## Setup Only for Colab

In [None]:
# prompt: mount drive

from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks/hidden_mediators

In [None]:
%ls

In [None]:
from IPython.display import clear_output

In [None]:
import time
!pip install -r requirements.txt
time.sleep(2)
clear_output()

In [None]:
import time
# replace `develop` with `install` if you wont make library code changes
!python setup.py develop
time.sleep(2)
clear_output()
# Restart the session after running this

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks

# Main Logic

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
from joblib import Parallel, delayed
from proximalde.gen_data import gen_data_complex, gen_data_no_controls, gen_data_with_mediator_violations, gen_data_no_controls_discrete_m
from proximalde.proximal import proximal_direct_effect, ProximalDE, residualizeW
from proximalde.ukbb_proximal import ProximalDE_UKBB
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler
from proximalde.crossfit import fit_predict
from proximalde.utilities import covariance, svd_critical_value
from proximalde.proximal import residualizeW
from proximalde.proxy_rm_utils import *
from proximalde.ukbb_data_utils import *
import seaborn as sns
import matplotlib.pyplot as plt
import pickle as pk

### Collect results 

In [None]:
import itertools
ds = ['Low_inc', 'On_dis', 'No_priv_insr', 'No_uni', 'Female', 'Black', 'Obese', 'Asian']
ys = ['OA', 'myoc','deprs', 'back', 'RA', 'fibro', 'infl', 'copd','chrkd','mgrn','mela', 'preg', 'endo']
dys = list(itertools.product(ds, ys))
dys = ['_'.join(x) for x in dys]

yy={'mela': 'Melanoma','endo': 'Endometriosis','infl': 'IBD','preg': 'Complications during labor','OA': 'Osteoarthritis','mgrn':'Migraine','copd':'COPD', 'back': 'Back pain', 'deprs': 'Depression', 'myoc': "Heart disease", 'RA': 'Rh. Arthritis', 'fibro': 'Fibromyalgia', 'chrkd': 'Chronic kidney disease'}
dd={'No_priv_insr': 'Not on private insr.','No_uni': 'No p.s. education', 'Low_inc': 'Low income','Obese':'Obese', 'Female': 'Female', 'Black': 'Black', 'Asian': "Asian", 'On_dis': 'Disability insurance'}
dys_main = ['Female_myoc','Asian_OA', 'Low_inc_deprs', 'On_dis_RA',  'Obese_OA', 'Black_chrkd']

X, X_feats, Z, Z_feats = load_ukbb_XZ_data()
Xint = get_int_feats(X_feats)
Zint_ = get_int_feats(Z_feats)

