## Setup Only for Colab

In [None]:
# prompt: mount drive

from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks/hidden_mediators

In [None]:
%ls

In [None]:
from IPython.display import clear_output

In [None]:
import time
!pip install -r requirements.txt
time.sleep(2)
clear_output()

In [None]:
import time
# replace `develop` with `install` if you wont make library code changes
!python setup.py develop
time.sleep(2)
clear_output()
# Restart the session after running this

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks

# Main Logic

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
from joblib import Parallel, delayed
from proximalde.gen_data import gen_data_complex, gen_data_no_controls, gen_data_with_mediator_violations, gen_data_no_controls_discrete_m
from proximalde.proximal import proximal_direct_effect, ProximalDE, residualizeW
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler
from proximalde.crossfit import fit_predict
from proximalde.utilities import covariance, svd_critical_value
from proximalde.proximal import residualizeW
from proximalde.proxy_rm_utils import *
from proximalde.ukbb_data_utils import *
import seaborn as sns
import matplotlib.pyplot as plt

# Synthetic

In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = .0  # this can be zero; does not hurt
e = .7  # if the product of e*f is small, then we have a weak instrument
f = .5  # if the product of e*f is small, then we have a weak instrument
g = .0  # this can be zero; does not hurt

n = 10000
pw = 100
pz, px = 50, 40
invalidZ = [0, 4, 5]
invalidX = [0, 6, 8]
validZ = np.setdiff1d(np.arange(pz), invalidZ)
validX = np.setdiff1d(np.arange(px), invalidX)
np.random.seed(0)
W, D, _, Z, X, Y = gen_data_with_mediator_violations(n, pw, pz, px, a, b, c, d, e, f, g,
                                                     invalidZinds=invalidZ, invalidXinds=invalidX)
W = None

np.random.seed(0)
Dres, Zres_rm, Xres, Yres, *_ = residualizeW(W, D, Z, X, Y)

In [None]:
prm = WeakProxyRemoval(Xres,Zres_rm,Dres,primal_type='full', violation_type='full',est_thresh=.05)
prm.update_Y(Yres)

In [None]:
prm.violation_est(np.arange(Xres.shape[1]), np.arange(Zres_rm.shape[1]))[:2]

In [None]:
prm.violation_full(np.arange(Xres.shape[1]), np.arange(Zres_rm.shape[1]))[:2]

In [None]:
candidates = prm.find_candidate_sets(10,niters=2,second_xset_thresh=1)
len(candidates)

In [None]:
# Only shows when # test passes > 2 
prm.get_estimates(candidates, verbose=1,npass=2)

# Real

In [None]:
# Interpretable X, Z features 
def rmNaZ(Zres_rm, Zint):
    bad_idx = np.array([('Do not know' in x) or ('Prefer not to' in x) for x in Zint])
    Zres_rm = Zres_rm[:,~bad_idx]
    Zint = Zint[~bad_idx]
    return Zres_rm, Zint

X, X_feats, Z, Z_feats = load_ukbb_XZ_data()
Xint = get_int_feats(X_feats)
Zint_ = get_int_feats(Z_feats)
D_label = 'Obese'
Y_label = 'back'
Xres, Zres_rm, Yres, Dres = load_ukbb_res_data(D_label, Y_label)

bad_idx = np.array([('Do not know' in x) or ('Prefer not to' in x) for x in Zint_])
Xres, Zres_rm, Yres, Dres = load_ukbb_res_data(D_label, Y_label)
Zres_rm = Zres_rm[:,~bad_idx]
Zint = Zint_[~bad_idx]

In [None]:
est.summary()

In [None]:
covZD = covariance(prm.Zres_rm, Dres)
split_covs = []
for m in [prm.covXZ.T, prm.covZY, covZD]:
    split_covs += [m[bad_idx], m[~bad_idx]]
    x = np.concatenate([m[bad_idx], np.zeros(m[:1].shape), m[~bad_idx]],axis=0)
    sns.heatmap(np.abs(x),cmap='Blues')
    plt.title('na | rest')
    plt.show()
#     x = np.concatenate([m[bad_idx], np.zeros(m[:1].shape), np.abs(m[~bad_idx])],axis=1)
#     sns.heatmap(np.abs(x))
#     plt.title('na | rest')
#     plt.show()
XZna, XZ, ZYna, ZY, ZDna, ZD = split_covs


