## Setup Only for Colab

In [None]:
# prompt: mount drive

from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks/hidden_mediators

In [None]:
%ls

In [None]:
from IPython.display import clear_output

In [None]:
import time
!pip install -r requirements.txt
time.sleep(2)
clear_output()

In [None]:
import time
# replace `develop` with `install` if you wont make library code changes
!python setup.py develop
time.sleep(2)
clear_output()
# Restart the session after running this

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks

# Main Logic

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
from joblib import Parallel, delayed
from proximalde.gen_data import gen_data_complex, gen_data_no_controls, gen_data_with_mediator_violations, gen_data_no_controls_discrete_m
from proximalde.proximal import proximal_direct_effect, ProximalDE, residualizeW
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler
from proximalde.crossfit import fit_predict
from proximalde.utilities import covariance, svd_critical_value
from proximalde.proximal import residualizeW
from proximalde.proxy_rm_utils import *
from proximalde.ukbb_data_utils import *

# Synthetic

In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = .0  # this can be zero; does not hurt
e = .7  # if the product of e*f is small, then we have a weak instrument
f = .5  # if the product of e*f is small, then we have a weak instrument
g = .0  # this can be zero; does not hurt

n = 50000
pw = 100
pz, px = 50, 40
invalidZ = [0, 4, 5]
invalidX = [0, 6, 8]
np.random.seed(0)

validZ = np.setdiff1d(np.arange(pz), invalidZ)
validX = np.setdiff1d(np.arange(px), invalidX)
W, D, _, Z, X, Y = gen_data_with_mediator_violations(n, pw, pz, px, a, b, c, d, e, f, g,
                                                     invalidZinds=invalidZ, invalidXinds=invalidX)
W = None

np.random.seed(0)
Dres, Zres, Xres, Yres, *_ = residualizeW(W, D, Z, X, Y)

In [None]:
prm = WeakProxyRemoval(Xres,Zres,Dres,primal_type='est')
prm.update_Y(Yres)
prm.dv_bench, prm.pv_bench

In [None]:
prm.violation(list(validX), list(validZ))

In [None]:
prm.violation(list(invalidX), list(invalidZ))

In [None]:
prm.dthresh = .08
prm.pthresh = .08
candidates = prm.find_candidate_sets(10)
len(candidates)

In [None]:
# Only shows when # test passes > 2 
prm.get_estimates(candidates, verbose=1,npass=2)

# Real

In [None]:
# Interpretable X, Z features 
X, X_feats, Z, Z_feats = load_ukbb_XZ_data()
Xint = get_int_feats(X_feats)
Zint = get_int_feats(Z_feats)
D_label = 'Obese'
Y_label = 'back'
Xres, Zres, Yres, Dres = load_ukbb_res_data(D_label, Y_label)

In [None]:
prm = WeakProxyRemoval(Xres,Zres,Dres,primal_type='est')
prm.update_Y(Yres)
prm.Xint = Xint
prm.Zint = Zint
prm.dv_bench, prm.pv_bench

In [None]:
prm = WeakProxyRemoval(Xres,Zres,Dres,primal_type='est')
prm.update_Y(Yres)
prm.Xint = Xint
prm.Zint = Zint
prm.dv_bench, prm.pv_bench

In [None]:
prm.dthresh = .08
prm.pthresh = .08
prm.change_primal_type('est')
N = 10
gen_nextX='random'
gen_nextZ='random'

np.random.seed(0)
candidates = prm.find_candidate_sets(N,gen_nextX=gen_nextX,gen_nextZ=gen_nextZ)
print(len(candidates))
prm.get_estimates(candidates, idx_list =np.random.choice(np.arange(len(candidates)),size=6), verbose=1, npass=1)

