In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn
import dependence_measures
from toy_data_experiment import main_kci_partial
import pval_computations
import cond_mean_estimation as cme
import scipy

def std_xy(xy):
    mean_x = xy.mean(dim=-1)
    var_x = xy.std(dim=-1) ** 2
    var_full = var_x.mean(dim=-1) + mean_x.std(dim=-1) ** 2
    return var_full ** 0.5

def get_error_mean_std(data, ground_truth, pval):
    if ground_truth == 'H0':
        data = 1 - (data >= pval).float()
    else:
        data = (data >= pval).float()

    mean = data.mean(dim=(-2, -1))
    std = std_xy(data)
    return mean, std

In [None]:
# your folder with results. Everything will be saved in './figs'
folder = 'results/'
folder_gamma = 'results/gamma_res/' # for the gamma approx

In [None]:
def plot_pvals(measures_list, task='get_xzy_randn_nl', dim=2, kernels='gaussian', n_points=100, ground_truth='H0',
              h0_pval_bottom=1/12, h0_pval_top=None, ignore_second_row=False, name_siffux=''):
    n_zy_points_min = 100
    names = {'kci': 'KCI', 'kci_xsplit': 'SplitKCI', 'circe': 'CIRCE', 'gcm': 'GCM', 'rbpt2_ub': 'RBPT2\'',
            'rbpt2': 'RBPT2'}
    colors = {'kci': '#0072B2', 'kci_xsplit': '#009E73', 'circe': '#821651', 'gcm': '#E69F00', 
              'rbpt2_ub': '#D55E00', 'rbpt2': '#000000'}

    points_list = torch.linspace(n_points, 1000, (1000 - n_points) // 100 + 1).int()
    ticks = ['', '200', '', '400', '', '600', '', '800', '', '1000']
    if n_points == 200 and points_list[-1] == 1000:
        points_list = torch.hstack((points_list, torch.tensor([2000])))
        ticks = ['200', '', '', '', '600', '', '', '', '1000', '2000']

    fig, axes = plt.subplots(nrows=1 if ignore_second_row else 2, ncols=2,
                             figsize=(6, 3 if ignore_second_row else 6), sharex=True)

    if ignore_second_row:
        axes = axes[None, :]
        pval_list = [0.05]
    else:
        pval_list = [0.05, 0.01]
    for pval_idx, pval in enumerate(pval_list):
        for xzy_idx, xzy in enumerate(['separate', 'joint']):
            axes[0, xzy_idx].set_title('unbalanced data regime' if xzy == 'separate' else 'standard regime')
            for measure in measures_list:
#                 if xzy == 'joint' and 'circe' in measure:
#                     continue
                data = torch.zeros(len(points_list), 100, 100)
                for idx, n_zy_points in enumerate(points_list):
                    # ${TASK}_${MEASURE}_${KERNELS}_${XZY}_${NPOINTS}_${NPOINTSZY}_${GROUND}
                    n_points_actual = 2 * n_points if 'circe' in measure else n_points
                    xzy_actual = 'separate' if 'circe' in measure else xzy
                    kernels_actual = 'gaussian' if 'circe' in measure else kernels
                    
                    file = f'{task}_{measure}_{kernels_actual}_{xzy_actual}_{n_points_actual}_{n_zy_points_min}_{ground_truth}d{dim}_nzy{n_zy_points}.pt'
                    data[idx] = torch.load(folder + file)

                mean, std = get_error_mean_std(data, ground_truth, pval)
                sqn_x = np.sqrt(data.shape[1])
                axes[pval_idx, xzy_idx].plot(points_list, mean, label=names[measure],
                                            color=colors[measure])
                axes[pval_idx, xzy_idx].fill_between(points_list, mean - std / sqn_x, 
                                                     mean + std / sqn_x, alpha=0.05, color=colors[measure])
           
                if ground_truth == 'H0':
                    axes[pval_idx, xzy_idx].axhline(pval, color='#5F5F5F', linestyle='--')
                    axes[pval_idx, 0].set_ylabel(fr'Type I error ($\alpha$={pval})')
                    if h0_pval_top is None:
                        axes[pval_idx, xzy_idx].set_ylim(bottom=pval * h0_pval_bottom)
                    else:
                        axes[pval_idx, xzy_idx].set_ylim(bottom=pval * h0_pval_bottom, top=min(1.05, pval * h0_pval_top))
                    axes[pval_idx, xzy_idx].set_yscale('log')
                else:
                    axes[pval_idx, 0].set_ylabel(fr'Type II error ($\alpha$={pval})')
                    axes[pval_idx, xzy_idx].set_ylim(bottom=-0.05, top=1.05)#, top=pval * 3)
                axes[-1, xzy_idx].set_xlabel(r'$m_{yz}$')
                axes[-1, xzy_idx].set_xticks(points_list, ticks)
        

    plt.legend()
#     fig.suptitle(f'{ground_truth} {task}')
    plt.tight_layout()
    plt.savefig(f'figs/full_{task}_dim{dim}_{kernels}_{n_points}_{ground_truth}{name_siffux}.pdf')
    plt.show()
    
    
def plot_pvals_circe(measures_list, task='get_xzy_randn_nl', dim=2, kernels='gaussian', n_points=100, ground_truth='H0',
              h0_pval_bottom=1/12, h0_pval_top=None, ignore_second_row=False, name_siffux=''):
    n_zy_points_min = 100
    names = {'kci': 'KCI', 'kci_xsplit': 'SplitKCI', 'circe': 'CIRCE', 'gcm': 'GCM', 'rbpt2_ub': 'RBPT2\'',
            'rbpt2': 'RBPT2'}
    colors = {'kci': '#0072B2', 'kci_xsplit': '#009E73', 'circe': '#821651', 'gcm': '#E69F00', 
              'rbpt2_ub': '#D55E00', 'rbpt2': '#000000'}

    points_list = torch.linspace(n_points, 1000, (1000 - n_points) // 100 + 1).int()
    ticks = ['', '200', '', '400', '', '600', '', '800', '', '1000']
    if n_points == 200 and points_list[-1] == 1000:
        points_list = torch.hstack((points_list, torch.tensor([2000])))
        ticks = ['200', '', '', '', '600', '', '', '', '1000', '2000']

    fig, axes = plt.subplots(nrows=1 if ignore_second_row else 2, ncols=2,
                             figsize=(6, 3 if ignore_second_row else 6), sharex=True)

    if ignore_second_row:
        axes = axes[None, :]
        pval_list = [0.05]
    else:
        pval_list = [0.05, 0.01]
    for pval_idx, pval in enumerate(pval_list):
        for xzy_idx, xzy in enumerate(['separate', 'joint']):
            axes[0, xzy_idx].set_title('unbalanced data regime' if xzy == 'separate' else 'standard regime')
            for measure in measures_list:
#                 if xzy == 'joint' and 'circe' in measure:
#                     continue
                data = torch.zeros(len(points_list), 100, 100)
                for idx, n_zy_points in enumerate(points_list):
                    # ${TASK}_${MEASURE}_${KERNELS}_${XZY}_${NPOINTS}_${NPOINTSZY}_${GROUND}
                    n_points_actual = 2 * n_points if 'circe' in measure else n_points
                    xzy_actual = 'separate' if 'circe' in measure else xzy
                    kernels_actual = 'gaussian' if 'circe' in measure else kernels
                    
                    file = f'{task}_{measure}_{kernels_actual}_{xzy_actual}_{n_points_actual}_{n_zy_points_min}_{ground_truth}d{dim}_nzy{n_zy_points}.pt'
                    data[idx] = torch.load(folder + file)

                mean, std = get_error_mean_std(data, ground_truth, pval)
                sqn_x = np.sqrt(data.shape[1])
                axes[pval_idx, xzy_idx].plot(points_list, mean, label=names[measure],
                                            color=colors[measure])
                axes[pval_idx, xzy_idx].fill_between(points_list, mean - std / sqn_x, 
                                                     mean + std / sqn_x, alpha=0.05, color=colors[measure])
           
                if ground_truth == 'H0':
                    axes[pval_idx, xzy_idx].axhline(pval, color='#5F5F5F', linestyle='--')
                    axes[pval_idx, 0].set_ylabel(fr'Type I error ($\alpha$={pval})')
                    if h0_pval_top is None:
                        axes[pval_idx, xzy_idx].set_ylim(bottom=pval * h0_pval_bottom)
                    else:
                        axes[pval_idx, xzy_idx].set_ylim(bottom=pval * h0_pval_bottom, top=min(1.05, pval * h0_pval_top))
                    axes[pval_idx, xzy_idx].set_yscale('log')
                else:
                    axes[pval_idx, 0].set_ylabel(fr'Type II error ($\alpha$={pval})')
                    axes[pval_idx, xzy_idx].set_ylim(bottom=-0.05, top=1.05)#, top=pval * 3)
                axes[-1, xzy_idx].set_xlabel(r'$m_{yz}$')
                axes[-1, xzy_idx].set_xticks(points_list, ticks)
        

    plt.legend()
#     fig.suptitle(f'{ground_truth} {task}')
    plt.tight_layout()
    plt.savefig(f'figs/circe_{task}_dim{dim}_{kernels}_{n_points}_{ground_truth}{name_siffux}.pdf')
    plt.show()

In [None]:
def plot_pvals_h0h1(measures_list, task='get_xzy_randn_nl', dim=2, kernels='gaussian', n_points=100,
                    h0_pval_bottom=1/12, h0_pval_top=None, pval=0.05):
    n_zy_points_min = 100
    names = {'kci': 'KCI', 'kci_xsplit': 'SplitKCI', 'circe': 'CIRCE', 'gcm': 'GCM', 'rbpt2_ub': 'RBPT2\''}
    colors = {'kci': '#0072B2', 'kci_xsplit': '#009E73', 'circe': '#821651', 'gcm': '#E69F00', 'rbpt2_ub': '#D55E00'}

    points_list = torch.linspace(n_points, 1000, (1000 - n_points) // 100 + 1).int()
    ticks = ['', '200', '', '400', '', '600', '', '800', '', '1000']
    if n_points == 200 and points_list[-1] == 1000:
        points_list = torch.hstack((points_list, torch.tensor([2000])))
        ticks = ['200', '', '', '', '600', '', '', '', '1000', '2000']

    fig, axes = plt.subplots(nrows=2, ncols=2,
                             figsize=(6, 6), sharex=True)


    for pval_idx, ground_truth in enumerate(['H0', 'H1']):
        for xzy_idx, xzy in enumerate(['separate', 'joint']):
            axes[0, xzy_idx].set_title('unbalanced data regime' if xzy == 'separate' else 'standard regime')
            for measure in measures_list:
#                 if xzy == 'joint' and 'circe' in measure:
#                     continue
                data = torch.zeros(len(points_list), 100, 100)
                for idx, n_zy_points in enumerate(points_list):
                    # ${TASK}_${MEASURE}_${KERNELS}_${XZY}_${NPOINTS}_${NPOINTSZY}_${GROUND}
                    n_points_actual = 2 * n_points if 'circe' in measure else n_points
                    xzy_actual = 'separate' if 'circe' in measure else xzy
                    kernels_actual = 'gaussian' if 'circe' in measure else kernels
                    
                    file = f'{task}_{measure}_{kernels_actual}_{xzy_actual}_{n_points_actual}_{n_zy_points_min}_{ground_truth}d{dim}_nzy{n_zy_points}.pt'
                    data[idx] = torch.load(folder + file)

                mean, std = get_error_mean_std(data, ground_truth, pval)
                sqn_x = np.sqrt(data.shape[1])
                axes[pval_idx, xzy_idx].plot(points_list, mean, label=names[measure],
                                            color=colors[measure])
                axes[pval_idx, xzy_idx].fill_between(points_list, mean - std / sqn_x, 
                                                     mean + std / sqn_x, alpha=0.05, color=colors[measure])
           
                if ground_truth == 'H0':
                    axes[pval_idx, xzy_idx].axhline(pval, color='#5F5F5F', linestyle='--')
                    axes[pval_idx, 0].set_ylabel(fr'Type I error ($\alpha$={pval})')
                    if h0_pval_top is None:
                        axes[pval_idx, xzy_idx].set_ylim(bottom=pval * h0_pval_bottom)
                    else:
                        axes[pval_idx, xzy_idx].set_ylim(bottom=pval * h0_pval_bottom, top=min(1.05, pval * h0_pval_top))
                    axes[pval_idx, xzy_idx].set_yscale('log')
                else:
                    axes[pval_idx, 0].set_ylabel(fr'Type II error ($\alpha$={pval})')
                    axes[pval_idx, xzy_idx].set_ylim(bottom=-0.05, top=1.05)#, top=pval * 3)
                axes[-1, xzy_idx].set_xlabel(r'$m_{yz}$')
                axes[-1, xzy_idx].set_xticks(points_list, ticks)
        

    plt.legend()
#     fig.suptitle(f'{ground_truth} {task}')
    plt.tight_layout()
    plt.savefig(f'figs/{task}_dim{dim}_{kernels}_{n_points}_p{pval}.pdf')
    plt.show()

In [None]:
def plot_pvals_kk(measures_list, task='get_xzy_randn_nl', dim=2, ground_truth='H0', n_points=100,
                    h0_pval_bottom=1/12, h0_pval_top=None, pval=0.05):
    n_zy_points_min = 100
    names = {'kci': 'KCI', 'kci_xsplit': 'SplitKCI', 'circe': 'CIRCE', 'gcm': 'GCM', 'rbpt2_ub': 'RBPT2\''}
    colors = {'kci': '#0072B2', 'kci_xsplit': '#009E73', 'circe': '#821651', 'gcm': '#E69F00', 'rbpt2_ub': '#D55E00'}

    points_list = torch.linspace(n_points, 1000, (1000 - n_points) // 100 + 1).int()
    ticks = ['', '200', '', '400', '', '600', '', '800', '', '1000']
    if n_points == 200 and points_list[-1] == 1000:
        points_list = torch.hstack((points_list, torch.tensor([2000])))
        ticks = ['200', '', '', '', '600', '', '', '', '1000', '2000']

    fig, axes = plt.subplots(nrows=2, ncols=2,
                             figsize=(6, 6), sharex=True)


    for pval_idx, kernels in enumerate(['gaussian', 'all']):
        for xzy_idx, xzy in enumerate(['separate', 'joint']):
            axes[0, xzy_idx].set_title('unbalanced data regime' if xzy == 'separate' else 'standard regime')
            for measure in measures_list:
#                 if xzy == 'joint' and 'circe' in measure:
#                     continue
                data = torch.zeros(len(points_list), 100, 100)
                for idx, n_zy_points in enumerate(points_list):
                    # ${TASK}_${MEASURE}_${KERNELS}_${XZY}_${NPOINTS}_${NPOINTSZY}_${GROUND}
                    n_points_actual = 2 * n_points if 'circe' in measure else n_points
                    xzy_actual = 'separate' if 'circe' in measure else xzy
                    kernels_actual = 'gaussian' if 'circe' in measure else kernels
                    
                    file = f'{task}_{measure}_{kernels_actual}_{xzy_actual}_{n_points_actual}_{n_zy_points_min}_{ground_truth}d{dim}_nzy{n_zy_points}.pt'
                    data[idx] = torch.load(folder + file)

                mean, std = get_error_mean_std(data, ground_truth, pval)
                sqn_x = np.sqrt(data.shape[1])
                axes[pval_idx, xzy_idx].plot(points_list, mean, label=names[measure],
                                            color=colors[measure])
                axes[pval_idx, xzy_idx].fill_between(points_list, mean - std / sqn_x, 
                                                     mean + std / sqn_x, alpha=0.05, color=colors[measure])
           
                if kernels == 'gaussian':
                    name_addition = 'Gaussian'
                else:
                    name_addition = 'best kernel'
                if ground_truth == 'H0':
                    axes[pval_idx, xzy_idx].axhline(pval, color='#5F5F5F', linestyle='--')
                    axes[pval_idx, 0].set_ylabel(fr'Type I error ($\alpha$={pval}, {name_addition})')
                    if h0_pval_top is None:
                        axes[pval_idx, xzy_idx].set_ylim(bottom=pval * h0_pval_bottom)
                    else:
                        axes[pval_idx, xzy_idx].set_ylim(bottom=pval * h0_pval_bottom, top=min(1.05, pval * h0_pval_top))
                    axes[pval_idx, xzy_idx].set_yscale('log')
                else:
                    axes[pval_idx, 0].set_ylabel(fr'Type II error ($\alpha$={pval}, {name_addition})')
                    axes[pval_idx, xzy_idx].set_ylim(bottom=-0.05, top=1.05)#, top=pval * 3)
                axes[-1, xzy_idx].set_xlabel(r'$m_{yz}$')
                axes[-1, xzy_idx].set_xticks(points_list, ticks)
        

    plt.legend()
#     fig.suptitle(f'{ground_truth} {task}')
    plt.tight_layout()
    plt.savefig(f'figs/{task}_dim{dim}_{kernels}_{n_points}_p{pval}_{ground_truth}.pdf')
    plt.show()

In [None]:
def plot_pvals_rbpt(measures_list, task='get_xzy_rbpt2', dim=2, kernels='gaussian', n_points=100,
              h0_pval_bottom=1/12, h0_pval_top=None, ignore_second_row=False, pval=0.05, full=False):
    n_zy_points_min = 100
    names = {'kci': 'KCI', 'kci_xsplit': 'SplitKCI', 'circe': 'CIRCE', 'gcm': 'GCM', 'rbpt2_ub': 'RBPT2\''}
    colors = {'kci': '#0072B2', 'kci_xsplit': '#009E73', 'circe': '#821651', 'gcm': '#E69F00', 'rbpt2_ub': '#D55E00'}
    
       
    points_list = torch.linspace(n_points, 1000, (1000 - n_points) // 100 + 1).int()
    ticks = ['', '200', '', '400', '', '600', '', '800', '', '1000']
    if n_points == 200 and points_list[-1] == 1000:
        points_list = torch.hstack((points_list, torch.tensor([2000])))
        ticks = ['200', '', '', '', '600', '', '', '', '1000', '2000']

    fig, axes = plt.subplots(nrows=1 if ignore_second_row else 2, ncols=2,
                             figsize=(6, 3 if ignore_second_row else 6), sharex=True)

    if ignore_second_row:
        axes = axes[None, :]
        truth_list = ['H0']
    else:
        truth_list = ['H0', 'H1']
    for pval_idx, ground_truth in enumerate(truth_list):
        if ground_truth == 'H0':
            rbpt_c = 0.0
            rbpt_gamma = 0.02
        else:
            rbpt_c = 0.1
            rbpt_gamma = 0.0
        for xzy_idx, xzy in enumerate(['separate', 'joint']):
            axes[0, xzy_idx].set_title('unbalanced data regime' if xzy == 'separate' else 'standard regime')
            for measure in measures_list:
#                 if xzy == 'joint' and 'circe' in measure:
#                     continue
                data = torch.zeros(len(points_list), 5, 100, 100)
                for rbpt_seed in range(1, 6):
                    for idx, n_zy_points in enumerate(points_list):
                        # ${TASK}_${MEASURE}_${KERNELS}_${XZY}_${NPOINTS}_${NPOINTSZY}_${GROUND}
                        n_points_actual = 2 * n_points if 'circe' in measure else n_points
                        xzy_actual = 'separate' if 'circe' in measure else xzy
                        kernels_actual = 'gaussian' if 'circe' in measure else kernels

                        file = f'{task}_{measure}_{kernels_actual}_{xzy_actual}_{n_points_actual}_{n_zy_points_min}_{ground_truth}_{rbpt_c}_{rbpt_gamma}_{rbpt_seed}d{dim}_nzy{n_zy_points}.pt'
                        data[idx, rbpt_seed - 1] = torch.load(folder + file)
                data = data.view(len(points_list), -1, 100)

                mean, std = get_error_mean_std(data, ground_truth, pval)
                sqn_x = np.sqrt(data.shape[1])
                axes[pval_idx, xzy_idx].plot(points_list, mean, label=names[measure],
                                            color=colors[measure])
                axes[pval_idx, xzy_idx].fill_between(points_list, mean - std / sqn_x, 
                                                     mean + std / sqn_x, alpha=0.05, color=colors[measure])
           
                if ground_truth == 'H0':
                    axes[pval_idx, xzy_idx].axhline(pval, color='#5F5F5F', linestyle='--')
                    axes[pval_idx, 0].set_ylabel(fr'Type I error ($\alpha$={pval})')
                    if h0_pval_top is None:
                        axes[pval_idx, xzy_idx].set_ylim(bottom=pval * h0_pval_bottom)
                    else:
                        axes[pval_idx, xzy_idx].set_ylim(bottom=pval * h0_pval_bottom, top=min(1.05, pval * h0_pval_top))
                    axes[pval_idx, xzy_idx].set_yscale('log')
                else:
                    axes[pval_idx, 0].set_ylabel(fr'Type II error ($\alpha$={pval})')
                    axes[pval_idx, xzy_idx].set_ylim(bottom=-0.05, top=1.05)#, top=pval * 3)
                axes[-1, xzy_idx].set_xlabel(r'$m_{yz}$')
                axes[-1, xzy_idx].set_xticks(points_list, ticks)
        

    plt.legend()
#     fig.suptitle(f'{ground_truth} {task}')
    plt.tight_layout()
    if full:
        plt.savefig(f'figs/full_{task}_dim{dim}_both_{n_points}_{kernels}_p{pval}.pdf')
    else:
        plt.savefig(f'figs/{task}_dim{dim}_both_{n_points}_{kernels}_p{pval}.pdf')
    plt.show()

In [None]:
def plot_pvals_gamma(measures_list, task='get_xzy_randn_nl', dim=2, kernels='gaussian', n_points=100, ground_truth='H0',
              h0_pval_bottom=1/12, h0_pval_top=None, ignore_second_row=False, name_siffux=''):
    assert kernels == 'gaussian'
    n_zy_points_min = 100
    names = {'kci': 'KCI', 'kci_xsplit': 'SplitKCI', 'circe': 'CIRCE', 'gcm': 'GCM', 'rbpt2_ub': 'RBPT2\'',
            'rbpt2': 'RBPT2'}
    colors = {'kci': '#0072B2', 'kci_xsplit': '#009E73', 'circe': '#821651', 'gcm': '#E69F00', 
              'rbpt2_ub': '#D55E00', 'rbpt2': '#000000'}

    points_list = torch.linspace(n_points, 1000, (1000 - n_points) // 100 + 1).int()
    ticks = ['', '200', '', '400', '', '600', '', '800', '', '1000']
    if n_points == 200 and points_list[-1] == 1000:
        points_list = torch.hstack((points_list, torch.tensor([2000])))
        ticks = ['200', '', '', '', '600', '', '', '', '1000', '2000']

    fig, axes = plt.subplots(nrows=1 if ignore_second_row else 2, ncols=2,
                             figsize=(6, 3 if ignore_second_row else 6), sharex=True)

    if ignore_second_row:
        axes = axes[None, :]
        pval_list = [0.05]
    else:
        pval_list = [0.05, 0.01]
    for pval_idx, pval in enumerate(pval_list):
        for xzy_idx, xzy in enumerate(['separate', 'joint']):
            axes[0, xzy_idx].set_title('unbalanced data regime' if xzy == 'separate' else 'standard regime')
            for measure in measures_list:
#                 if xzy == 'joint' and 'circe' in measure:
#                     continue
                data = torch.zeros(len(points_list), 100, 100)
                data_gamma = torch.zeros(len(points_list), 100, 100)
                for idx, n_zy_points in enumerate(points_list):
                    # ${TASK}_${MEASURE}_${KERNELS}_${XZY}_${NPOINTS}_${NPOINTSZY}_${GROUND}
                    n_points_actual = 2 * n_points if 'circe' in measure else n_points
                    xzy_actual = 'separate' if 'circe' in measure else xzy
                    kernels_actual = 'gaussian' if 'circe' in measure else kernels
                    
                    file = f'{task}_{measure}_{kernels_actual}_{xzy_actual}_{n_points_actual}_{n_zy_points_min}_{ground_truth}d{dim}_nzy{n_zy_points}.pt'
                    data[idx] = torch.load(folder + file)
                    
                    file = f'{ground_truth.lower()}_{measure}_{xzy_actual}_n{n_points_actual}_d2_nzy{n_zy_points}.pt'
#                     h0_kci_separate_n200_d5_nzy600.pt 
                    data_gamma[idx] = torch.load(folder_gamma + file)

                mean, std = get_error_mean_std(data, ground_truth, pval)
                sqn_x = np.sqrt(data.shape[1])
                axes[pval_idx, xzy_idx].plot(points_list, mean, label=names[measure] + ' (wild)',
                                            color=colors[measure])
                axes[pval_idx, xzy_idx].fill_between(points_list, mean - std / sqn_x, 
                                                     mean + std / sqn_x, alpha=0.05, color=colors[measure])
                
                mean, std = get_error_mean_std(data_gamma, ground_truth, pval)
                sqn_x = np.sqrt(data_gamma.shape[1])
                axes[pval_idx, xzy_idx].plot(points_list, mean, label=names[measure] + ' (gamma)',
                                            color=colors[measure], ls=':')
                axes[pval_idx, xzy_idx].fill_between(points_list, mean - std / sqn_x, 
                                                     mean + std / sqn_x, alpha=0.05, color=colors[measure], ls=':')
           
                if ground_truth == 'H0':
                    axes[pval_idx, xzy_idx].axhline(pval, color='#5F5F5F', linestyle='--')
                    axes[pval_idx, 0].set_ylabel(fr'Type I error ($\alpha$={pval})')
                    if h0_pval_top is None:
                        axes[pval_idx, xzy_idx].set_ylim(bottom=pval * h0_pval_bottom)
                    else:
                        axes[pval_idx, xzy_idx].set_ylim(bottom=pval * h0_pval_bottom, top=min(1.05, pval * h0_pval_top))
                    axes[pval_idx, xzy_idx].set_yscale('log')
                else:
                    axes[pval_idx, 0].set_ylabel(fr'Type II error ($\alpha$={pval})')
                    axes[pval_idx, xzy_idx].set_ylim(bottom=-0.05, top=1.05)#, top=pval * 3)
                axes[-1, xzy_idx].set_xlabel(r'$m_{yz}$')
                axes[-1, xzy_idx].set_xticks(points_list, ticks)
        

    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
#     fig.suptitle(f'{ground_truth} {task}')
    plt.tight_layout()
    plt.savefig(f'figs/gamma_{task}_dim{dim}_{kernels}_{n_points}_{ground_truth}{name_siffux}.pdf')
    plt.show()

# Plotting

In [None]:
# task one
measures_list = ['kci', 'kci_xsplit', 'circe']#'gcm', 'rbpt2_ub']#, 'circe'

for kernels in ['all', 'gaussian']:
    for n_points in [100, 200]:#, 200]:
        for ground_truth in ['H0', 'H1']:
            plot_pvals_circe(measures_list, task='get_xzy_randn', dim=2, 
                       kernels=kernels, n_points=n_points, ground_truth=ground_truth,
                       h0_pval_bottom=1/2, h0_pval_top=30, ignore_second_row=True)

In [None]:
# task 2
measures_list = ['kci', 'kci_xsplit', 'gcm', 'rbpt2_ub']#, 'circe'

for kernels in ['all', 'gaussian']:
    for n_points in [100, 200]:#, 200]:
#         for ground_truth in ['H0', 'H1']:
#             plot_pvals(measures_list, task='get_xzy_randn_nl', dim=2, 
#                        kernels=kernels, n_points=n_points, ground_truth=ground_truth)
            
            plot_pvals_h0h1(measures_list, task='get_xzy_randn_nl', dim=2, kernels=kernels, n_points=n_points, pval=0.05,
                           h0_pval_bottom=1/5)

In [None]:
# task 3
measures_list = ['kci', 'kci_xsplit', 'gcm', 'rbpt2_ub']#, 'circe'

# for kernels in ['all', 'gaussian']:
for n_points in [100, 200]:#, 200]:
    for ground_truth in ['H0', 'H1']:
#         plot_pvals(measures_list, task='get_xzy_circ', dim=2, 
#                    kernels=kernels, n_points=n_points, ground_truth=ground_truth)
        plot_pvals_kk(measures_list, task='get_xzy_circ', dim=2, ground_truth=ground_truth, n_points=100,
                    h0_pval_bottom=1/2.6, h0_pval_top=None, pval=0.05)

In [None]:
# task 4
measures_list = ['kci', 'kci_xsplit', 'gcm', 'rbpt2_ub']#, 'circe'

for kernels in ['all', 'gaussian']:
    for n_points in [200]:#, 200]:
        plot_pvals_rbpt(measures_list, task='get_xzy_rbpt', dim=40, 
                   kernels=kernels, n_points=n_points,
                       h0_pval_bottom=1/2, h0_pval_top=100)#, ignore_second_row=True)

In [None]:
# circe re-plotting + full
# task one
measures_list = ['kci', 'kci_xsplit', 'circe', 'gcm', 'rbpt2_ub']

for kernels in ['all', 'gaussian']:
    for n_points in [100, 200]:#, 100]:
        for ground_truth in ['H0', 'H1']:
            print(kernels, n_points)
            plot_pvals(measures_list, task='get_xzy_circ', dim=2, 
                       kernels=kernels, n_points=n_points, ground_truth=ground_truth,
                       h0_pval_bottom=1/2, h0_pval_top=30, ignore_second_row=True)
            
            plot_pvals(measures_list, task='get_xzy_randn', dim=2, 
                       kernels=kernels, n_points=n_points, ground_truth=ground_truth,
                       h0_pval_bottom=1/2, h0_pval_top=30, ignore_second_row=True)
            
            plot_pvals(measures_list, task='get_xzy_randn_nl', dim=2, 
                       kernels=kernels, n_points=n_points, ground_truth=ground_truth,
                       h0_pval_bottom=1/2, h0_pval_top=30, ignore_second_row=True)
            if n_points == 200:
                plot_pvals_rbpt(measures_list, task='get_xzy_rbpt', dim=40, 
                       kernels=kernels, n_points=n_points,
                           h0_pval_bottom=1/2, h0_pval_top=100, full=True)

In [None]:
# RBPT unbiasing
measures_list = ['rbpt2', 'rbpt2_ub']#, 'circe'

for kernels in ['gaussian']:
    for n_points in [100, 200]:#, 100]:
        for ground_truth in ['H0', 'H1']:
            print(kernels, n_points)
            plot_pvals(measures_list, task='get_xzy_randn', dim=2, 
                       kernels=kernels, n_points=n_points, ground_truth=ground_truth,
                       h0_pval_bottom=1/100, h0_pval_top=30, ignore_second_row=True, name_siffux='rbpt_bias')

In [None]:
# gamma vs wild bootstrap
measures_list = ['kci', 'kci_xsplit', 'circe']#, 'gcm', 'rbpt2_ub']#, 'circe'

for kernels in ['gaussian']:
    for n_points in [100]:#, 100]:
        for ground_truth in ['H0', 'H1']:
            print(kernels, n_points)
            plot_pvals_gamma(measures_list, task='get_xzy_randn', dim=2, 
                       kernels=kernels, n_points=n_points, ground_truth=ground_truth,
                       h0_pval_bottom=1/20, h0_pval_top=30, ignore_second_row=True)
