# Main Logic

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from proximalde.ukbb_data_utils import load_ukbb_data
from proximalde.ukbb_proximal import ProximalDE_UKBB
from proximalde.proximal import ProximalDE
pd.options.display.max_columns = None
from tqdm import tqdm 
import os 
import pickle as pk
from proximalde.ukbb_data_utils import *

SAVE_PATH = './results/'


In [None]:
np.random.seed(30)
dfs = []
D_labels = ['Black', 'Female', 'Obese', 'Asian', 'On_dis', "No_uni", 'Low_inc', 'No_priv_insr']
Y_labels = ['OA', 'myoc','deprs', 'back', 'RA', 'fibro', 'infl', 'copd','chrkd','mgrn','mela', 'preg', 'endo']
for D_label in D_labels:
    for Y_label in Y_labels:
        dir = f'./results/D={D_label}_Y={Y_label}/Dbin=False_Ybin=False_XZbin=False_Rgr=linear/'
        if not os.path.exists(dir + '/table2.csv'):
            continue
        else:
            test_df = pd.read_csv(dir + '/table2.csv',header=1, index_col=1)
            point_df = pd.read_csv(dir + '/table0.csv',header=1, index_col=1)
            point_df['ci'] = point_df.stderr * 1.96
            test_df['stat,crit'] = test_df.apply(lambda x: f"{round(x.statistic,1)},{round(x['critical value'],1)}", axis=1)
            test_df = test_df[['stat,crit']]
            test_df.index=['id','primal','dual', 'weakIV']
            test_df_flat = test_df.T.unstack().to_frame().sort_index(level=1).T
            test_df_flat.columns = test_df_flat.columns.map('_'.join)
            point_df = point_df[['point', 'ci']]                  
            df = pd.concat([point_df.reset_index(), test_df_flat.reset_index()],axis=1)
            point = df[[c for c in df.columns if c != 'index']]
            point['dy'] = f'{D_label}_{Y_label}'
            dfs.append(point)
df = pd.concat(dfs)            
def statcrit_fn(df):
    direction = {'dual':'<', 'primal':'<', 'id':'>', 'weakIV':'>'}
yy={'OA': 'Osteoarthritis', 'back': 'Back pain', 'deprs': 'Depression', 'myoc': "Heart disease", 'RA': 'Rh. Arthritis', 'fibro': 'Fibromyalgia', 'chrkd': 'Chronic kidney disease'}
dd={'Low_inc': 'Low Income','Obese':'Obese', 'Female': 'Female', 'Black': 'Black', 'Asian': "Asian", 'On_dis': 'Disability insurance'}
df['D_Y'] = df.dy.map(lambda x: dd['_'.join(x.split('_')[:-1])] + ', ' + yy[x.split('_')[-1]])
df['dual_stat,crit'] = df['dual_stat,crit'].map(lambda x: f"{round(float(x.split(',')[0]),1)}<{round(float(x.split(',')[-1]),1)}")
df['primal_stat,crit'] = df['primal_stat,crit'].map(lambda x: f"{round(float(x.split(',')[0]),1)}<{round(float(x.split(',')[-1]),1)}")
df['id_stat,crit'] = df['id_stat,crit'].map(lambda x: f"{round(float(x.split(',')[0]),1)}>{round(float(x.split(',')[-1]),1)}")
df['weakIV_stat,crit'] = df['weakIV_stat,crit'].map(lambda x: f"{round(float(x.split(',')[0]),1)}>{round(float(x.split(',')[-1]),1)}")
df.point = [f'{round(x,2)}$\pm${round(y,2)}' for x,y in zip(df.point,df.ci)]
df =df[['D_Y', 'point','primal_stat,crit', 'dual_stat,crit', 'id_stat,crit', 'weakIV_stat,crit']]
df.to_csv('all_norm_7metrics.csv',index=False)