In [None]:
# Interpretable X, Z features 
X, X_feats, Z, Z_feats = load_ukbb_XZ_data()
Xint = get_int_feats(X_feats)
Zint = get_int_feats(Z_feats)
D_label = 'Obese'
Y_label = 'back'
Xres, Zres, Yres, Dres = load_ukbb_res_data(D_label, Y_label)
bad_idx = np.array([('Do not know' in x) or ('Prefer not to' in x) for x in Zint])
prm2 = WeakProxyRemoval(Xres,Zres[:,~bad_idx],Dres,primal_type='est')
prm2.update_Y(Yres)
prm2.Xint = Xint
prm2.Zint = Zint[~bad_idx]
prm2.dv_bench, prm2.pv_bench

In [None]:
save_dir =f'./results/proxyrm/Obese_back/ntrials200_dth0.1_pth0.1_ptyfull_genXrandom_genZrandom_Rgrs=linear/'
candidates = pk.load(open(save_dir + 'candidates.pkl', 'rb'))
np.random.seed(2)
prm.get_estimates(candidates, idx_list =np.random.choice(np.arange(len(candidates)),size=50), verbose=1, npass=1)

In [None]:
X, X_feats, Z, Z_feats = load_ukbb_XZ_data()
Zint = get_int_feats(Z_feats)
Xint = get_int_feats(X_feats)
D_label = 'Obese'
Y_label = 'back'
Xres, Zres, Yres, Dres = load_ukbb_res_data(D_label, Y_label)
bad_idx = np.array([('Do not know' in x) or ('Prefer not to' in x) for x in Zint])
prm2 = WeakProxyRemoval(Xres,Zres[:,~bad_idx],Dres,primal_type='est')
prm2.update_Y(Yres)
prm2.Xint = Xint
prm2.Zint = Zint[~bad_idx]
prm2.dv_bench, prm2.pv_bench
N = 25
gen_nextX='random'
gen_nextZ='random'

np.random.seed(0)
candidates = prm2.find_candidate_sets(N,gen_nextX=gen_nextX,gen_nextZ=gen_nextZ)
print(len(candidates))
prm2.get_estimates(candidates, idx_list =np.random.choice(np.arange(len(candidates)),size=40), verbose=1, npass=1)


In [None]:
prm2.Zint = Zint[~bad_idx]
N = 20
gen_nextX='random'
gen_nextZ='random'

np.random.seed(0)
candidates = prm2.find_candidate_sets(N,gen_nextX=gen_nextX,gen_nextZ=gen_nextZ)
print(len(candidates))
prm2.get_estimates(candidates, idx_list =np.random.choice(np.arange(len(candidates)),size=40), verbose=1, npass=1)


In [None]:
# IGNORE

In [None]:
np.random.seed(30)
for D_label in ['Female','Obese','Black', 'Asian']:
    for Y_label in tqdm(['OA', 'myoc', 'back']):
        dy=f'{D_label}_{Y_label}'
        print("\n"*3)
        print(dy)
        dir = f'./results/D={D_label}_Y={Y_label}/Dbin=False_Ybin=False_XZbin=False_Rgr=linear/'
        test = pd.read_csv(dir + '/table2.csv',header=1, index_col=1)
        point = pd.read_csv(dir + '/table0.csv',header=1, index_col=1)
        print("Estimation on all data:")
        display(point)
        display(test)
        dir = f'./results/proxyrm/{dy}'
        all_paths = os.listdir(dir)
        import seaborn as sns
        all_data = {}
        print("Estimations after rm weak proxies (that pass >2 tests):")

        for var in all_paths:
            print(var)
            est_paths = [x for x in os.listdir(f'{dir}/{var}') if not '.pkl' in x]
            pass_tests = []
            for est in est_paths:
                test = pd.read_csv(f'{dir}/{var}/{est}/table2.csv',header=1, index_col=1)
                if test['pass test'].sum() > 2:
                    point = pd.read_csv(f'{dir}/{var}/{est}/table0.csv',header=1, index_col=1)
#                     display(point)
#                     display(test)
#                     print("-"*20)
                pass_tests.append(test['pass test'].sum())
