# Description

Notebook for processing multi-dimensional simulated case, with sample splitting
evaluation

# Imports

In [1]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import multiprocessing
import numpy as np
import os
import pandas as pd
import pingouin as pg


import pickle
import seaborn as sns
import sklearn
import sys
from tqdm import tqdm

# to hide warnings for pretty notebook rendering in repo
import warnings
warnings.filterwarnings('ignore')

# user imports
sys.path.append("../../")

from utils.sim import *
from utils.rddd import *
from utils.pwr import *

In [2]:
RESULT_DIR = "../kdd23/"

In [30]:
fuzzy_gap = 0.5
n = 1000

bw = 0.25

running_cols = ['x']
treat = 't'
seed_offset = 2000


def process_multidim_results_sample_split(n_feat, seeds):
    """Process multidim results"""
    alpha = 0.05
    pwr_dict = {
            "lower_all": [],
            "upper_all": [],
            "lower_max": [],
            "upper_max": []
    }

    for seed in range(seeds):
        res = pickle.load(open(os.path.join(RESULT_DIR, "seed{}_nfeats_{}_multidim.pkl".format(seed, n_feat)), "rb"))
        
        result, n_tests = res
        cur_trial = seed
        x_dict = result['x']

        for x_cutoff, label in [(0.25, "lower"), (0.75, "upper")]:
            nodes = x_dict[x_cutoff]
            pwrs = []
            for node in nodes:

                if node['llr_results'] is None:
                        continue
                regression_dict = dict(n_informative=n_feat, 
                                       noise=0, 
                                       n_features=n_feat)
                # generate an iid hold-out set
                holdout = generate_blended_rdd_with_covars(seed + seed_offset, 
                                                           n, 
                                                           fuzzy_gap=fuzzy_gap, 
                                                           take=0.05, 
                                                           reg_dict=regression_dict)

                # apply the rule to the holdout set
                rule_path = node['rule_path']
                for rule in rule_path[:-1]:
                    if rule.path_dir == '<':
                        holdout = holdout[holdout[rule.feature] < rule.threshold]
                    elif rule.path_dir == '>=':
                        holdout = holdout[holdout[rule.feature] >= rule.threshold]
                    elif rule.path_dir == '<=':
                        holdout = holdout[holdout[rule.feature] <= rule.threshold]
                    elif rule.path_dir == '>':
                        holdout = holdout[holdout[rule.feature] > rule.threshold]
                    elif rule.path_dir == '==':
                        holdout = holdout[holdout[rule.feature] == rule.threshold]

                llr_results, _, _ = test_discontinuity(holdout, x_cutoff, 'x', treat='t', bw=bw, kernel='triangular')

                if ((node['neff_pval'] < alpha) or len(node['rule_path']) == 1) and llr_results.pvalues['z'] < (alpha / n_tests):
                    sig_power = rdd_power(llr_results.params['z'], llr_results.std_errors['z']**2, alpha=alpha / n_tests)
                    pwrs.append(sig_power)
                    pwr_dict["{}_all".format(label)].append(sig_power)
            if len(pwrs) > 0:
                pwr_dict["{}_max".format(label)].append(max(pwrs))

    return (n_feat, pwr_dict)


In [31]:
%time
seeds = 500
f_args = [(n_feat, seeds) for n_feat in [2, 4, 8, 16]]

with multiprocessing.Pool(4) as pool:
    results = pool.starmap(process_multidim_results_sample_split, f_args)


CPU times: user 4 µs, sys: 1 µs, total: 5 µs
Wall time: 10 µs


In [37]:
pickle.dump(results, open("../../results/tmlr/multidim.pkl", "wb"))

In [36]:
for n_feat, res in results:
    print("n_feat: {}".format(n_feat))
    for k, v in res.items():
        print("\t{}: {:.3f} +/- {:.3f}".format(k, np.mean(v), np.std(v)))

n_feat: 2
	lower_all: 0.820 +/- 0.154
	upper_all: 0.800 +/- 0.157
	lower_max: 0.841 +/- 0.150
	upper_max: 0.816 +/- 0.157
n_feat: 4
	lower_all: 0.791 +/- 0.146
	upper_all: 0.791 +/- 0.155
	lower_max: 0.807 +/- 0.142
	upper_max: 0.803 +/- 0.154
n_feat: 8
	lower_all: 0.771 +/- 0.150
	upper_all: 0.776 +/- 0.158
	lower_max: 0.779 +/- 0.149
	upper_max: 0.785 +/- 0.156
n_feat: 16
	lower_all: 0.787 +/- 0.155
	upper_all: 0.778 +/- 0.155
	lower_max: 0.789 +/- 0.154
	upper_max: 0.780 +/- 0.155