In [None]:
np.random.seed(30)
all_data = {}
D_labels = ['Black', 'Female', 'Obese', 'Asian', 'On_dis', "No_uni", 'Low_inc', 'No_priv_insr']
Y_labels = ['OA', 'myoc','deprs', 'back', 'RA', 'fibro', 'infl', 'copd','chrkd','mgrn','mela', 'preg', 'endo']
for D_label in D_labels:
    for Y_label in Y_labels:
        dir = f'./results/D={D_label}_Y={Y_label}/Dbin=False_Ybin=False_XZbin=False_Rgr=linear/'
        if not os.path.exists(dir + '/table2.csv'):
            continue
        else:
            print(dir + ' exists')
            og_test = pd.read_csv(dir + '/table2.csv')
            og_point = pd.read_csv(dir + '/table0.csv')
        
#         candidates = pk.load(open(f'./{D_label}_{Y_label}_candidates_reweight_acually.pkl', 'rb'))
#         new_tests = []
#         new_points = []
#         saved_cand=[]
#         for idx in np.arange(len(candidates)):
#             try:
#                 fname = f'./results/proxyrm/{D_label}_{Y_label}_{idx}_random/'
#                 point = pd.read_csv(f'{fname}/table0.csv', header=1, index_col=1)
#                 test = pd.read_csv(f'{fname}/table2csv', header=1, index_col=1)
#                 if test['pass test'].sum() > 2:
#                     new_tests.append(test)
#                     new_points.append(point)
#                     saved_cand.append((Xset, Zset))
#             except FileNotFoundError:
#                 pass
#         if len(new_tests) > 0:
#             all_data[f'{D_label}_Y={Y_label}'] = {'tests': [og_test]+all_tests, 'point': [og_point] + new_points, 'cand': saved_cand}
#             print(f'{D_label}_Y={Y_label}', len(new_tests))

In [None]:

for regr, clsf, D_label, Y_label, Dbin, Ybin, XZbin in product(['linear', 'xgb'], ['xgb'], 
                                            ['Black', 'Female', 'Obese', 'Asian'], 
                                            ['OA', 'RA', 'myoc', 'copd', 'deprs', 'back'],
                                           [True, False],
                                           [True,False],
                                           [True,False]):
        try:
            if np.array([Dbin,Ybin,XZbin]).any():
                clsf = f'_Cls={clsf}'
            else:
                clsf = ''
            save_dir = f'{SAVE_PATH}/D={D_label}_Y={Y_label}/Dbin={Dbin}_Ybin={Ybin}_XZbin={XZbin}_Rgr={regr}{clsf}'
            
            W, _, W_feats, X, X_binary, X_feats, Z, Z_binary, Z_feats, Y, D = load_ukbb_data(D_label=D_label, Y_label=Y_label)
        
            print(save_dir)
            if not os.path.exists(save_dir):
                os.mkdir(save_dir)
            if not os.path.exists(save_dir + '/table1.csv'):
                np.random.seed(4)
                est = ProximalDE_UKBB(model_regression=regr, 
                                      model_classification='xgb',
                                      binary_D=Dbin, binary_Y=Ybin,
                                      binary_X=X_binary if XZbin else [], binary_Z=Z_binary if XZbin else [], 
                                      semi=True, cv=3, verbose=1, random_state=3)

                est.fit(W, D, Z, X, Y, D_label=D_label, Y_label=Y_label, save_fname_addn='')                
                sm = est.summary(decimals=5, save_dir=save_dir)
                print(sm.tables[0])
                print(sm.tables[1])
                print(sm.tables[2])
        except Exception as e:
            print(e)

#             if Y_label == 'OA':
#                 svalues, svalues_crit = est.covariance_rank_test(calculate_critical=True)
#                 np.save(save_dir + '/covrank_test.npy', np.concatenate([svalues, [svalues_crit]]))

## Analyze results 

### Load all data into a single dataframe

In [None]:
dfs = []
from itertools import product