#             sns.histplot(pass_tests)
            print((np.array(pass_tests)>2).mean())
#             if (np.array(pass_tests)==4).any():
#                 print(1/0)
#             plt.show()

        dir = f'./results/proxyrm/old_est/'
        all_paths = [p for p in os.listdir(dir) if dy in p]
        pass_tests = []
        for est in all_paths:
            try:
                test = pd.read_csv(f'{dir}/{est}/table2.csv',header=1, index_col=1)
                pass_tests.append(test['pass test'].sum())
                if test['pass test'].sum() > 2:
                    point = pd.read_csv(f'{dir}/{est}/table0.csv',header=1, index_col=1)
                    display(point)
                    display(test)
#                 if test['pass test'].sum() == 3:
#                     print(est)
            except:
                pass
#         print((np.array(pass_tests)>2).mean())
#         sns.histplot(pass_tests)
#         plt.show()

In [None]:
import pickle as pk 
np.random.seed(30)
for D_label in ['Female','Obese','Black', 'Asian']:
    for Y_label in tqdm(['OA', 'myoc', 'back']):
        dy=f'{D_label}_{Y_label}'
        dy = 'Obese_back'
        print("\n"*3)
        print(dy)
        dir = f'./results/D={D_label}_Y={Y_label}/Dbin=False_Ybin=False_XZbin=False_Rgr=linear/'
        test = pd.read_csv(dir + '/table2.csv',header=1, index_col=1)
        point = pd.read_csv(dir + '/table0.csv',header=1, index_col=1)
        print("Estimation on all data:")
        display(point)
        display(test)
        dir = f'./results/proxyrm/{dy}'
        all_paths = os.listdir(dir)
        import seaborn as sns
        all_data = {}
        print("Estimations after rm weak proxies (that pass >2 tests):")
        xs = []
        zs = []
        points = []
        for var in all_paths:
#             print(var)
            candidates = pk.load(open(f'{dir}/{var}/candidates.pkl', 'rb'))
            est_paths = [x for x in os.listdir(f'{dir}/{var}') if not '.pkl' in x]
            pass_tests = []
            for est in est_paths:
                test = pd.read_csv(f'{dir}/{var}/{est}/table2.csv',header=1, index_col=1)
                if test['pass test'].sum() > 2:
                    point = pd.read_csv(f'{dir}/{var}/{est}/table0.csv',header=1, index_col=1)
                    xset, zset = candidates[int(est.split('_')[-1])]
                    print(len(xset), len(zset))
                    xset_ = np.zeros(65)
                    zset_ = np.zeros(197)
                    xset_[xset] = 1
                    zset_[zset] = 1
                    xs.append(xset_[None,:])
                    zs.append(zset_[None,:])
                    display(point)
                    display(test)
                    print("-"*20)
                    points.append(point.point.iloc[0])
                pass_tests.append(test['pass test'].sum())
#             sns.histplot(pass_tests)
#             print((np.array(pass_tests)>2).mean())
#             if (np.array(pass_tests)==4).any():
#                 print(1/0)
#             plt.show()
        print(1/0)
        dir = f'./results/proxyrm/old_est/'
        all_paths = [p for p in os.listdir(dir) if dy in p]
        pass_tests = []
        for est in all_paths:
            try:
                test = pd.read_csv(f'{dir}/{est}/table2.csv',header=1, index_col=1)
                pass_tests.append(test['pass test'].sum())
                if test['pass test'].sum() > 2:
                    point = pd.read_csv(f'{dir}/{est}/table0.csv',header=1, index_col=1)
                    display(point)
                    display(test)
#                 if test['pass test'].sum() == 3:
#                     print(est)
            except:
                pass
#         print((np.array(pass_tests)>2).mean())
#         sns.histplot(pass_tests)
#         plt.show()

