# Main Logic

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
from joblib import Parallel, delayed
from proximalde.proximal import proximal_direct_effect, ProximalDE, residualizeW
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler
from proximalde.crossfit import fit_predict
from proximalde.utilities import covariance, svd_critical_value
from proximalde.proximal import residualizeW
from proximalde.proxy_rm_utils import *
from proximalde.ukbb_data_utils import *
import seaborn as sns
import pickle as pk

In [None]:
def get_median_item(y):
    srt_idx = np.argsort(y)
    if len(y) % 2 == 0: # if even elements, get max point
        i1 = len(y)//2 - 1
        i2 = len(y)//2
        if np.abs(y[srt_idx[i1]]) > np.abs(y[srt_idx[i2]]):
            return y[srt_idx[i1]], srt_idx[i1]
        else:
            return y[srt_idx[i2]], srt_idx[i2]
    else:
        return y[srt_idx[len(y)//2]], srt_idx[len(y)//2]

X, X_feats, Z, Z_feats = load_ukbb_XZ_data()
Xint = get_int_feats(X_feats)
Zint_ = get_int_feats(Z_feats)
bad_idx = np.array([('Do not know' in x) or ('Prefer not to' in x) for x in Zint_])
ss_dy = pk.load(open('ss_dy_updated_inf.pkl', 'rb'))
ss_dy.keys()

# Inf set heat map 

In [None]:
bin_idxs = []
i=0
labels = []
yy={'mela': 'Melanoma','endo': 'Endometriosis','infl': 'IBD','preg': 'Complications during labor','OA': 'Osteoarthritis','mgrn':'Migraine','copd':'COPD', 'back': 'Back pain', 'deprs': 'Depression', 'myoc': "Heart disease", 'RA': 'Rh. Arthritis', 'fibro': 'Fibromyalgia', 'chrkd': 'Chronic kidney disease'}
dd={'No_priv_insr': 'Not on private insr.','No_uni': 'No p.s. education', 
    'Low_inc': 'Low income','Obese':'Obese', 'Female': 'Female', 'Black': 'Black', 'Asian': "Asian", 
    'On_dis': 'Disability insr.'}

for k in sorted(ss_dy.keys(), key=lambda x: x.split('_')[-1]):
    point, test, inf_dict, path, (Xset, Zset) = ss_dy[k]
    inf = inf_dict['switch_sign']
    if np.abs(point.point.iloc[0])> 0.01:
        bin_idx = np.zeros(502411)
        bin_idx[inf] = 1
        bin_idxs.append(bin_idx[:,None])
        labels.append(f"{dd['_'.join(k.split('_')[:-1])] + ', ' + yy[k.split('_')[-1]]} ({int(bin_idx.sum())})")
        i += 1
    
bin_idxs = np.concatenate(bin_idxs,axis=1)
bin_idxs = bin_idxs[bin_idxs.sum(axis=1)!=0]

sns.heatmap(np.corrcoef(bin_idxs.T),xticklabels=labels, yticklabels=labels,cmap='Blues', annot=False)
plt.show()

plt.subplots(figsize=(9,7),dpi=100)
n = bin_idxs.shape[1]
coeff = np.zeros((n,n))
for i in range(n):
    inf = bin_idxs[:,i].astype(bool)
    for j in range(i, n):
        inf2 = bin_idxs[:,j].astype(bool)
        numer = (inf & inf2).sum()
        denom = min(inf.sum(), inf2.sum())
#         denom = (inf | inf2).sum()
        coeff[i,j] = coeff[j,i] = numer/denom

sns.heatmap(coeff,xticklabels=labels, yticklabels=labels,cmap='Blues', linewidth=.2,annot=False), ss_dy.keys()


In [None]:
D.mean

In [None]:
ss_dy = pk.load(open('ss_dy_updated_inf.pkl', 'rb'))
dy_strs =  ['Low_inc', 'On_dis', 'No_priv_insr', 'No_uni', 'Female', 'Black', 'Obese', 'Asian']+['OA', 'myoc','deprs', 'back', 'RA', 'fibro', 'infl', 'copd','chrkd','mgrn','mela', 'preg', 'endo']
dy = 'Obese_OA'
point, test, inf_dict, path, (Xset, Zset) = ss_dy[dy]
d, y = '_'.join(dy.split('_')[:-1]), dy.split('_')[-1]
W, _, W_feats, X, X_binary, X_feats, Z, Z_binary, Z_feats, Y, D = load_ukbb_data(D_label=d, Y_label=y, pp = False)
Z = Z[:,~bad_idx][:,Zset]
X = X[:,Xset]
Z_feats, X_feats = Z_feats[~bad_idx][Zset], X_feats[Xset]
data = np.concatenate([W,Z,X,D[:,None],Y], axis=1)
feats = np.concatenate([W_feats,Z_feats,X_feats,[d],[y]])
FID = np.concatenate([[x.split('.')[1] for x in feats[:-2]],feats[-2:]])
data.shape, feats.shape, FID.shape

In [None]:
ss_dy = pk.load(open('ss_dy_updated_inf.pkl', 'rb'))
df = {'dy':[], 'D':[], 'Y':[]}
for dy in ss_dy.keys():
    point, test, inf_dict, path, (Xset, Zset) = ss_dy[dy]
    d, y = '_'.join(dy.split('_')[:-1]), dy.split('_')[-1]
    W, _, W_feats, X, X_binary, X_feats, Z, Z_binary, Z_feats, Y, D = load_ukbb_data(D_label=d, Y_label=y, pp = False)
    inf_idxs = inf_dict['switch_sign']
    infY, ninfY = Y[inf_idxs].mean(),Y[np.setdiff1d(np.arange(D.shape[0]), inf_idxs)].mean()
    infD, ninfD = D[inf_idxs].mean(),D[np.setdiff1d(np.arange(D.shape[0]), inf_idxs)].mean()
    df['dy'].append(dy)
    df['Y'].append(f'{round(infY*100)}%, {round(ninfY*100)}%')
    df['D'].append(f'{round(infD*100)}%, {round(ninfD*100)}%')
pd.DataFrame(df)
#     Z = Z[:,~bad_idx][:,Zset]
#     X = X[:,Xset]
#     Z_feats, X_feats = Z_feats[~bad_idx][Zset], X_feats[Xset]
#     data = np.concatenate([W,Z,X,D[:,None],Y], axis=1)
#     feats = np.concatenate([W_feats,Z_feats,X_feats,[d],[y]])
#     FID = np.concatenate([[x.split('.')[1] for x in feats[:-2]],feats[-2:]])
#     data.shape, feats.shape, FID.shape

In [None]:
for k,v in inf_dict.items():
    print(k, v.shape)
inf_idxs = inf_dict['switch_sign']

In [None]:
# First run single categorical association test
cat_assn = []
cat_fids = []
d1 = data[inf_idxs]
d0 = data[np.setdiff1d(np.arange(data.shape[0]), inf_idxs)]
d1.shape, d0.shape


In [None]:
binary_fids = [] # multiple categorical variables - best treated independently
for f in tqdm(np.unique(FID)):
    rel_feat_idx = np.where(f == FID)[0]
    if len(rel_feat_idx) == 1:
        if ((data[:,rel_feat_idx]==0)|(data[:,rel_feat_idx]==1)).all():
            binary_fids.append(f)
    else:
        if f in dy_strs:
            binary_fids.append(f)
            pass
        
        idx1_data = d1[:, rel_feat_idx]
        idx0_data = d0[:, rel_feat_idx]


        # Tests if most data points only have a single response for the feature
        is_single_cat = ((idx1_data.sum(axis=1) > 1).mean() < .1) and ((idx0_data.sum(axis=1) > 1).mean() < .1)
        if is_single_cat:
            chi2, p = cat_sim_test(idx1_data, idx0_data)
            cat_fids.append('.' + f)
            cat_assn.append([chi2, p])
        else:
            # multiple categorical variables - best treated independently
            print(f, ((idx1_data.sum(axis=1) > 1).mean()), ((idx0_data.sum(axis=1) > 1).mean()))
            binary_fids.append(f)

cat_assn_df = get_assn_df(cat_fids, cat_assn)


In [None]:
thresh = .05 / len(feats)
cat_assn_df = get_assn_df(cat_fids, cat_assn).drop_duplicates()
import seaborn as sns
with sns.axes_style("whitegrid"):
    cat_assn_df.fid = cat_assn_df.fid.map(lambda x: x.split('.')[1])
    coding_dict = get_coding_dict(np.unique(cat_assn_df.fid))
    for _, row in cat_assn_df.sort_values(by='pval').iterrows():
        if row.pval < thresh:
            idxs = np.where(row.fid==FID)[0]
            print(row.fid)
            assert len(idxs) > 1        
            class_names = [coding_dict[row.fid][int(i.split('.')[2])] for i in feats[idxs]]
            stat, pval = row.stat, row.pval
            plot_cat_data(d1[:,idxs], d0[:,idxs], 
                          fids=class_names,title=f'{row.fid} {row.names}')#':\nChi2 {round(stat,2)} P={round(pval,6)}')


In [None]:
bin_assn = []
bin_fids = []
cont_assn = []
cont_fids = []
for i,f in tqdm(enumerate(feats)):
    if f in dy_strs or f.split('.')[1] in binary_fids:
        stat, pval = chi2_binary(d1[:,i], d0[:,i])
        bin_assn.append([stat,pval])
        bin_fids.append(f)
    elif '.' + f.split('.')[1] not in cat_fids:
        stat, pval = cont_sim_test(d1[:,i], d0[:,i])
        cont_assn.append([stat,pval])
        cont_fids.append(f)

bin_assn_df =  get_assn_df(bin_fids, bin_assn)
cont_assn_df = get_assn_df(cont_fids, cont_assn)

In [None]:
for _, row in cont_assn_df.sort_values(by='pval').iterrows():
    if row.pval < thresh:
        i = np.where(feats==row.fid)[0]
        print(d1[:,i].mean(), np.median(d1[:,i]))
        print(d0[:,i].mean(), np.median(d0[:,i]))
        plot_cont_data(d1[:,i], d0[:,i], title=f'{row.names}')
        

In [None]:
bin_fids = [f.split('.')[1] for f in bin_assn_df.fid if f.count('.') > 1]
coding_dict = get_coding_dict(np.unique(bin_fids))
for _, row in bin_assn_df.sort_values(by='pval').iterrows():
    if row.pval < thresh:
        print(row)
        if row.fid.count('.') > 1:
            class_id = row.fid.split('.')[2]
            class_name = '\n=' + coding_dict[row.fid.split('.')[1]][int(class_id)]
        else:
            class_name=''
        i = np.where(feats==row.fid)[0]
        stat, pval = row.stat, row.pval
        plot_cat_data(d1[:,i], d0[:,i], fids=[row.fid],title=f'{row.names}{class_name}')#':\nChi2 {round(stat,2)} P={round(pval,6)}')


In [None]:
plt.figure(figsize=(20, 7),dpi=90)
with sns.axes_style('whitegrid'):
    dd={'feat': ['Recurrent severe major depression',  'Recurrent severe major depression', 
              'Felt tense/restless for\nseveral days in last 2 weeks',
              'Felt tense/restless for\nseveral days in last 2 weeks',              
        'Never felt tired/lethargic\nin last 2 weeks',
        'Never felt tired/lethargic\nin last 2 weeks',
        
             'Nervous feelings', 'Nervous feelings',],
     'point': [6,2,41,20,32,51,50,24],
     'group':['High influence patients', 'Typical patients'] * 4}
    sns.set_palette(sns.color_palette(["#B9E5FA", "#219EBC","#FB8500", "#FFB703", "#FAEDCD",  "#3C6E71", ]))
    barplot= sns.barplot(data=dd, x='feat', y='point', hue='group')
    for p in barplot.patches:
        height = p.get_height()
        barplot.text(
            x=p.get_x() + p.get_width() / 2,
            y=height,
            s=f'{int(height)}%',
            ha='center',
            va='bottom',fontsize=22
        )
    plt.ylim(0,60)
    # Add title and labels
    plt.title(f'Feature prevalency for patients with high influence in Obese, Osteoarthritis',fontsize=22)
    plt.xlabel('')
    plt.ylabel('Prevelancy',fontsize=17)
    plt.legend(loc='upper left', fontsize=17)
    sns.set_palette(sns.color_palette(["#B9E5FA", "#219EBC","#FB8500", "#FFB703", "#FAEDCD",  "#3C6E71", ]))
    plt.xticks(fontsize=17)  # Adjust the fontsize value as needed
    plt.yticks(fontsize=17)  # Adjust the fontsize value as needed
    plt.show()

# 'General pain for 3+ months'
# [28, 1]
# 'Unable to work because of sickness or disability'
# [19, 3]
# 'Loneliness, isolation'
# [36, 18]
# 'Obtained university degree'
# [33 19]

In [None]:
def plot_cat_data(dataA, dataB, fids, title='', labelA='High influence points', labelB='Typical points'):

    combined = get_cat_df_from_1hot(dataA, dataB, feats=fids, labelA=labelA, labelB=labelB)
    with sns.axes_style("whitegrid"):

        plt.figure(figsize=(4*len(fids),3), dpi=100)
        relative_freq = combined.groupby(['Dataset', 'Category']).size().reset_index(name='Count')
        total_counts = relative_freq.groupby('Dataset')['Count'].transform('sum')
        relative_freq['Density'] = relative_freq['Count'] / total_counts
        sns.set_palette(sns.color_palette(["#B9E5FA", "#219EBC","#FB8500", "#FFB703", "#FAEDCD",  "#3C6E71", ]))

        barplot= sns.barplot(data=relative_freq, x='Category', y='Density', hue='Dataset')
        for p in barplot.patches:
            height = p.get_height()
            barplot.text(
                x=p.get_x() + p.get_width() / 2,
                y=height,
                s=f'{height*100:.0f}%',
                ha='center',
                va='bottom',
            )

        # Add title and labels
        plt.title(f'{title}')
        plt.xlabel('Category')
        plt.ylabel('Count')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        sns.set_palette(sns.color_palette(["#B9E5FA", "#219EBC","#FB8500", "#FFB703", "#FAEDCD",  "#3C6E71", ]))

        plt.show()

In [None]:
import scipy.stats as stats
from scipy.stats import chi2_contingency, ks_2samp
# from utils import * 
import numpy as np
import pandas as pd 
import seaborn as sns 
import matplotlib.pyplot as plt
def get_names():
    names = [] 
    for var in ['sc', 'mc', 'intg', 'cont']: 
        names.append(pd.read_csv(f'/oak/stanford/groups/rbaltman/karaliu/bias_detection/cohort_creation/helper_data/{var}.csv', sep='\t'))
        names[-1]['var'] = var
    return pd.concat(names)
names = get_names()
def one_hot_to_category(data, feats):
    df = pd.DataFrame(data, columns=feats)
    return df, df.idxmax(axis=1).str.replace('Category_', '')
       
def get_cat_df_from_1hot(dataA, dataB, feats=None, labelA='outlier', labelB='typical'):
    assert dataA.shape[1] == dataB.shape[1]
    if dataA.shape[1] == 1:
        A = pd.DataFrame({'Category': dataA.squeeze().astype(str)})
        B = pd.DataFrame({'Category': dataB.squeeze().astype(str)})
    else:
        if feats is None:
            feats = np.arange(dataA.shape[1]).astype(str)
        assert len(feats) == dataA.shape[1]
        if dataB.shape[0] < 100 or dataA.shape[0] < 100:
            print("Warning: sparse sample size")
        A, cat = one_hot_to_category(dataA, feats)
        A['Category']= cat
        B, cat = one_hot_to_category(dataB, feats)
        B['Category'] = cat

    # Add a column to distinguish the datasets
    A['Dataset'] = labelA
    B['Dataset'] = labelB
    # Combine the datasets
    combined = pd.concat([A, B])
    return combined

def plot_cat_data(dataA, dataB, fids, title='', labelA='High influence points', labelB='Typical points'):

    combined = get_cat_df_from_1hot(dataA, dataB, feats=fids, labelA=labelA, labelB=labelB)
    with sns.axes_style("whitegrid"):

        plt.figure(figsize=(3*len(fids), 3), dpi=120)
        relative_freq = combined.groupby(['Dataset', 'Category']).size().reset_index(name='Count')
        total_counts = relative_freq.groupby('Dataset')['Count'].transform('sum')
        relative_freq['Density'] = relative_freq['Count'] / total_counts
        sns.set_palette(sns.color_palette(["#B9E5FA", "#219EBC","#FB8500", "#FFB703", "#FAEDCD",  "#3C6E71", ]))

        barplot= sns.barplot(data=relative_freq, x='Category', y='Density', hue='Dataset')
        for p in barplot.patches:
            height = p.get_height()
            barplot.text(
                x=p.get_x() + p.get_width() / 2,
                y=height,
                s=f'{height*100:.0f}%',
                ha='center',
                va='bottom'
            )

        # Add title and labels
        plt.title(f'{title}')
        plt.xlabel('Category')
        plt.ylabel('Count')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        sns.set_palette(sns.color_palette(["#B9E5FA", "#219EBC","#FB8500", "#FFB703", "#FAEDCD",  "#3C6E71", ]))

        plt.show()
                  
def cat_sim_test(dataA, dataB):
    combined = get_cat_df_from_1hot(dataA,dataB)
    contingency_table = pd.crosstab(combined['Dataset'], combined['Category'])
    chi2, p, dof, expected = stats.chi2_contingency(contingency_table)
    return chi2, p

             
def get_assn_df(fids, assn):
    assn = np.array(assn)
    names = [get_dscr_from_fid(f) for f in fids] 
    df = pd.DataFrame({'pval': np.array(assn)[:,1], 
                                'stat':np.array(assn)[:,0],
                                'fid':fids, 'names': names})
    return df.sort_values(by='stat')

def get_dscr_from_fid(f):
    """ f must be a valid string D, Y or of the format f.{} """
    assert f in dy_strs or  f.count('.') > 0, f"Invalid feature {f}"
    try:
        return names[names['Field ID'].astype(int) == int(f.split('.')[1])].Description.iloc[0]
    except IndexError as e:
        print(f"Didn't find {f} in names")
        return f

def chi2_binary(vector1, vector2):
    count_0_vector1 = np.sum(vector1 == 0)
    count_1_vector1 = np.sum(vector1 == 1)
    count_0_vector2 = np.sum(vector2 == 0)
    count_1_vector2 = np.sum(vector2 == 1)

    contingency_table = np.array([[count_0_vector1, count_1_vector1],
                                  [count_0_vector2, count_1_vector2]])
    #pval = 1 when same, chi2 = 0 when same
    return chi2_contingency(contingency_table)[:2]


def cont_sim_test(dataA, dataB, rm_nan=True):
    dataA = dataA.squeeze()
    dataB = dataB.squeeze()
    assert len(dataA.shape)==1
    assert len(dataB.shape)==1
    if rm_nan: 
        vecA, vecB = rm_nan_cont(dataA, dataB)
    return ks_2samp(vecA, vecB)

def plot_cont_data(dataA, dataB, title='', rm_nan=True, labelA='outlier', labelB='typical'):
    dataA = dataA.squeeze()
    dataB = dataB.squeeze()
    assert len(dataA.shape)==1
    assert len(dataB.shape)==1
    if rm_nan: 
        vecA, vecB = rm_nan_cont(dataA, dataB)
    else:
        vecA, vecB = dataA, dataB
    sns.kdeplot(vecA, fill=True,label=labelA)    
    sns.kdeplot(vecB, fill=True,label=labelB, bw_adjust=2)
    perc = max(np.percentile(vecA,99.9), np.percentile(vecB,99.9))
    if (perc - vecB.mean())/vecB.std() > 3:
        plt.xlim(right=perc)

    # sns.histplot(vecA, fill=True,label=labelA)    
    # sns.histplot(vecB, fill=True,label=labelB)
    plt.title(title)
    plt.legend()
    plt.show()
def rm_nan_cont(v1, v2):
    """
    Removes nan values from vectors but bc we mean-imputed, it just
    removes the mode of the longest vector"""
    vals, counts = np.unique(v1, return_counts=True)
    i = np.argmax(counts)
    nan1 = vals[i]
    vals, counts = np.unique(v2, return_counts=True)
    i = np.argmax(counts)
    nan = nan2 = vals[i]
    if int(nan2) == nan2:
        print("Found discrete mode. Ignoring..")
        return v1, v2
    elif nan1 != nan2:
        print("Nans not the same", nan1, nan2)
        if len(v1) > len(v2):
            nan = nan1
        else:
            nan = nan2
    eps = 1e-5
    v1 = v1[np.abs(v1 - nan) > eps]
    v2 = v2[np.abs(v2 - nan) > eps]

    return v1, v2