for regr, clsf, D_label, Y_label, Dbin, Ybin, XZbin in product(['linear', 'xgb'], ['linear', 'xgb'], 
                                            ['Black', 'Female', 'Obese', 'Asian'], 
                                            ['OA', 'RA', 'myoc', 'copd', 'deprs', 'back'],
                                           [True, False],
                                           [True, False],
                                           [True,False]):
        try:
            if np.array([Dbin,Ybin,XZbin]).any():
                clsf = f'_Cls={clsf}'
            else:
                clsf = ''
            save_dir = f'{SAVE_PATH}/D={D_label}_Y={Y_label}/Dbin={Dbin}_Ybin={Ybin}_XZbin={XZbin}_Rgr={regr}{clsf}'
            test_df = pd.read_csv(save_dir + '/table2.csv', header=1, index_col=1)
            test_df = test_df.drop(columns=['0'])
            test_df_flat = test_df.T.unstack().to_frame().sort_index(level=1).T
            test_df_flat.columns = test_df_flat.columns.map('_'.join)
            point_df = pd.read_csv(save_dir + '/table0.csv', header=1, index_col=1)
            point_df = point_df.drop(columns=['0'])                    
            res_df = pd.read_csv(save_dir + '/table1.csv', header=1, index_col=1)
            res_df = res_df.drop(columns=['0'])
            df = pd.concat([point_df.reset_index(), test_df_flat.reset_index(), res_df.reset_index()],axis=1)
            df = df[[c for c in df.columns if c != 'index']]
            df['D_Y'] = f'{D_label}_{Y_label}'
            df['hparams'] = f'Dbin={Dbin}_Ybin={Ybin}_XZbin={XZbin}_Rgr={regr}{clsf}'
            dfs.append(df)
        except Exception as e:
            print(e)


df = pd.concat(dfs,axis=0)
# res_df = pd.concat(res_dfs)
# test_df = pd.concat(test_dfs)
# test_df = test_df.reindex(sorted(test_df.columns), axis=1)
# point_df.to_csv('./results/all_point_est.csv')
# test_df.to_csv('./results/all_tests.csv')
# res_df.to_csv('./results/all_res.csv')

In [None]:
df

In [None]:
import seaborn as sns
sns.set_theme()
ss_DY = df[(np.sign(df.ci_lower) == np.sign(df.ci_upper)) & (np.abs(df.point) > .05)].D_Y.unique()
df = df[df.D_Y.isin(ss_DY)]
point_df[(np.sign(point_df.ci_lower) == np.sign(point_df.ci_upper)) & (np.abs(point_df.point)>.05)]
for y in ['point', 'stderr'] + [c for c in df.columns if ('statistic' in c) | ('critical' in c) | ('r2' in c)]:
    plt.subplots(figsize=(30,6))
    sns.barplot(data=df, hue='hparams', x='D_Y',y=y)
    plt.grid(True)
    plt.title(f"Comparing {y}")
    plt.legend(loc='upper right')
    plt.show()

In [None]:
ss_DY = point_df[(np.sign(point_df.ci_lower) == np.sign(point_df.ci_upper)) & (np.abs(point_df.point)>.05)].D_Y.unique()
point_df[(np.sign(point_df.ci_lower) == np.sign(point_df.ci_upper)) & (np.abs(point_df.point)>.05)]

In [None]:
test_df = test_df.reindex(sorted(test_df.columns), axis=1)
test_df[test_df.D_Y.isin(ss_DY)] 

### Visualize rank test results

In [None]:
for D_label in ['Black', 'Female', 'Obese','Asian']:
    print(D_label)
    for Y_label in ['OA']:
        covrank_data = np.load(f'{SAVE_PATH}/ivreg=adv_dual=Z_D={D_label}_Y={Y_label}/covrank_test.npy')
        svalues, svalues_crit = covrank_data[:-1], covrank_data[-1]
        plt.title(f"D={D_label}_Y={Y_label}\nNumber of singular values above threshold: {np.sum(svalues >= svalues_crit)}. "
                  f"\nThreshold={svalues_crit:.3f}. Top singular value={svalues[0]:.3f}")
        plt.scatter(np.arange(len(svalues)), svalues)
        plt.axhline(svalues_crit)
        plt.show()

### Other diagnostics