In [None]:
xs = np.concatenate(xs)
xsp = xs[np.array(points)>0]
xsn = xs[np.array(points)<0]
diff = xsp.mean(axis=0) - xsn.mean(axis=0)
idx = np.argsort(np.abs(diff))[::-1]
print(diff[idx])
Xint[idx] 
# positive numbers means including that Xfeature tends to create a positive bias effect (more likely to be diag)
# negative numbers mean including that Xfeature ends to create a more negative effect (less likely to be diag)

In [None]:
### Most important features for passing some tests - caution bc depends on Z
sm = xs.sum(axis=0)
idx1 = np.argsort(np.abs(sm))[::-1]
print(sm[idx1])
Xint[idx1]

In [None]:
zs = np.concatenate(zs)[:,:196]
xsp = zs[np.array(points)>0]
xsn = zs[np.array(points)<0]
diff = xsp.mean(axis=0) - xsn.mean(axis=0)
idx = np.argsort(np.abs(diff))[::-1]
print(diff[idx])
Zint[idx] 
# positive numbers means including that Zfeature tends to create a positive bias effect (more likely to be diag)
# negative numbers mean including that Zfeature ends to create a more negative effect (less likely to be diag)

In [None]:
### Most important features for passing some tests - caution bc depends on X
sm = zs.sum(axis=0)
idx1 = np.argsort(np.abs(sm))[::-1]
print(sm[idx1])
Zint[idx1]

In [None]:
xs.sum(axis=0)

In [None]:
np.random.seed(30)
all_proxyrm_table_paths = os.listdir(f'./results/proxyrm/')
all_data = {}
for D_label in ['Female','Obese','Black', 'Asian']:
    for Y_label in tqdm(['deprs', 'back']):
        dir = f'./results/D={D_label}_Y={Y_label}/Dbin=False_Ybin=False_XZbin=False_Rgr=linear/'
        test = pd.read_csv(dir + '/table2.csv',header=1, index_col=1)
        point = pd.read_csv(dir + '/table0.csv',header=1, index_col=1)

#         candidates = pk.load(open(f'./{D_label}_{Y_label}_candidates_reweight_acually.pkl', 'rb'))
        candidates = pk.load(open(f'./{D_label}_{Y_label}_candidates_reweight.pkl', 'rb'))
        new_tests = []
        new_points = []
        xsets=[]
        zsets=[]
        ps=[p for p in all_proxyrm_table_paths if f'{D_label}_{Y_label}' in p and 'Dbin' not in p and 'random' in p]
        if len(ps) > 0:
            print(dir)
            display(point)
            display(test)
        for path in ps:
            try:
                point = pd.read_csv(f'./results/proxyrm/{path}/table0.csv', header=1, index_col=1)
                test = pd.read_csv(f'./results/proxyrm/{path}/table2.csv', header=1, index_col=1)
                i = int(path.split("_")[-1])
                Xset, Zset = candidates[i]
                if test['pass test'].sum() > 2:
                    print("----"*10)
                    new_tests.append(test)
                    new_points.append(point)
                    rmXset = np.setdiff1d(np.arange(Xres.shape[1]), Xset)
                    x=Xint[Xset]
                    np.random.shuffle(x)
                    print("Kept Xs = ", x[:10])
                    x=Xint[rmXset]
                    np.random.shuffle(x)
                    print("Deleted Xs =", x)

                    rmZset = np.setdiff1d(np.arange(Zres.shape[1]), Zset)
                    
                    x=Zint[Zset]
                    np.random.shuffle(x)
                    print("Kept Zs = ", x[:10])
                    x=Zint[rmZset]
                    np.random.shuffle(x)
                    print("Deleted Zs =", x[:10])
                    x = np.zeros(Xres.shape[1])
                    x[Xset] = 1
                    z = np.zeros(Zres.shape[1])
                    z[Zset] = 1
                    xsets.append(x[None,:])
                    zsets.append(z[None,:])
                    display(test)
                    display(point)
            except FileNotFoundError:
                pass
        if len(new_tests) > 0:
            plt.subplots(figsize=(10,10))
            plt.imshow(np.concatenate(xsets))
            plt.show()
            plt.subplots(figsize=(10,10))
            plt.imshow(np.concatenate(zsets))
            plt.show()
            all_data[f'{D_label}_Y={Y_label}'] = {'tests': [og_test]+new_tests, 
                                    'point': [og_point] + new_points, 'xsets': xsets,
                                                 'zsets':zsets}
            print(f'{D_label}_Y={Y_label}', len(new_tests))