In [None]:
def get_median_item(y):
    srt_idx = np.argsort(y)
    if len(y) % 2 == 0: # if even elements, get max point
        i1 = len(y)//2 - 1
        i2 = len(y)//2
        if np.abs(y[srt_idx[i1]]) > np.abs(y[srt_idx[i2]]):
            return y[srt_idx[i1]], srt_idx[i1]
        else:
            return y[srt_idx[i2]], srt_idx[i2]
    else:
        return y[srt_idx[len(y)//2]], srt_idx[len(y)//2]

def rmNaZ(Zres, Zint):
    bad_idx = np.array([('Do not know' in x) or ('Prefer not to' in x) for x in Zint])
    Zres = Zres[:,~bad_idx]
    Zint = Zint[~bad_idx]
    return Zres, Zint


def run_inf_rm(D_label, Y_label, inf_idxs, Xset, Zset, save_dir):
    np.random.seed(4)
    W, _, W_feats, X, X_binary, X_feats, Z, Z_binary, Z_feats, Y, D = load_ukbb_data(D_label=d, Y_label=y)
    Z = Z[:,~bad_idx][:,Zset]
    X = X[:,Xset]
    est = ProximalDE_UKBB(binary_D=False, semi=True, cv=3, verbose=1, random_state=3)
    est.fit(np.delete(W, inf_idxs, axis=0), np.delete(D, inf_idxs, axis=0),
             np.delete(Z, inf_idxs, axis=0), np.delete(X, inf_idxs, axis=0),
             np.delete(Y, inf_idxs, axis=0), D_label=D_label, Y_label=Y_label, save_fname_addn=f'_infRm_{D_label}{Y_label}') 
    return est.summary(alpha=0.05, save_dir=save_dir,save_fname_addn='_infRm')


def run_bootstrap(D_label, Y_label, Xset, Zset, save_dir):
    np.random.seed(4)
    X, X_feats, Z, Z_feats = load_ukbb_XZ_data()
    Xint = get_int_feats(X_feats)
    Zint_ = get_int_feats(Z_feats)
    bad_idx = np.array([('Do not know' in x) or ('Prefer not to' in x) for x in Zint_])


    W, _, W_feats, X, X_binary, X_feats, Z, Z_binary, Z_feats, Y, D = load_ukbb_data(D_label=d, Y_label=y)
    est = ProximalDE_UKBB(binary_D=False, semi=True, cv=3, verbose=1, random_state=3)
    est.fit(W, D, Z, X, Y, D_label, Y_label, Xset=Xset, Zset=Zset, bad_idx=bad_idx)

    inf1 = est.bootstrap_inference(stage=1, n_subsamples=10, fraction=0.5, replace=False, verbose=3, random_state=123)
    inf1.summary(save_dir=save_dir, save_fname_addn='_bs_stage1')

    inf2 = est.bootstrap_inference(stage=2, n_subsamples=100, fraction=0.5, replace=False, verbose=3, random_state=123)
    inf2.summary(save_dir=save_dir, save_fname_addn='_bs_stage2')

    inf3 = est.bootstrap_inference(stage=3, n_subsamples=1000, fraction=0.5, replace=False, verbose=3, random_state=123)
    inf3.summary(save_dir=save_dir, save_fname_addn='_bs_stage3')
    return {'inf1': inf1, 'inf2': inf2, 'inf3':inf3}






### Collects all proxy rm sets that pass 4/4 tests
could be many per dy pair

In [None]:
import pickle as pk 

def get_cand_and_len(dir, exp_path):
    cand_exists = os.path.exists(f'{dir}/{exp_path}/candidates.pkl')
    if cand_exists:
        cand = pk.load(open(f'{dir}/{exp_path}/candidates.pkl','rb'))
        return cand, len(cand)
    else:
        return [], -1
    
verbose = True
ss_dy = {}
for dy in dys:
    
    if verbose:
        print("\n"*2,dy)
    dir = f'./results/proxyrm/{dy}'
    if not os.path.exists(dir):
        continue
    all_dy_hparams = [x for x in os.listdir(dir) if 'rm' in x] #todo: delete all other folders

    for exp_path in all_dy_hparams:
        candidates, n_pairs = get_cand_and_len(dir, exp_path) 
        if verbose and n_pairs < 1:
            print(f"{exp_path} has {n_pairs} candidate pairs found")
        single_cand_paths = [x for x in os.listdir(f'{dir}/{exp_path}') if not '.pkl' in x]
        if n_pairs > 0:
            print(exp_path)
            unfound = True
            for cand_idx in single_cand_paths:
                test = pd.read_csv(f'{dir}/{exp_path}/{cand_idx}/table2.csv',header=1, index_col=1)
                if test['pass test'].sum() == 4:
                    if dy not in ss_dy:
                        ss_dy[dy] = []
                    point = pd.read_csv(f'{dir}/{exp_path}/{cand_idx}/table0.csv',header=1, index_col=1)
                    inf = np.load(f'{dir}/{exp_path}/{cand_idx}/inf_set.npy')
                    candidate = candidates[int(cand_idx)]
                    ss_dy[dy].append([point, test, inf, f'{dir}/{exp_path}/{cand_idx}', candidate])
                    if verbose and unfound:
                        print(f"{exp_path} found 4/4 passing!")
                        unfound = False
pk.dump(ss_dy,open('ss_dy.pkl', 'wb'))

### Recollect / update inf set to be for switching the sign to 0
as is, inf set is for alpha=0.05

In [None]:
ss_dy=pk.load(open('ss_dy.pkl', 'rb'))
for dy in ss_dy.keys():
    points= np.array([x[0].point.iloc[0] for x in ss_dy[dy]])
    m, idx = get_median_item(points)
    print(dy, idx)
    point, test, inf_idxs, path, (Xset, Zset) = ss_dy[dy][idx]
    d, y = '_'.join(dy.split('_')[:-1]), dy.split('_')[-1]
    final_ss_dy[dy]  = ss_dy[dy][idx]
    Xres, Zres, Yres, Dres = load_ukbb_res_data(d, y)
    Zres, Zint = rmNaZ(Zres, Zint_)

    est = ProximalDE(semi=True, cv=3, verbose=1, random_state=3)
    est.fit(None, Dres, Zres[:, Zset], Xres[:, Xset], Yres)
    sm = est.summary()
    diag = est.run_diagnostics()
    inf_mp200 = est.influential_set(max_points=200)
    inf_alhpa = inf_idxs
    inf = est.influential_set()
    infs = {'switch_sign': inf, 'alpha=.05': inf_idxs, 'n=200':inf_mp200}
    for k,v in infs.items():
        print(v.shape)
    final_ss_dy[dy] = point, test, infs, path, (Xset, Zset)
pk.dump(final_ss_dy, open('ss_dy_updated_inf.pkl', 'wb'))

In [None]:
ss_dy = pk.load(open('ss_dy_updated_inf.pkl', 'rb'))
for dy in ss_dy.keys():
        
        point, test, inf_dict, path, (Xset, Zset) = ss_dy[dy]
        d, y = '_'.join(dy.split('_')[:-1]), dy.split('_')[-1]
        Xres, Zres, Yres, Dres = load_ukbb_res_data(d, y)
        Zres, Zint = rmNaZ(Zres, Zint_)

        est = ProximalDE(semi=True, cv=3, verbose=1, random_state=3)
        est.fit(None, Dres, Zres[:, Zset], Xres[:, Xset], Yres)
        sm = est.summary()
        diag = est.run_diagnostics()

        # Inf score / removal 
        inf_rm_points = []
        for name,inds in inf_dict.items():
            if 'alpha' not in name:
                if not os.path.exists(path + f'/table0_infRm_{name}.csv'):
                    run_inf_rm(D_label=d, Y_label=y, inf_idxs=inds, Xset=Xset, Zset=Zset, save_dir=path)
                inf_rm_point = pd.read_csv(path + f'/table0_infRm_{name}.csv',header=1, index_col=1)[['point', 'stderr', 'ci_lower', 'ci_upper']]
                inf_rm_point['set_size'] = len(inds)
                inf_rm_point.columns = [x+f'_infRm_{name}' for x in inf_rm_point.columns]
                inf_rm_points.append(inf_rm_point.reset_index())
        Run rank test, saving and loading if saved
        if os.path.exists(path + '/rank.npy'):
            x = np.load(path + '/rank.npy')
            svalues, svalues_crit = x[:-1], x[-1]
        else: 
            svalues, svalues_crit = est.covariance_rank_test(calculate_critical=True)
            x = np.concatenate([svalues, [svalues_crit]])
            np.save(path + '/rank.npy', x)
        
        # Robust CI to weak iv 
        lb_robust, ub_robust = est.robust_conf_int(lb=-1, ub=1)

        point_df=point
        test_df=test
        test_df['stat,crit'] = test_df.apply(lambda x: f"{round(x.statistic,1)},{round(x['critical value'],1)}", axis=1)
        test_df = test_df[['stat,crit','p-value']]
        test_df.index=['id','primal','dual', 'weakIV']
        test_df_flat = test_df.T.unstack().to_frame().sort_index(level=1).T
        test_df_flat.columns = test_df_flat.columns.map('_'.join)
        point_df = point_df[['point', 'stderr', 'ci_lower', 'ci_upper']]                  
        df = pd.concat([point_df.reset_index(), test_df_flat.reset_index()] + inf_rm_points,axis=1)
        point = df[[c for c in df.columns if c != 'index']]
        point['ci_lower_robust'] = lb_robust
        point['ci_upper_robust'] = ub_robust
        point['rank'] = np.sum(svalues >= svalues_crit)
        point['Xdim'] = len(Xset)
        point['Zdim'] = len(Zset)
        point['inf'] = len(inf_dict['switch_sign'])
        point['D_Y'] = dy
        final_ss_dy_df.append(point)
final_ss_dy_df = pd.concat(final_ss_dy_df)
final_ss_dy_df.to_csv('all_ss_points.csv')
pk.dump(final_ss_dy, open('final_ss_dy.pkl', 'wb'))

df = final_ss_dy_df[~final_ss_dy_df.D_Y.isin(dys_main)].reset_index()
df['ci'] = (df.ci_upper - df.ci_lower)/2
def statcrit_fn(df):
    direction = {'dual':'<', 'primal':'<', 'id':'>', 'weakIV':'>'}
df['D_Y'] = df.D_Y.map(lambda x: dd['_'.join(x.split('_')[:-1])] + ', ' + yy[x.split('_')[-1]])
df['dual_stat,crit'] = df['dual_stat,crit'].map(lambda x: f"{round(float(x.split(',')[0]),1)}<{round(float(x.split(',')[-1]),1)}")
df['primal_stat,crit'] = df['primal_stat,crit'].map(lambda x: f"{round(float(x.split(',')[0]),1)}<{round(float(x.split(',')[-1]),1)}")
df['id_stat,crit'] = df['id_stat,crit'].map(lambda x: f"{round(float(x.split(',')[0]),1)}>{round(float(x.split(',')[-1]),1)}")
df['weakIV_stat,crit'] = df['weakIV_stat,crit'].map(lambda x: f"{round(float(x.split(',')[0]),1)}>{round(float(x.split(',')[-1]),1)}")
df.point = [f'{round(x,2)}$\pm${round(y,2)}' for x,y in zip(df.point,df.ci)]
df =df[['D_Y', 'point','primal_stat,crit', 'dual_stat,crit', 'id_stat,crit', 'weakIV_stat,crit', 'rank']]
df.to_csv('all_main_7metrics.csv',index=False)
df

In [None]:
df = final_ss_dy_df[final_ss_dy_df.D_Y.isin(dys_main)].reset_index()
df['point_robust'] = df.point 

ci_types = ['','_robust',  '_infRm_n=200','_infRm_switch_sign',]
df = df[['D_Y'] + [x+f'{y}' for x in ['point','ci_lower', 'ci_upper'] for y in ci_types]]
import seaborn as sns
plt.clf()
with sns.axes_style("whitegrid"):
    plt.subplots(figsize=(8,8), dpi=100)
    df2 = []
    for dy in dys_main:
    #     sns.set_theme()
    #     plt.subplots(figsize=(10,10))
        sub_df = df[df.D_Y==dy]
        for ci_type in ci_types:
            lower = sub_df[f'ci_lower{ci_type}'].iloc[0]
            upper = sub_df[f'ci_upper{ci_type}'].iloc[0]
            ci_type = {'':'Full data', '_robust': 'Weak IV confidence interval', 
                        '_infRm_switch_sign': 'Remove full high-influence set',
                       '_infRm_n=200': 'Remove top 200 high-influence\npoints '}[ci_type]
            df2.append(pd.DataFrame({'how': [ci_type]*4, 'p': lower, 'points':[lower, lower, upper, upper], 'D_Y': [dy]*4}, index=[0,0,0,0]))
    df2 = pd.concat(df2)
    df2['D_Y'] = df2.D_Y.map(lambda x: dd['_'.join(x.split('_')[:-1])] + ', ' + yy[x.split('_')[-1]])
    ax=sns.boxplot(data=df2, whis=0, showfliers=False, width=.7, linewidth=1.4, x='points', y='D_Y', hue='how')
    sns.set_palette(sns.color_palette(["#B9E5FA", "#219EBC","#FB8500", "#FFC803", "#FAEDCD",  "#3C6E71", ]
))
    plt.legend(fontsize=12, title='')
   
    plt.xlabel("Effect estimate")
    plt.grid(True, which='both', axis='y', linewidth=0.3, color='gray')
    plt.grid(True, which='both', axis='x', linewidth=0.3, color='gray')
#     sns.set_context("talk", font_scale=.8)  # You can adjust the font_scale as needed
    plt.xticks(fontsize=15)
    plt.yticks(fontsize=15)

    plt.title("Weak IV confidence interval & influence removal test  K", fontsize=15)

In [None]:

import seaborn as sns
plt.clf()
with sns.axes_style("whitegrid"):
    plt.subplots(figsize=(15,5), dpi=100)
    df2 = []
    for dy in df.D_Y.unique():
        sns.set_theme()
        plt.subplots(figsize=(10,10))
        sub_df = df[df.D_Y==dy]
        for ci_type in ci_types:
            lower = sub_df[f'ci_lower{ci_type}'].iloc[0]
            upper = sub_df[f'ci_upper{ci_type}'].iloc[0]
            ci_type = {'':'Full data', '_robust': 'Weak IV confidence interval', 
                        '_infRm_switch_sign': 'Remove full high-influence set',
                       '_infRm_n=200': 'Remove top 200 high-influence\npoints '}[ci_type]
            df2.append(pd.DataFrame({'how': [ci_type]*4, 'p': lower, 'points':[lower, lower, upper, upper], 'D_Y': [dy]*4}, index=[0,0,0,0]))
    df2 = pd.concat(df2)
    df2['D_Y'] = df2.D_Y.map(lambda x: dd['_'.join(x.split('_')[:-1])] + ', ' + yy[x.split('_')[-1]])
    df2 = df2.sort_values(by='points')
    ax=sns.boxplot(data=df2, whis=0, showfliers=False, width=.7, linewidth=1.4, y='points', x='D_Y', hue='how')
    sns.set_palette(sns.color_palette(["#B9E5FA", "#219EBC","#FB8500", "#FFC803", "#FAEDCD",  "#3C6E71", ]
))
    plt.legend(fontsize=15, title='')
   
    plt.xlabel("\$\theta\$")
    plt.grid(True, which='both', axis='y', linewidth=0.3, color='gray')
    plt.grid(True, which='both', axis='x', linewidth=0.3, color='gray')
#     sns.set_context("talk", font_scale=.8)  # You can adjust the font_scale as needed
    plt.xticks(fontsize=15,rotation=90)
    plt.yticks(fontsize=15)
    

    plt.title("Weak IV confidence interval", fontsize=15)

In [None]:
# ###Stratification res 

ss_dy = pk.load(open('ss_dy_updated_inf.pkl', 'rb'))
W, _, W_feats, X, X_binary, X_feats, Z, Z_binary, Z_feats, Y, D = load_ukbb_data(D_label='Asian', Y_label='OA',pp=True)
inc_idx = np.argwhere(['738' in x and '-' not in x for x in W_feats]).flatten()
W_feats_int = get_int_feats(W_feats[:-1])
strat_dfs  = []
for dy in ss_dy.keys():
        print(dy)
        if 'inc' not in dy and 'endo' not in dy and 'preg' not in dy:
            point, test, inf_idxs, path, (Xset, Zset) = ss_dy[dy]
            point['dy'] = dy
            point['fname'] = 'Full data'
            strat_dfs.append(point)

            d, y = '_'.join(dy.split('_')[:-1]), dy.split('_')[-1]
            Xres, Zres, Yres, Dres = load_ukbb_res_data(d, y)
            Zres, Zint = rmNaZ(Zres, Zint_)
            est = ProximalDE(semi=True, cv=3, verbose=1, random_state=3)
            for i in inc_idx:
                idx = (W[:,i]+.5).astype(bool).flatten()
                f = W_feats_int[i]
                est.fit(None, Dres[idx], Zres[idx][:, Zset], Xres[idx][:, Xset], Yres[idx])
                sm = est.summary()
                point = pd.DataFrame.from_records(sm.tables[0].data)
                cols = point.iloc[0]
                point = point.iloc[1:]
                point.columns = cols
                point['dy'] = dy
                point['fname'] = f
                strat_dfs.append(point)
df = pd.concat(strat_dfs)
df.point = df.point.astype(float)
df.stderr = df.stderr.astype(float)
df = df[df.fname.isin(['Full data',
       'Average total household income before tax=Less than 18,000',
       'Average total household income before tax=Greater than 100,000'])]
df.fname = df.fname.map(lambda x: {'Full data':'Full data',
       'Average total household income before tax=Less than 18,000': 'Low income',
       'Average total household income before tax=Greater than 100,000': 'High income'}[x])


plt.subplots(figsize=(10,20))
df_inc_strat = df[~df.dy.isin(['Black,\nChronic kidney disease','Obese,\nOsteoarthritis','Asian,\nOsteoarthritis'])]
import seaborn as sns
plt.clf()
with sns.axes_style("whitegrid"):
    plt.subplots(figsize=(6,20), dpi=100)
    df2 = []
    df_inc_strat['dy'] = df_inc_strat.dy.map(lambda x: dd['_'.join(x.split('_')[:-1])] + ',\n' + yy[x.split('_')[-1]])
    df_inc_strat['Wint'] = df_inc_strat.fid#.map(lambda x:)
    ax=sns.barplot(data=df_inc_strat, x='point', y='dy', hue='fname')
    sns.set_palette(sns.color_palette(["#B9E5FA", "#219EBC","#FB8500", "#FFC803", "#FAEDCD",  "#3C6E71", ]
))
    plt.legend(fontsize=12, title='')
   
    plt.xlabel("$\\theta$",fontsize=14)
    plt.ylabel('')
    plt.grid(True, which='both', axis='y', linewidth=0.3, color='gray')
    plt.grid(True, which='both', axis='x', linewidth=0.3, color='gray')
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=11)

    plt.title("Stratification by income", fontsize=18)

In [None]:
ss_dy = pk.load(open('ss_dy.pkl', 'rb'))
bs_dfs = []
for dy in ss_dy.keys():
    points= np.array([x[0].point.iloc[0] for x in ss_dy[dy]])
    m, idx = get_median_item(points)
    if dy in dys_main:
        point, test, inf_idxs, path, (Xset, Zset) = ss_dy[dy][idx]
        d, y = '_'.join(dy.split('_')[:-1]), dy.split('_')[-1]

        
        # Subsample / bootstrap work 
        df_ = point[['point', 'ci_lower', 'ci_upper']].copy()
        df_['stage'] = 'All'
        df_['frac'] = 'All'
        df_['D_Y'] = dy
        bs_dfs.append(df_)
        
        if os.path.exists(path + '/table0_bs_stage3.csv'):
            for i in range(1,4):
                df_ = pd.read_csv(path + f'/table0_bs_stage{i}.csv',header=1, index_col=1)
                df_['stage'] = i
                df_['frac'] = float(.5)
                df_['D_Y'] = dy
                bs_dfs.append(df_)

        if os.path.exists(path + '/table0_bs_stage3_nBs1000_frac0.75.csv'):
            df_ = pd.read_csv(path + f'/table0_bs_stage3_nBs10000_frac0.25.csv',header=1, index_col=1)
            df_['frac'] = float(0.25)
            df_['stage'] = 3
            df_['D_Y'] = dy
            bs_dfs.append(df_)
            df_ = pd.read_csv(path + f'/table0_bs_stage3_nBs1000_frac0.75.csv',header=1, index_col=1)
            df_['frac'] = float(0.75)
            df_['stage'] = 3
            df_['D_Y'] = dy
            bs_dfs.append(df_)
            
        if os.path.exists(path + '/table0_bs_stage1_nBs10_frac0.25.csv'):
            for frac in ['0.75', '0.25']:
                df_ = pd.read_csv(path + f'/table0_bs_stage1_nBs10_frac{frac}.csv',header=1, index_col=1)
                df_['frac'] = float(frac)
                df_['stage'] = 1
                df_['D_Y'] = dy
                bs_dfs.append(df_)
        else:
            print(dy)
bs_dfs = pd.concat(bs_dfs)[['point', 'ci_lower', 'ci_upper', 'stage', 'frac', 'D_Y']].reset_index()
bs_dfs = bs_dfs[[c for c in bs_dfs.columns if c != 'index']]

import seaborn as sns

df = bs_dfs[bs_dfs.frac.isin([.5, 'All'])]
df.stage = [{1: '1 (M=10)', 2: '2 (M=100)', 3: '3 (M=1000)', 'All': 'Full data'}[x] for x in df.stage]
with sns.axes_style("whitegrid"):
    plt.subplots(figsize=(8,8), dpi=100)
    df2 = []
    for dy in dys_main:
    #     sns.set_theme()
    #     plt.subplots(figsize=(10,10))
        for stage in df.stage.unique():
            sub_df = df[(df.stage==stage) & (df.D_Y==dy)]
            lower = sub_df.ci_lower.iloc[0]
            upper = sub_df.ci_upper.iloc[0]
            dy = sub_df.D_Y.iloc[0]
            df2.append(pd.DataFrame({'how': [stage]*4, 'p': lower, 'points':[lower, lower, upper, upper], 'D_Y': [dy]*4}, index=[0,0,0,0]))
    df2 = pd.concat(df2)
    df2['D_Y'] = df2.D_Y.map(lambda x: dd['_'.join(x.split('_')[:-1])] + ', ' + yy[x.split('_')[-1]])
    sns.boxplot(data=df2, whis=0, showfliers=False, width=.7, linewidth=1.4, x='points', y='D_Y', hue='how')
    sns.set_palette(sns.color_palette(["#B9E5FA", "#219EBC","#FB8500", "#FFC803", "#FAEDCD",  "#3C6E71", ]
))
    plt.legend(fontsize=13, title_fontsize=13, title='Stage of Estimation\n(M=iterations)', loc='lower left')


    plt.xlabel("Effect estimate")
    plt.grid(True, which='both', axis='y', linewidth=0.3, color='gray')
    plt.grid(True, which='both', axis='x', linewidth=0.3, color='gray')
#     sns.set_context("talk", font_scale=.8)  # You can adjust the font_scale as needed
    plt.xticks(fontsize=15)
    plt.yticks(fontsize=15)

    plt.title("Comparing stages of re-estimation for bootstrapping       A)   B)", fontsize=20)

In [None]:
import seaborn as sns

df = bs_dfs[bs_dfs.stage.isin([3, 'All'])]
df.frac = [{'All': 'Full data', .5: '50%', .75: '75%', .25: '25%',0.1:'10%' }[x] for x in df.frac]
with sns.axes_style("whitegrid"):
    plt.subplots(figsize=(8,8), dpi=100)
    df2 = []
    for dy in dys_main:
    #     sns.set_theme()
    #     plt.subplots(figsize=(10,10))
        for frac in df.frac.unique():
            try:
                sub_df = df[(df.frac==frac) & (df.D_Y==dy)]
                lower = sub_df.ci_lower.iloc[0]
                upper = sub_df.ci_upper.iloc[0]
                dy = sub_df.D_Y.iloc[0]
                df2.append(pd.DataFrame({'how': [frac]*4, 'p': lower, 'points':[lower, lower, upper, upper], 'D_Y': [dy]*4}, index=[0,0,0,0]))
            except Exception as e:
                print(dy, frac)
    df2 = pd.concat(df2)
    custom_order = {'Full data':0, '75%':1, '50%':2, '25%':3}
    df2['order'] = df2.how.map(custom_order)
    df2=df2.sort_values(by=[ 'order', 'p'])
    df2['D_Y'] = df2.D_Y.map(lambda x: dd['_'.join(x.split('_')[:-1])] + ', ' + yy[x.split('_')[-1]])
    sns.boxplot(data=df2, whis=0, showfliers=False, width=.7, linewidth=1.4, x='points', y='D_Y', hue='how')
    sns.set_palette(sns.color_palette(["#B9E5FA", "#219EBC","#FB8500", "#FFB703", "#FAEDCD",  "#3C6E71", ]
))
    plt.legend(fontsize=13,title_fontsize=13, title='Percent of data resampled')


    plt.xlabel("Effect estimate")
    plt.grid(True, which='both', axis='y', linewidth=0.3, color='gray')
    plt.grid(True, which='both', axis='x', linewidth=0.3, color='gray')
#     sns.set_context("talk", font_scale=.8)  # You can adjust the font_scale as needed
    plt.xticks(fontsize=15)
    plt.yticks(fontsize=15)

    plt.title("Copmaring bootstrap sample size for K=10 iterations\nRe-estimating at Stage 1",fontsize=15)

In [None]:
import seaborn as sns

df = bs_dfs[bs_dfs.stage.isin([3, 'All'])]
df.frac = [{.5: '50%', .75: '75%', .25: '25%', 'All': 'Full data'}[x] for x in df.frac]
with sns.axes_style("whitegrid"):
    plt.subplots(figsize=(8,8), dpi=100)
    df2 = []
    for dy in dys_main:
    #     sns.set_theme()
    #     plt.subplots(figsize=(10,10))
        for frac in df.frac.unique():
            try:
                sub_df = df[(df.frac==frac) & (df.D_Y==dy)]
                lower = sub_df.ci_lower.iloc[0]
                upper = sub_df.ci_upper.iloc[0]
                dy = sub_df.D_Y.iloc[0]
                df2.append(pd.DataFrame({'how': [frac]*4, 'p': lower, 'points':[lower, lower, upper, upper], 'D_Y': [dy]*4}, index=[0,0,0,0]))
            except Exception as e:
                print(dy, frac)
    df2 = pd.concat(df2)
    custom_order = {'Full data':0, '75%':1, '50%':2, '25%':3}
    df2['order'] = df2.how.map(custom_order)
    df2=df2.sort_values(by=[ 'order', 'p'])
    df2['D_Y'] = df2.D_Y.map(lambda x: dd['_'.join(x.split('_')[:-1])] + ', ' + yy[x.split('_')[-1]])
    sns.boxplot(data=df2, whis=0, showfliers=False, width=.7, linewidth=1.4, x='points', y='D_Y', hue='how')
    sns.set_palette(sns.color_palette(["#B9E5FA", "#219EBC","#FB8500", "#FFB703", "#FAEDCD",  "#3C6E71", ]
))
    plt.legend(fontsize=13,title_fontsize=13, title='Percent of data resampled')


    plt.xlabel("Effect estimate")
    plt.grid(True, which='both', axis='y', linewidth=0.3, color='gray')
    plt.grid(True, which='both', axis='x', linewidth=0.3, color='gray')
#     sns.set_context("talk", font_scale=.8)  # You can adjust the font_scale as needed
    plt.xticks(fontsize=15)
    plt.yticks(fontsize=15)

    plt.title("Comparing bootstrap sample size for K=1k iterations\nRe-estimating at Stage 3",fontsize=15)

In [None]:
dys_main

In [None]:
def load_ukbb_res_data_xgb(D_label, Y_label):
    print("Assuming D,Y,Z all treated as continuous, using linear regression of W")
    _get_path = lambda fname: f'/oak/stanford/groups/rbaltman/karaliu/bias_detection/causal_analysis/data_hm_std/{fname}'
    D_label = D_label.replace('_', '')
    save_fname_addn = ''
    if Y_label in ['endo', 'preg']:
        save_fname_addn='_FemOnly'
    Winfo = f'_Wrm{D_label}'
    Yres = np.load(_get_path(f'Yres_{Y_label}{Winfo}{save_fname_addn}_Rgrs=xgb.npy')) 
    Dres = np.load(_get_path(f'Dres_{D_label}{save_fname_addn}_Rgrs=xgb.npy')) 
    Xres = np.load(_get_path(f'Xres{Winfo}{save_fname_addn}_Rgrs=xgb.npy')) 
    print(f'Zres{Winfo}{save_fname_addn}_Rgrs=xgb.npy')
    Zres = np.load(_get_path(f'Zres{Winfo}{save_fname_addn}_Rgrs=xgb.npy')) 
    return Xres, Zres, Yres, Dres
      
final_ss_dy = pk.load(open('ss_dy_updated_inf.pkl', 'rb'))

for dy in ['Low_inc_deprs']:#dys_main:
    point, test, inf_idxs, path, (Xset, Zset) = final_ss_dy[dy]
    print(dy)
    d, y = '_'.join(dy.split('_')[:-1]), dy.split('_')[-1]

    try:
        Xres, Zres, Yres, Dres = load_ukbb_res_data_xgb(d, y)
        Zres, Zint = rmNaZ(Zres, Zint_)

        est = ProximalDE(semi=True, cv=3, verbose=1, random_state=3)
        est.fit(None, Dres, Zres[:, Zset], Xres[:, Xset], Yres)
        sm = est.summary()
        display(sm.tables[0])
        display(sm.tables[2])
    except Exception as e:
        print(e)
#         diag = est.run_diagnostics()
#         inf_mp200 = est.influential_set(max_points=200)
#         inf_alhpa = inf_idxs
#         inf = est.influential_set()
#         infs = {'switch_sign': inf, 'alpha=.05': inf_idxs, 'n=200':inf_mp200}
#         for k,v in infs.items():
#             print(v.shape)
#         final_ss_dy[dy] = point, test, infs, path, (Xset, Zset)
# pk.dump(final_ss_dy, open('ss_dy_updated_inf.pkl', 'wb'))
# ss_dy = final_ss_dy

In [None]:
# diag.cookd_plot()
# plt.show()

In [None]:
# diag.l2influence_plot()
# plt.show()

In [None]:
from proximalde.ukbb_proximal import ProximalDE_UKBB
def run_inf_rm(D_label, Y_label, inf_idxs, Xset, Zset, save_dir):
    np.random.seed(4)
    W, _, W_feats, X, X_binary, X_feats, Z, Z_binary, Z_feats, Y, D = load_ukbb_data(D_label=d, Y_label=y)
    Z = Z[:,~bad_idx][:,Zset]
    X = X[:,Xset]
    est = ProximalDE_UKBB(binary_D=False, semi=True, cv=3, verbose=1, random_state=3)
    est.fit(np.delete(W, inf_idxs, axis=0), np.delete(D, inf_idxs, axis=0),
             np.delete(Z, inf_idxs, axis=0), np.delete(X, inf_idxs, axis=0),
             np.delete(Y, inf_idxs, axis=0), D_label=D_label, Y_label=Y_label, save_fname_addn=f'_infRm_{D_label}{Y_label}') 
    return est.summary(alpha=0.05, save_dir=save_dir,save_fname_addn='_infRm')
# run_inf_rm(D_label=d, Y_label=y, inf_idxs=inds, Xset=Xset, Zset=Zset, save_dir=save_dir)


In [None]:
import pickle as pk 
np.random.seed(30)
for D_label in ['Female','Obese','Black', 'Asian']:
    for Y_label in tqdm(['OA', 'myoc', 'back']):
        dy=f'{D_label}_{Y_label}'
        dy = 'Obese_back'
        print("\n"*3)
        print(dy)
        dir = f'./results/D={D_label}_Y={Y_label}/Dbin=False_Ybin=False_XZbin=False_Rgr=linear/'
        test = pd.read_csv(dir + '/table2.csv',header=1, index_col=1)
        point = pd.read_csv(dir + '/table0.csv',header=1, index_col=1)
        print("Estimation on all data:")
        display(point)
        display(test)
        dir = f'./results/proxyrm/{dy}'
        all_paths = os.listdir(dir)
        import seaborn as sns
        all_data = {}
        print("Estimations after rm weak proxies (that pass >2 tests):")
        xs = []
        zs = []
        points = []
        for var in all_paths:
#             print(var)
            candidates = pk.load(open(f'{dir}/{var}/candidates.pkl', 'rb'))
            est_paths = [x for x in os.listdir(f'{dir}/{var}') if not '.pkl' in x]
            pass_tests = []
            for est in est_paths:
                test = pd.read_csv(f'{dir}/{var}/{est}/table2.csv',header=1, index_col=1)
                if test['pass test'].sum() > 2:
                    point = pd.read_csv(f'{dir}/{var}/{est}/table0.csv',header=1, index_col=1)
                    xset, zset = candidates[int(est.split('_')[-1])]
                    print(len(xset), len(zset))
                    xset_ = np.zeros(65)
                    zset_ = np.zeros(197)
                    xset_[xset] = 1
                    zset_[zset] = 1
                    xs.append(xset_[None,:])
                    zs.append(zset_[None,:])
                    display(point)
                    display(test)
                    print("-"*20)
                    points.append(point.point.iloc[0])
                pass_tests.append(test['pass test'].sum())
#             sns.histplot(pass_tests)
#             print((np.array(pass_tests)>2).mean())
#             if (np.array(pass_tests)==4).any():
#                 print(1/0)
#             plt.show()
        print(1/0)
        dir = f'./results/proxyrm/old_est/'
        all_paths = [p for p in os.listdir(dir) if dy in p]
        pass_tests = []
        for est in all_paths:
            try:
                test = pd.read_csv(f'{dir}/{est}/table2.csv',header=1, index_col=1)
                pass_tests.append(test['pass test'].sum())
                if test['pass test'].sum() > 2:
                    point = pd.read_csv(f'{dir}/{est}/table0.csv',header=1, index_col=1)
                    display(point)
                    display(test)
#                 if test['pass test'].sum() == 3:
#                     print(est)
            except:
                pass
#         print((np.array(pass_tests)>2).mean())
#         sns.histplot(pass_tests)
#         plt.show()

In [None]:
np.random.seed(30)
all_proxyrm_table_paths = os.listdir(f'./results/proxyrm/')
all_data = {}
for D_label in ['Female','Obese','Black', 'Asian']:
    for Y_label in tqdm(['deprs', 'back']):
        dir = f'./results/D={D_label}_Y={Y_label}/Dbin=False_Ybin=False_XZbin=False_Rgr=linear/'
        test = pd.read_csv(dir + '/table2.csv',header=1, index_col=1)
        point = pd.read_csv(dir + '/table0.csv',header=1, index_col=1)

#         candidates = pk.load(open(f'./{D_label}_{Y_label}_candidates_reweight_acually.pkl', 'rb'))
        candidates = pk.load(open(f'./{D_label}_{Y_label}_candidates_reweight.pkl', 'rb'))
        new_tests = []
        new_points = []
        xsets=[]
        zsets=[]
        ps=[p for p in all_proxyrm_table_paths if f'{D_label}_{Y_label}' in p and 'Dbin' not in p and 'random' in p]
        if len(ps) > 0:
            print(dir)
            display(point)
            display(test)
        for path in ps:
            try:
                point = pd.read_csv(f'./results/proxyrm/{path}/table0.csv', header=1, index_col=1)
                test = pd.read_csv(f'./results/proxyrm/{path}/table2.csv', header=1, index_col=1)
                i = int(path.split("_")[-1])
                Xset, Zset = candidates[i]
                if test['pass test'].sum() > 2:
                    print("----"*10)
                    new_tests.append(test)
                    new_points.append(point)
                    rmXset = np.setdiff1d(np.arange(Xres.shape[1]), Xset)
                    x=Xint[Xset]
                    np.random.shuffle(x)
                    print("Kept Xs = ", x[:10])
                    x=Xint[rmXset]
                    np.random.shuffle(x)
                    print("Deleted Xs =", x)

                    rmZset = np.setdiff1d(np.arange(Zres.shape[1]), Zset)
                    
                    x=Zint[Zset]
                    np.random.shuffle(x)
                    print("Kept Zs = ", x[:10])
                    x=Zint[rmZset]
                    np.random.shuffle(x)
                    print("Deleted Zs =", x[:10])
                    x = np.zeros(Xres.shape[1])
                    x[Xset] = 1
                    z = np.zeros(Zres.shape[1])
                    z[Zset] = 1
                    xsets.append(x[None,:])
                    zsets.append(z[None,:])
                    display(test)
                    display(point)
            except FileNotFoundError:
                pass
        if len(new_tests) > 0:
            plt.subplots(figsize=(10,10))
            plt.imshow(np.concatenate(xsets))
            plt.show()
            plt.subplots(figsize=(10,10))
            plt.imshow(np.concatenate(zsets))
            plt.show()
            all_data[f'{D_label}_Y={Y_label}'] = {'tests': [og_test]+new_tests, 
                                    'point': [og_point] + new_points, 'xsets': xsets,
                                                 'zsets':zsets}
            print(f'{D_label}_Y={Y_label}', len(new_tests))

In [None]:
test_dfs = []
res_dfs = []
point_dfs = []
model_regression = 'linear'
model_classification = 'linear'
Dbin = False 
Ybin = False
SAVE_PATH = './results/'
XZbin = False
for D_label in ['Black', 'Female', 'Obese','Asian']:
        print(D_label)
        for Y_label in ['OA', 'RA', 'myoc','deprs', 'back']:

            try:
#                     res_model_save = ''
#                     if res_model == 'xgb':
#                         res_model_save = '_xgb'
                save_dir = f'{SAVE_PATH}/D={D_label}_Y={Y_label}/Dbin={Dbin}_Ybin={Ybin}_XZbin={XZbin}_Rgr={model_regression}'
                test_df = pd.read_csv(save_dir + '/table2.csv', header=1, index_col=1)
                test_df = test_df.drop(columns=['0'])
                test_df_flat = test_df.T.unstack().to_frame().sort_index(level=1).T
                test_df_flat.columns = test_df_flat.columns.map('_'.join)
                point_df = pd.read_csv(save_dir + '/table0.csv', header=1, index_col=1)
                point_df = point_df.drop(columns=['0'])                    
                res_df = pd.read_csv(save_dir + '/table1.csv', header=1, index_col=1)
                res_df = res_df.drop(columns=['0'])
                test_df_flat['D_Y'] = point_df['D_Y'] = res_df['D_Y'] = f'{D_label}_{Y_label}'
#                 test_df_flat['res_model'] = point_df['res_model'] = res_df['res_model'] = res_model
                res_dfs.append(res_df)
                point_dfs.append(point_df)
                test_dfs.append(test_df_flat)
            except Exception as e:
                print(e)

point_df = pd.concat(point_dfs)
res_df = pd.concat(res_dfs)
test_df = pd.concat(test_dfs)
test_df = test_df.reindex(sorted(test_df.columns), axis=1)
point_df

In [None]:
from scipy.stats import t
import textwrap 
def wrap_labels(labels, max_characters):
    return [textwrap.fill(label, max_characters) for label in labels]
    
UKBB_DATA_DIR = '/oak/stanford/groups/rbaltman/karaliu/bias_detection/cohort_creation/data/'

def _load_data(fname: str):
    data = np.load(UKBB_DATA_DIR + f'{fname}_data_rd.npy', allow_pickle=False)    
    feats = np.load(UKBB_DATA_DIR + f'{fname}_feats_rd.npy', allow_pickle=False)
    assert np.isnan(data).sum() == 0, 'NaN values cannot exist in data'
    return data, feats
    
def load_XZ_data():
    Z, Z_feats = _load_data(fname = 'srMntSlp')
    X, X_feats = _load_data(fname = 'biomMed')
    return X, X_feats, Z, Z_feats

def load_DY_data(D_label, Y_label):
    D_df = pd.read_csv(UKBB_DATA_DIR + 'updated_sa_df_pp.csv')
    D = D_df[D_label].to_numpy()     
    Y = pd.read_csv(UKBB_DATA_DIR + 'updated_Y_labels.csv')[Y_label].to_numpy()[:,None] 
    return D, Y

def load_res_data(D_label, Y_label):
    _get_path = lambda fname: f'/oak/stanford/groups/rbaltman/karaliu/bias_detection/causal_analysis/data_hm/{fname}'
    D_label = D_label.replace('_', '')
    Winfo = f'_Wrm{D_label}'
    Yres = np.load(_get_path(f'Yres_{Y_label}{Winfo}.npy')) 
    Dres = np.load(_get_path(f'Dres_{D_label}.npy')) 
    Xres = np.load(_get_path(f'Xres{Winfo}.npy')) 
    Zres = np.load(_get_path(f'Zres{Winfo}.npy')) 
    return Xres, Zres, Yres, Dres

def XZ_hparam_plot(dual_or_primal='dual'):
    """
    Tool for visualizing covariance matrices 
    and how different thresholds for N affect the corresponding covariance.
    Could be for dual or primal. 
    """
    
    D_labels = ['Female', 'Obese','Black', 'Asian']
    if dual_or_primal=='dual':
        Y_labels = ['OA']
    else:
        Y_labels = ['OA', 'RA', 'myoc', 'copd', 'deprs', 'back']
    for D_label in D_labels:
        print(D_label)
        for Y_label in Y_labels:
            Xres, Zres, _, Dres = load_res_data(D_label, Y_label=Y_label)
            if dual_or_primal=='dual':
                Xrpr, Zrpr, Drpr, label = 'X', 'Z', 'D', D_label
                XZres_cov, XZres_pvals, XZres_thresh = get_cov(Xres, Zres, get_pvals=True)
                DXres_cov, DXres_pvals, DXres_thresh = get_cov(Dres, Xres, get_pvals=True)
            else:
                Xrpr, Zrpr, Drpr, label = 'Z', 'X', 'Y', Y_label
                XZres_cov, XZres_pvals, XZres_thresh = get_cov(Zres, Xres, get_pvals=True)
                DXres_cov, DXres_pvals, DXres_thresh = get_cov(Yres, Zres, get_pvals=True)
            
            # We only care about X feats with st.sig. assn with D
            ss_DXidx = (DXres_pvals < DXres_thresh).squeeze()
            if dual_or_primal=='dual':
                fig, axs = plt.subplots(1, 2, figsize=(20, 5), dpi=70)
            else:
                fig, axs = plt.subplots(1, 2, figsize=(12, 5), dpi=70)

            im = axs[0].imshow((XZres_pvals[ss_DXidx] < XZres_thresh), aspect='auto', cmap='Blues', interpolation='nearest')
            axs[0].set_title(f"Nonzero Covariance({Xrpr},{Zrpr})\n(if spearman pvalue < .05/dz*dx)", fontsize=12)
            axs[0].set_ylabel(f"{Xrpr} feats", fontsize=10)
            axs[0].set_xlabel(f"{Zrpr} feats", fontsize=10)
            cbar = fig.colorbar(im, ax=axs[0], orientation='vertical')
            cbar.set_ticks([0, 1])
            cbar.set_ticklabels(['Zero', 'Nonzero'])

            
            nZfeats = []
            Xfeats_w_zero_Zfeats = []
            for i in range(40):
                keep = (XZres_pvals[ss_DXidx] < XZres_thresh).sum(axis=0) > i
                zero = ((XZres_pvals[ss_DXidx][:, keep] < XZres_thresh).sum(axis=1) == 0).sum()
                nZfeats.append(keep.sum())
                Xfeats_w_zero_Zfeats.append(zero)
                
            axs[1].set_title(f"Each {Zrpr} feat's # st.sig. correlations w/ all {Xrpr} feats, {Drpr}={label}", fontsize=12)
            ax3 = axs[1]
            ax3.plot(range(40), nZfeats, color='blue')
            ax3.set_ylabel(f"# {Zrpr} feats correlated with >N {Xrpr} feats", color='blue')
            ax3.set_xlabel(f"N\n(hparam for filtering {Xrpr} feats)")
            ax3.tick_params(axis='y', labelcolor='blue')

            ax3b = ax3.twinx()
            ax3b.plot(range(40), Xfeats_w_zero_Zfeats, color='green')
            ax3b.set_ylabel(f"# {Xrpr} feats w/ 0 st.sig. feat correlation\nafter rm {Zrpr} feats correlated with <N {Xrpr} feats", color='green')
            ax3b.tick_params(axis='y')

            ax3.set_title(f'How removing {Zrpr} feats affects Cov(X,Z)')
            ax3.grid(True, axis='y', linestyle='--', linewidth=0.5, color='lightgray')

            # Adjust layout to prevent overlapping
            plt.tight_layout()
            plt.suptitle(f"{dual_or_primal} violation plots for {Drpr}={label}\nusing Xres, Zres")
            # Show the combined plots
            plt.show()

def XZ_vis_cov(rmX_zeroZ_dual, popssZ_dual, rmZ_zeroX_primal, popssX_primal):


    for D_label in ['Female', 'Obese','Black', 'Asian']:
        for Y_label in ['OA', 'RA', 'myoc', 'copd', 'deprs', 'back']:
            print(f"{D_label}->{Y_label}")
            W, W_feats, X, X_feats, Z, Z_feats, Y, D = load_ukbb_data(D_label=D_label, Y_label=Y_label)
            Xres, Zres, Yres, Dres = load_res_data(D_label, Y_label)

            # Filter X,Z based on bad proxies proxies 
            Xprm_idx = popssX_primal[Y_label][D_label]    
            Zprm_idx = ~rmZ_zeroX_primal[Y_label][D_label]
            Zdual_idx = popssZ_dual['OA'][D_label]   
            Xdual_idx = ~rmX_zeroZ_dual['OA'][D_label]

            Xres = Xres[:,(Xprm_idx & Xdual_idx)]
            Zres = Zres[:,(Zprm_idx & Zdual_idx)]

            XZres_cov, XZres_pvals, XZres_thresh = get_cov(Xres,Zres, get_pvals=True)
            DXres_cov, DXres_pvals, DXres_thresh = get_cov(Dres, Xres, get_pvals=True)
            YZres_cov, YZres_pvals, YZres_thresh = get_cov(Yres, Zres, get_pvals=True)

            XZres_cov1 = np.concatenate([XZres_cov, YZres_cov], axis=0)
            DXres_cov1 = np.vstack([np.array([[0]]), DXres_cov.reshape(-1, 1)])
            XZres_cov2 = np.concatenate([XZres_cov1, DXres_cov1], axis=1)
            plt.subplots(1,1,figsize=(12,8),dpi=60)
            sns.heatmap(np.abs(XZres_cov2), cmap='Blues')
            plt.axhline(y=Xres.shape[1], color='black', linewidth=2)  # Line between N and N+1 (where the vector is appended)
            plt.axvline(x=Zres.shape[1], color='black', linewidth=2)  # Line between N and N+1 (where the vector is appended)
            xtick_labels = list(Xint[(Xprm_idx & Xdual_idx)]) + [f'D={D_label}']  # First N are xi, last one is D
            plt.yticks(ticks=np.arange(Xres.shape[1]+1)+.5, labels=xtick_labels, rotation=0)
            ytick_labels = list(Zint[(Zprm_idx & Zdual_idx)]) + [f'Y={Y_label}']  # First N are xi, last one is D
            plt.xticks(ticks=np.arange(Zres.shape[1]+1)+.5, labels=wrap_labels(ytick_labels,50),rotation=90, fontsize=8)
            plt.xlabel('Z feats')
            plt.ylabel('X feats')
            plt.title(f'|Cov(X,Z)| after filtering X,Z\n{D_label}->{Y_label}')
            plt.show()

            assert ((XZres_pvals < XZres_thresh).sum(axis=0) > 0).all()
            assert ((XZres_pvals < XZres_thresh).sum(axis=1) > 0).all()
            XZres_cov1 = np.concatenate([XZres_pvals < XZres_thresh, YZres_pvals < YZres_thresh], axis=0)
            DXres_cov1 = np.vstack([np.array([[0]]), DXres_pvals.reshape(-1, 1) < DXres_thresh])
            XZres_cov2 = np.concatenate([XZres_cov1, DXres_cov1], axis=1)
            plt.subplots(1,1,figsize=(12,8),dpi=60)
            sns.heatmap(np.abs(XZres_cov2), cmap='Blues')
            plt.axhline(y=Xres.shape[1], color='black', linewidth=2)  # Line between N and N+1 (where the vector is appended)
            plt.axvline(x=Zres.shape[1], color='black', linewidth=2)  # Line between N and N+1 (where the vector is appended)
            xtick_labels = list(Xint[(Xprm_idx & Xdual_idx)]) + [f'D={D_label}']  # First N are xi, last one is D
            plt.yticks(ticks=np.arange(Xres.shape[1]+1)+.5, labels=xtick_labels, rotation=0)
            ytick_labels = list(Zint[(Zprm_idx & Zdual_idx)]) + [f'Y={Y_label}']  # First N are xi, last one is D
            plt.xticks(ticks=np.arange(Zres.shape[1]+1)+.5, labels=wrap_labels(ytick_labels,50),rotation=90, fontsize=8)
            plt.xlabel('Z feats')
            plt.ylabel('X feats')
            plt.title(f'Nonzero Cov(X,Z) after filtering X,Z\n{D_label}->{Y_label}')
            plt.show()

            DYXres = np.concatenate([Dres, Yres, Xres], axis=1)
            Zall_cov, Zall_pvals, Zall_thresh = get_cov(DYXres, Zres, get_pvals=True)
            plt.subplots(1,1,figsize=(12,8),dpi=60)
            sns.heatmap(np.abs(Zall_cov), cmap='Blues')
            xtick_labels =[f'D={D_label}', f'Y={Y_label}'] + list(Xint[(Xprm_idx & Xdual_idx)])  # First N are xi, last one is D
            plt.yticks(ticks=np.arange(Xres.shape[1]+2)+.5, labels=xtick_labels, rotation=0)
            ytick_labels = list(Zint[(Zprm_idx & Zdual_idx)])
            plt.xticks(ticks=np.arange(Zres.shape[1])+.5, labels=wrap_labels(ytick_labels,50),rotation=90, fontsize=10)
            plt.ylabel('X feats')
            plt.xlabel('Z feats')
            plt.title(f'|Cov(DYX,Z)| after filtering X,Z\n{D_label}->{Y_label}')
            plt.show()


            DYZres = np.concatenate([Dres, Yres, Zres], axis=1)
            Zall_cov, Zall_pvals, Zall_thresh = get_cov(DYZres, Xres, get_pvals=True)
            plt.subplots(1,1,figsize=(12,8),dpi=60)
            sns.heatmap(np.abs(Zall_cov), cmap='Blues')
            xtick_labels =[f'D={D_label}', f'Y={Y_label}'] + list(Zint[(Zprm_idx & Zdual_idx)])  # First N are xi, last one is D
            plt.yticks(ticks=np.arange(Zres.shape[1]+2)+.5, labels=xtick_labels, rotation=0)
            ytick_labels = list(Xint[(Xprm_idx & Xdual_idx)])
            plt.xticks(ticks=np.arange(Xres.shape[1])+.5, labels=wrap_labels(ytick_labels,50),rotation=90, fontsize=10)
            plt.ylabel('Z feats')
            plt.xlabel('X feats')
            plt.title(f'|Cov(DYZ,X| after filtering X,Z\n{D_label}->{Y_label}')
            plt.show()

In [None]:
np.random.seed(0)
for i in idx_list = np.random.choice(np.arange(len(candidates)), size=2)
for Xset, Zset in candidates[idx_list]:
    print("Xset =", Xset)
    print("Zset =", Zset)
    print()
    est = ProximalDE(random_state=3)
    est.fit(None, Dres, Zres[:, Zset], Xres[:, Xset], Yres)
    t = est.summary().tables[2]
    df = pd.DataFrame.from_records(t.data)
    header = df.iloc[0] # grab the first row for the header
    df = df[1:] # take the data less the header row
    df.columns = header
    df['pass test'] = df['pass test'].map(lambda x: x == 'True')
    if df['pass test'].all():
        display(est.summary().tables[0], est.summary().tables[2])
        print("Xset =", Xset)
        print("Zset =", Zset)