In [None]:
refit_rm_inf = False 
for D_Y in ss_DY:
    D_label, Y_label = D_Y.split('_')
    W, W_feats, X, X_feats, Z, Z_feats, Y, D = load_ukbb_data(D_label=D_label, Y_label=Y_label)
    save_dir = f'{SAVE_PATH}/ivreg={ivreg_type}_dual={dual_type}_D={D_label}_Y={Y_label}'
    print(save_dir)
    np.random.seed(4)
    est = ProximalDE_UKBB(cv=3, semi=True, dual_type=dual_type, ivreg_type=ivreg_type,
                         multitask=False, n_jobs=-1, random_state=3, verbose=1)
    est.fit(W, D, Z, X, Y, D_label=D_label, Y_label=Y_label)    
    diag = est.run_diagnostics()
    inds = est.influential_set(alpha=0.05)
    print(len(inds))
    diag.cookd_plot()
    plt.title(save_dir)
    plt.show()
    diag.l2influence_plot()
    plt.title(save_dir)
    plt.show()
    diag.influence_plot(influence_measure='cook', npoints=10)
    plt.title(save_dir)
    plt.show()
    
    if refit_rm_inf == True:
        from sklearn.base import clone
        np.random.seed(4)
        est2 = clone(est)
        est2.fit(np.delete(W, inds, axis=0), np.delete(D, inds, axis=0),
                 np.delete(Z, inds, axis=0), np.delete(X, inds, axis=0),
                 np.delete(Y, inds, axis=0),D_label=D_label, Y_label=Y_label, save_fname_addn=f'_rmInf{dual_type}')
        est2.summary(alpha=0.05)
    print()
    print()

In [None]:
from proximalde.ukbb_data_utils import *
# UKBB_DATA_DIR = '/oak/stanford/groups/rbaltman/karaliu/bias_detection/cohort_creation/data/'
    
# potD_fids = np.load(f'{UKBB_DATA_DIR}/potD_feats.npy')
# potD_intfids = list(get_int_feats(np.load(f'{UKBB_DATA_DIR}/potD_feats.npy')[:-4])) + ['Race=Asian','Race=Black']
potD_binary = [True]*len(potD_intfids)

# Wfeats = get_int_feats(np.load(f'{UKBB_DATA_DIR}/dem_feats_rd.npy'))
Wbinary_=is_matrix_binary(np.load(f'{UKBB_DATA_DIR}/dem_data_rd.npy'))
# Wfeats = np.concatenate([Wfeats,potD_intfids])
Wbinary = np.concatenate([Wbinary_,potD_binary])

In [None]:
from proximalde.ukbb_data_utils import _load_ukbb_data
def gen_feat_csv_for_paper(isbinary,fids=None, intfids=None,rm_na=False):
    if intfids is None:
        intfids =get_int_feats(fids)
    if rm_na:
        bad_idx = np.array([('Do not know' in x) or ('Prefer not to' in x) for x in intfids])
        intfids, isbinary = intfids[~bad_idx], isbinary[~bad_idx]
    return pd.DataFrame({'full_feat':intfids, 'feat':[x.split('=')[0] for x in intfids], 'var':['Categorical' if x else 'Continuous' for x in isbinary]})

# Z, Z_binary, Z_feats = _load_ukbb_data(fname = 'srMntSlp')
# X, X_binary, X_feats = _load_ukbb_data(fname = 'biomMed')
# df = gen_feat_csv_for_paper(fids=Z_feats,isbinary=Z_binary,rm_na=True)
# df.groupby(['feat', 'var']).size().reset_index(name='count').to_csv('Zcsv.csv')
# df = gen_feat_csv_for_paper(fids=X_feats,isbinary=X_binary,rm_na=True)
# df.groupby(['feat', 'var']).size().reset_index(name='count').to_csv('Xcsv.csv')


In [None]:
df['count'].sum()

In [None]:
df = gen_feat_csv_for_paper(intfids=Wfeats,isbinary=Wbinary,rm_na=False)
df = df.groupby(['feat', 'var']).size().reset_index(name='count')
df.to_csv('Wcsv.csv')
df