In [None]:
test_dfs = []
res_dfs = []
point_dfs = []
model_regression = 'linear'
model_classification = 'linear'
Dbin = False 
Ybin = False
SAVE_PATH = './results/'
XZbin = False
for D_label in ['Black', 'Female', 'Obese','Asian']:
        print(D_label)
        for Y_label in ['OA', 'RA', 'myoc','deprs', 'back']:

            try:
#                     res_model_save = ''
#                     if res_model == 'xgb':
#                         res_model_save = '_xgb'
                save_dir = f'{SAVE_PATH}/D={D_label}_Y={Y_label}/Dbin={Dbin}_Ybin={Ybin}_XZbin={XZbin}_Rgr={model_regression}'
                test_df = pd.read_csv(save_dir + '/table2.csv', header=1, index_col=1)
                test_df = test_df.drop(columns=['0'])
                test_df_flat = test_df.T.unstack().to_frame().sort_index(level=1).T
                test_df_flat.columns = test_df_flat.columns.map('_'.join)
                point_df = pd.read_csv(save_dir + '/table0.csv', header=1, index_col=1)
                point_df = point_df.drop(columns=['0'])                    
                res_df = pd.read_csv(save_dir + '/table1.csv', header=1, index_col=1)
                res_df = res_df.drop(columns=['0'])
                test_df_flat['D_Y'] = point_df['D_Y'] = res_df['D_Y'] = f'{D_label}_{Y_label}'
#                 test_df_flat['res_model'] = point_df['res_model'] = res_df['res_model'] = res_model
                res_dfs.append(res_df)
                point_dfs.append(point_df)
                test_dfs.append(test_df_flat)
            except Exception as e:
                print(e)

point_df = pd.concat(point_dfs)
res_df = pd.concat(res_dfs)
test_df = pd.concat(test_dfs)
test_df = test_df.reindex(sorted(test_df.columns), axis=1)
point_df

In [None]:

def zset_trial2(it, remnantX, verbose):
    ''' We try to add elements to the X's in random order, while maintaining that the dual
    violation is not violated. Here we use all the Z's, since the dual violation can only
    improve if we add more Z's.
    '''
    np.random.seed(it)
    unusedZ = np.arange(Zres.shape[1])
    remnantZ = []
    while len(unusedZ) > 0:
        next = np.random.choice(len(unusedZ), size=1)[0]
        dv, pv = violation(remnantX, remnantZ + [unusedZ[next]])
        if pv < 0.1 * pv_bench:
            remnantZ += [unusedZ[next]]
        unusedZ = np.delete(unusedZ, next)

    if remnantZ:
        remnantZ = np.sort(remnantZ)
        if verbose:
            print(remnantZ, violation(remnantX, remnantZ))
    
        ohe = np.zeros(Zres.shape[1]).astype(int)
        ohe[remnantZ] = 1
        return ohe
    else:
        return None

def xset_trial2(it, remnantZ, verbose):
    ''' Given a candidate X set, we try to add elements to the Z's in random order,
    while maintaining that the primal violation does not occur.
    '''
    np.random.seed(it)
    
    unusedX = np.arange(Xres.shape[1])
    remnantX = []
    while len(unusedX) > 0:
        next = np.random.choice(len(unusedX), size=1)[0]
        dv, pv = violation(remnantX + [unusedX[next]],remnantZ)
        if dv < .1 * dv_bench:
            remnantX += [unusedX[next]]
        unusedX = np.delete(unusedX, next)

    if remnantX:
        remnantX = np.sort(remnantX)
    
        dv, pv = violation(remnantX, remnantZ)
        if verbose:
            print(remnantX, remnantZ, dv, pv)
    
        ohe = np.zeros(Xres.shape[1] + Zres.shape[1]).astype(int)
        ohe[remnantX] = 1
        ohe[Xres.shape[1] + remnantZ] = 1
        return ohe
    else:
        return None

    
    