### Interpretability into the candidate sets the algorithm found

In [None]:
import itertools
import pickle as pk 

ds = ['Low_inc', 'On_dis', 'No_priv_insr', 'No_uni', 'Female', 'Black', 'Obese', 'Asian']
ys = ['OA', 'myoc','deprs', 'back', 'RA', 'fibro', 'infl', 'copd','chrkd','mgrn','mela', 'preg', 'endo']
dys = list(itertools.product(ds, ys))
dys = ['_'.join(x) for x in dys]

def get_median_item(y):
    srt_idx = np.argsort(y)
    if len(y) % 2 == 0: # if even elements, get max point
        i1 = len(y)//2 - 1
        i2 = len(y)//2
        if np.abs(y[srt_idx[i1]]) > np.abs(y[srt_idx[i2]]):
            return y[srt_idx[i1]], srt_idx[i1]
        else:
            return y[srt_idx[i2]], srt_idx[i2]
    else:
        return y[srt_idx[len(y)//2]], srt_idx[len(y)//2]

def rmNaZ(Zres_rm, Zint):
    bad_idx = np.array([('Do not know' in x) or ('Prefer not to' in x) for x in Zint])
    Zres_rm = Zres_rm[:,~bad_idx]
    Zint = Zint[~bad_idx]
    return Zres_rm, Zint

X, X_feats, Z, Z_feats = load_ukbb_XZ_data()
Xint = get_int_feats(X_feats)
Zint_ = get_int_feats(Z_feats)
bad_idx = np.array([('Do not know' in x) or ('Prefer not to' in x) for x in Zint_])

Xres, Zres, Yres, Dres = load_ukbb_res_data('Black', 'OA')
Zres = Zres[:,~bad_idx]
Zint = Zint_[~bad_idx]

ss_dy = pk.load(open('ss_dy.pkl', 'rb'))


In [None]:
xs.mean(axis=0)#==18

In [None]:
sns.clustermap(np.corrcoef(np.concatenate([zs,xs],axis=1)),yticklabels=labels_,xticklabels=labels_,linecolor='white',linewidths=.01,center=0,cmap='RdBu')


In [None]:

# xs, zs = [], []
# labels = []
# get_local = False
# for dy, v in ss_dy.items():
#     print(dy)
#     points= np.array([x[0].point.iloc[0] for x in ss_dy[dy]])
# #     plt.hist(points)
# #     plt.show()
#     m, idx = get_median_item(points)
#     print(m,idx,len(points),np.std(points)*1.96)
#     if get_local:
#         for point, test, inf_idxs, path, (Xset, Zset) in ss_dy[dy]:
#             xset_ = np.zeros(65)
#             zset_ = np.zeros(Zres.shape[1])
#             xset_[Xset] = 1
#             zset_[Zset] = 1
#             xs.append(xset_[None,:])
#             zs.append(zset_[None,:])
# #         if len(xs)>1 and np.abs(m)>0.05:
# #             sns.heatmap(np.corrcoef(np.concatenate(xs)),center=0,cmap='RdBu') 
# #             plt.show()
# #             sns.heatmap(np.corrcoef(np.concatenate(zs)),center=0,cmap='RdBu') 
# #             plt.show()
# #             idx = np.argsort(np.concatenate(zs).mean(axis=0))[::-1]
# #             print(Zint[idx[:10]])
# #             idx = np.argsort(np.concatenate(xs).mean(axis=0))[::-1]
# #             print(Xint[idx[:10]])
#         xs, zs = [], []

#     print()
#     print()
#     if np.abs(m) > .05:
#         labels.append(dy)
#         point, test, inf_idxs, path, (Xset, Zset)= ss_dy[f'{dy}'][idx]
#         xset_ = np.zeros(65)
#         zset_ = np.zeros(Zres.shape[1])
#         xset_[Xset] = 1
#         zset_[Zset] = 1
#         xs.append(xset_[None,:])
#         zs.append(zset_[None,:])
# xs=np.concatenate(xs)
# zs=np.concatenate(zs)
xs_ = xs[:,(xs.mean(axis=0) != 0) & (xs.mean(axis=0) !=1)]
zs_ = zs[:,(zs.mean(axis=0) != 0) & (zs.mean(axis=0) !=1)]
yy={'OA': 'Osteoarthritis','mgrn':'Migraine','copd':'COPD', 'back': 'Back pain', 'deprs': 'Depression', 'myoc': "Heart disease", 'RA': 'Rh. Arthritis', 'fibro': 'Fibromyalgia', 'chrkd': 'Chronic kidney disease'}
dd={'No_uni': 'No p.s. education', 'Low_inc': 'Low income','Obese':'Obese', 'Female': 'Female', 'Black': 'Black', 'Asian': "Asian", 'On_dis': 'Disability insurance'}
labels_ = [dd['_'.join(x.split('_')[:-1])] + ', ' + yy[x.split('_')[-1]] for x in labels]
plt.subplots(figsize=(10,10),dpi=100)
plt.imshow(xs)
plt.show()
# sns.heatmap(np.corrcoef(xs),xticklabels=labels,center=0,cmap='RdBu')
sns.clustermap(np.corrcoef(xs_),yticklabels=labels_,xticklabels=labels_,center=0,linecolor='white',linewidths=.006,cmap='RdBu')
plt.show()
idx = np.argsort(xs.mean(axis=0))[::-1]
print(Xint[idx[:6]],xs.mean(axis=0)[idx[:6]] )
print(Xint[idx[-6:]],xs.mean(axis=0)[idx[-6:]] )
plt.hist(xs.var(axis=0))
plt.show()

plt.subplots(figsize=(10,10),dpi=100)
plt.imshow(zs)
plt.show()
# sns.heatmap(np.corrcoef(zs),xticklabels=labels,center=0,cmap='RdBu')
sns.clustermap(np.corrcoef(zs_),yticklabels=labels_,xticklabels=labels_,linecolor='white',linewidths=.01,center=0,cmap='RdBu')
plt.show()
idx = np.argsort(zs.mean(axis=0))[::-1]
print(Zint[idx[:6]],zs.mean(axis=0)[idx[:6]] )
print(Zint[idx[-6:]],zs.mean(axis=0)[idx[-6:]] )
plt.hist(zs.var(axis=0))
plt.show()

In [None]:
def get_cov(A,B,get_pvals=True):
    return A.T @ B / B.shape[0]
def wrap_labels(labels, max_characters):
    return [textwrap.fill(label, max_characters) for label in labels]
   
print(f"{D_label}->{Y_label}")
# _, _, Yres, Dres = load_ukbb_res_data(D_label, Y_label)
# Zres_rm = Zres[:,Zset]
# Xres_rm = Xres[:,Xset]

# Zres_other = Zres[:, np.setdiff1d(np.arange(Zres.shape[1]), Zset)]
# Xres_other = Xres[:, np.setdiff1d(np.arange(Xres.shape[1]), Xset)]

# XZres_rm_cov= get_cov(Xres_rm,Zres_rm, get_pvals=True)
# XZres_cov= get_cov(Xres_other,Zres_other, get_pvals=True)
# DXres_rm_cov = get_cov(Dres, Xres_rm, get_pvals=True)
# DXres_cov = get_cov(Dres, Xres_other, get_pvals=True)
# YZres_rm_cov = get_cov(Yres, Zres_rm, get_pvals=True)
# YZres_cov = get_cov(Yres, Zres_other, get_pvals=True)
sns.kdeplot(DXres_cov.squeeze(), label='other')
sns.kdeplot(DXres_rm_cov.squeeze(), label='rm')
plt.show()
plt.legend()
from scipy import stats
print(stats.ks_2samp(DXres_rm_cov.squeeze(),DXres_cov.squeeze())) #pval1 = same

sns.kdeplot(YZres_cov.squeeze(), label='other')
sns.kdeplot(YZres_rm_cov.squeeze(), label='rm')
plt.show()
plt.legend()
from scipy import stats
print(stats.ks_2samp(YZres_cov.squeeze(),YZres_rm_cov.squeeze())) #pval1 = same

DYXres_rm = np.concatenate([Dres, Yres, Xres_rm], axis=1)
Zall_cov = get_cov(DYXres_rm, Zres_rm, get_pvals=True)
plt.subplots(1,1,figsize=(12,8),dpi=60)
sns.heatmap(np.abs(Zall_cov), cmap='Blues')
xtick_labels =[f'D={D_label}', f'Y={Y_label}'] + list(Xint[Xset])  # First N are xi, last one is D
plt.yticks(ticks=np.arange(Xres_rm.shape[1]+2)+.5, labels=xtick_labels, rotation=0)
ytick_labels = list(Zint[Zset])
plt.xticks(ticks=np.arange(Zres_rm.shape[1])+.5, labels=wrap_labels(ytick_labels,50),rotation=90, fontsize=10)
plt.ylabel('X feats')
plt.xlabel('Z feats')
plt.title(f'|Cov(DYX,Z)| after filtering X,Z\n{D_label}->{Y_label}')
plt.show()


DYZres_rm = np.concatenate([Dres, Yres, Zres_rm], axis=1)
Zall_cov = get_cov(DYZres_rm, Xres_rm, get_pvals=True)
plt.subplots(1,1,figsize=(12,8),dpi=60)
sns.heatmap(np.abs(Zall_cov), cmap='Blues')
xtick_labels =[f'D={D_label}', f'Y={Y_label}'] + list(Zint[Zset])  # First N are xi, last one is D
plt.yticks(ticks=np.arange(Zres_rm.shape[1]+2)+.5, labels=xtick_labels, rotation=0)
ytick_labels = list(Xint[Xset])
plt.xticks(ticks=np.arange(Xres_rm.shape[1])+.5, labels=wrap_labels(ytick_labels,50),rotation=90, fontsize=10)
plt.ylabel('Z feats')
plt.xlabel('X feats')
plt.title(f'|Cov(DYZ,X| after filtering X,Z\n{D_label}->{Y_label}')
plt.show()

In [None]:
def XZ_vis_cov(D_label, Y_label, Xset, Zset):


    print(f"{D_label}->{Y_label}")
    _, _, Yres, Dres = load_ukbb_res_data(D_label, Y_label)
    Zres_rm_ = rmNaZ(Zres_rm,Zint)
    Zres_rm_rm = Zres_rm[:,Zset]
    Xres_rm = Xres[:,Xset]

    XZres_rm_cov, XZres_rm_pvals, XZres_rm_thresh = get_cov(Xres,Zres_rm_, get_pvals=True)
    DXres_cov, DXres_pvals, DXres_thresh = get_cov(Dres, Xres, get_pvals=True)
    YZres_rm_cov, YZres_rm_pvals, YZres_rm_thresh = get_cov(Yres, Zres_rm, get_pvals=True)

    DYXres = np.concatenate([Dres, Yres, Xres], axis=1)
    Zall_cov, Zall_pvals, Zall_thresh = get_cov(DYXres, Zres_rm, get_pvals=True)
    plt.subplots(1,1,figsize=(12,8),dpi=60)
    sns.heatmap(np.abs(Zall_cov), cmap='Blues')
    xtick_labels =[f'D={D_label}', f'Y={Y_label}'] + list(Xint)  # First N are xi, last one is D
    plt.yticks(ticks=np.arange(Xres.shape[1]+2)+.5, labels=xtick_labels, rotation=0)
    ytick_labels = list(Zint)
    plt.xticks(ticks=np.arange(Zres_rm.shape[1])+.5, labels=wrap_labels(ytick_labels,50),rotation=90, fontsize=10)
    plt.ylabel('X feats')
    plt.xlabel('Z feats')
    plt.title(f'|Cov(DYX,Z)| after filtering X,Z\n{D_label}->{Y_label}')
    plt.show()


    DYZres_rm = np.concatenate([Dres, Yres, Zres_rm], axis=1)
    Zall_cov, Zall_pvals, Zall_thresh = get_cov(DYZres_rm, Xres, get_pvals=True)
    plt.subplots(1,1,figsize=(12,8),dpi=60)
    sns.heatmap(np.abs(Zall_cov), cmap='Blues')
    xtick_labels =[f'D={D_label}', f'Y={Y_label}'] + list(Zint)  # First N are xi, last one is D
    plt.yticks(ticks=np.arange(Zres_rm.shape[1]+2)+.5, labels=xtick_labels, rotation=0)
    ytick_labels = list(Xint)
    plt.xticks(ticks=np.arange(Xres.shape[1])+.5, labels=wrap_labels(ytick_labels,50),rotation=90, fontsize=10)
    plt.ylabel('Z feats')
    plt.xlabel('X feats')
    plt.title(f'|Cov(DYZ,X| after filtering X,Z\n{D_label}->{Y_label}')
    plt.show()

In [None]:
test_dfs = []
res_dfs = []
point_dfs = []
model_regression = 'linear'
model_classification = 'linear'
Dbin = False 
Ybin = False
SAVE_PATH = './results/'
XZbin = False
for D_label in ['Black', 'Female', 'Obese','Asian']:
        print(D_label)
        for Y_label in ['OA', 'RA', 'myoc','deprs', 'back']:

            try:
#                     res_model_save = ''
#                     if res_model == 'xgb':
#                         res_model_save = '_xgb'
                save_dir = f'{SAVE_PATH}/D={D_label}_Y={Y_label}/Dbin={Dbin}_Ybin={Ybin}_XZbin={XZbin}_Rgr={model_regression}'
                test_df = pd.read_csv(save_dir + '/table2.csv', header=1, index_col=1)
                test_df = test_df.drop(columns=['0'])
                test_df_flat = test_df.T.unstack().to_frame().sort_index(level=1).T
                test_df_flat.columns = test_df_flat.columns.map('_'.join)
                point_df = pd.read_csv(save_dir + '/table0.csv', header=1, index_col=1)
                point_df = point_df.drop(columns=['0'])                    
                res_df = pd.read_csv(save_dir + '/table1.csv', header=1, index_col=1)
                res_df = res_df.drop(columns=['0'])
                test_df_flat['D_Y'] = point_df['D_Y'] = res_df['D_Y'] = f'{D_label}_{Y_label}'
#                 test_df_flat['res_model'] = point_df['res_model'] = res_df['res_model'] = res_model
                res_dfs.append(res_df)
                point_dfs.append(point_df)
                test_dfs.append(test_df_flat)
            except Exception as e:
                print(e)

point_df = pd.concat(point_dfs)
res_df = pd.concat(res_dfs)
test_df = pd.concat(test_dfs)
test_df = test_df.reindex(sorted(test_df.columns), axis=1)
point_df

In [None]:
from scipy.stats import t
import textwrap 
def wrap_labels(labels, max_characters):
    return [textwrap.fill(label, max_characters) for label in labels]
    
UKBB_DATA_DIR = '/oak/stanford/groups/rbaltman/karaliu/bias_detection/cohort_creation/data/'

def _load_data(fname: str):
    data = np.load(UKBB_DATA_DIR + f'{fname}_data_rd.npy', allow_pickle=False)    
    feats = np.load(UKBB_DATA_DIR + f'{fname}_feats_rd.npy', allow_pickle=False)
    assert np.isnan(data).sum() == 0, 'NaN values cannot exist in data'
    return data, feats
    
def load_XZ_data():
    Z, Z_feats = _load_data(fname = 'srMntSlp')
    X, X_feats = _load_data(fname = 'biomMed')
    return X, X_feats, Z, Z_feats

def load_DY_data(D_label, Y_label):
    D_df = pd.read_csv(UKBB_DATA_DIR + 'updated_sa_df_pp.csv')
    D = D_df[D_label].to_numpy()     
    Y = pd.read_csv(UKBB_DATA_DIR + 'updated_Y_labels.csv')[Y_label].to_numpy()[:,None] 
    return D, Y

def load_res_data(D_label, Y_label):
    _get_path = lambda fname: f'/oak/stanford/groups/rbaltman/karaliu/bias_detection/causal_analysis/data_hm/{fname}'
    D_label = D_label.replace('_', '')
    Winfo = f'_Wrm{D_label}'
    Yres = np.load(_get_path(f'Yres_{Y_label}{Winfo}.npy')) 
    Dres = np.load(_get_path(f'Dres_{D_label}.npy')) 
    Xres = np.load(_get_path(f'Xres{Winfo}.npy')) 
    Zres_rm = np.load(_get_path(f'Zres_rm{Winfo}.npy')) 
    return Xres, Zres_rm, Yres, Dres

def XZ_hparam_plot(dual_or_primal='dual'):
    """
    Tool for visualizing covariance matrices 
    and how different thresholds for N affect the corresponding covariance.
    Could be for dual or primal. 
    """
    
    D_labels = ['Female', 'Obese','Black', 'Asian']
    if dual_or_primal=='dual':
        Y_labels = ['OA']
    else:
        Y_labels = ['OA', 'RA', 'myoc', 'copd', 'deprs', 'back']
    for D_label in D_labels:
        print(D_label)
        for Y_label in Y_labels:
            Xres, Zres_rm, _, Dres = load_res_data(D_label, Y_label=Y_label)
            if dual_or_primal=='dual':
                Xrpr, Zrpr, Drpr, label = 'X', 'Z', 'D', D_label
                XZres_rm_cov, XZres_rm_pvals, XZres_rm_thresh = get_cov(Xres, Zres_rm, get_pvals=True)
                DXres_cov, DXres_pvals, DXres_thresh = get_cov(Dres, Xres, get_pvals=True)
            else:
                Xrpr, Zrpr, Drpr, label = 'Z', 'X', 'Y', Y_label
                XZres_rm_cov, XZres_rm_pvals, XZres_rm_thresh = get_cov(Zres_rm, Xres, get_pvals=True)
                DXres_cov, DXres_pvals, DXres_thresh = get_cov(Yres, Zres_rm, get_pvals=True)
            
            # We only care about X feats with st.sig. assn with D
            ss_DXidx = (DXres_pvals < DXres_thresh).squeeze()
            if dual_or_primal=='dual':
                fig, axs = plt.subplots(1, 2, figsize=(20, 5), dpi=70)
            else:
                fig, axs = plt.subplots(1, 2, figsize=(12, 5), dpi=70)

            im = axs[0].imshow((XZres_rm_pvals[ss_DXidx] < XZres_rm_thresh), aspect='auto', cmap='Blues', interpolation='nearest')
            axs[0].set_title(f"Nonzero Covariance({Xrpr},{Zrpr})\n(if spearman pvalue < .05/dz*dx)", fontsize=12)
            axs[0].set_ylabel(f"{Xrpr} feats", fontsize=10)
            axs[0].set_xlabel(f"{Zrpr} feats", fontsize=10)
            cbar = fig.colorbar(im, ax=axs[0], orientation='vertical')
            cbar.set_ticks([0, 1])
            cbar.set_ticklabels(['Zero', 'Nonzero'])

            
            nZfeats = []
            Xfeats_w_zero_Zfeats = []
            for i in range(40):
                keep = (XZres_rm_pvals[ss_DXidx] < XZres_rm_thresh).sum(axis=0) > i
                zero = ((XZres_rm_pvals[ss_DXidx][:, keep] < XZres_rm_thresh).sum(axis=1) == 0).sum()
                nZfeats.append(keep.sum())
                Xfeats_w_zero_Zfeats.append(zero)
                
            axs[1].set_title(f"Each {Zrpr} feat's # st.sig. correlations w/ all {Xrpr} feats, {Drpr}={label}", fontsize=12)
            ax3 = axs[1]
            ax3.plot(range(40), nZfeats, color='blue')
            ax3.set_ylabel(f"# {Zrpr} feats correlated with >N {Xrpr} feats", color='blue')
            ax3.set_xlabel(f"N\n(hparam for filtering {Xrpr} feats)")
            ax3.tick_params(axis='y', labelcolor='blue')

            ax3b = ax3.twinx()
            ax3b.plot(range(40), Xfeats_w_zero_Zfeats, color='green')
            ax3b.set_ylabel(f"# {Xrpr} feats w/ 0 st.sig. feat correlation\nafter rm {Zrpr} feats correlated with <N {Xrpr} feats", color='green')
            ax3b.tick_params(axis='y')

            ax3.set_title(f'How removing {Zrpr} feats affects Cov(X,Z)')
            ax3.grid(True, axis='y', linestyle='--', linewidth=0.5, color='lightgray')

            # Adjust layout to prevent overlapping
            plt.tight_layout()
            plt.suptitle(f"{dual_or_primal} violation plots for {Drpr}={label}\nusing Xres, Zres_rm")
            # Show the combined plots
            plt.show()



In [None]:
np.random.seed(0)
for i in idx_list = np.random.choice(np.arange(len(candidates)), size=2)
for Xset, Zset in candidates[idx_list]:
    print("Xset =", Xset)
    print("Zset =", Zset)
    print()
    est = ProximalDE(random_state=3)
    est.fit(None, Dres, Zres_rm[:, Zset], Xres[:, Xset], Yres)
    t = est.summary().tables[2]
    df = pd.DataFrame.from_records(t.data)
    header = df.iloc[0] # grab the first row for the header
    df = df[1:] # take the data less the header row
    df.columns = header
    df['pass test'] = df['pass test'].map(lambda x: x == 'True')
    if df['pass test'].all():
        display(est.summary().tables[0], est.summary().tables[2])
        print("Xset =", Xset)
        print("Zset =", Zset)