from tqdm import tqdm
def find_candidate_sets_Xfirst(ntrials, verbose=0, n_jobs=-1, gen_next='random'):
    unique_Xsets = np.array([np.ones(Xres.shape[1])]).astype(int)

    for _ in range(5):
        # we generate a set of candidate of maximal X sets such that the dual violation does not
        # occur, when we use all the Z's. Note that more Z's can only help the dual.
        candidateZ = []
        for remnantX in unique_Xsets:
            remnantX = np.argwhere(remnantX).flatten()
            candidateZ += Parallel(n_jobs=n_jobs, verbose=3)(delayed(zset_trial2)(it, remnantX, verbose)
                                                             for it in range(ntrials))
        candidateZ = [c for c in candidateZ if c is not None]

        if not candidateZ:
            return []

        candidateZ = np.array(candidateZ).astype(int)
        # we clean up to keep only the unique solutions
        unique_Zsets = np.unique(candidateZ, axis=0)
    
        candidateXZ = []
        # for each unique candidate solution of X's
        for remnantZ in tqdm(unique_Zsets):
            remnantZ = np.argwhere(remnantZ).flatten()
            # we try to construct maximal sets of Z's, such that the primal violation
            # does not occur. Note that more X's can only help the primal, which is why
            # we tried to build maximal X's in the first place.
            candidateXZ += Parallel(n_jobs=n_jobs, verbose=3)(delayed(xset_trial2)(it, remnantZ, verbose)
                                                              for it in range(ntrials))
        candidateXZ = [c for c in candidateXZ if c is not None]

        if not candidateXZ:
            return []

        # this array now contains the one-hot-encodings of the Xset and the Zset (concatenated)
        candidateXZ = np.array(candidateXZ).astype(int)
        # we clean up to keep only unique Zset solutions
        ##THIS MIGHT BE WRONG
        unique_Xsets = np.unique(candidateXZ[:, :Xres.shape[1]], axis=0)

    # we clean up to keep only unique pairs of solutions
    unique_XZsets = np.unique(candidateXZ, axis=0)
    # we transform the one hot encodings back to member sets
    final_candidates = []
    for unique_XZ in unique_XZsets:
        Xset = np.argwhere(unique_XZ[:Xres.shape[1]]).flatten()
        Zset = np.argwhere(unique_XZ[Xres.shape[1]:]).flatten()
        dv, pv = violation(Xset, Zset)
        if verbose:
            print(Xset, Zset, dv, pv)
        if pv < 0.1 * pv_bench and dv < 0.1 * dv_bench:
            final_candidates += [(Xset, Zset)]
    return final_candidates

In [None]:
np.random.seed(0)
for i in idx_list = np.random.choice(np.arange(len(candidates)), size=2)
for Xset, Zset in candidates[idx_list]:
    print("Xset =", Xset)
    print("Zset =", Zset)
    print()
    est = ProximalDE(random_state=3)
    est.fit(None, Dres, Zres[:, Zset], Xres[:, Xset], Yres)
    t = est.summary().tables[2]
    df = pd.DataFrame.from_records(t.data)
    header = df.iloc[0] # grab the first row for the header
    df = df[1:] # take the data less the header row
    df.columns = header
    df['pass test'] = df['pass test'].map(lambda x: x == 'True')
    if df['pass test'].all():
        display(est.summary().tables[0], est.summary().tables[2])
        print("Xset =", Xset)
        print("Zset =", Zset)