In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import warnings
warnings.filterwarnings('ignore')

In [3]:
import matplotlib
font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 14}

matplotlib.rc('font', **font)

In [4]:
save_plots = False

In [5]:
# helpers
def get_col_names(results_lines,mode):
    breaks = []
    for i,line in enumerate(results_lines):
        if(line==''):
            breaks.append(i)
    if(mode=='train'):
        start_idx = 3
    elif(mode=='val'):
        start_idx = breaks[0]+2
    elif(mode=='test'):
        start_idx = breaks[1]+2
    else:
        raise Exception('mode must be train,val, or test')
        
    metrics = results_lines[start_idx].split('|')[1:]
    col_names = [m.strip() for m in metrics]
    col_names_all = []
    for i,name in enumerate(col_names):
        col_names_all.append(name)
        col_names_all.append(name+'_stderr')
    col_names_all.insert(0,'idx')
    return col_names_all

def process_result(result):
    result_processed = []
    result_arr = result.split('|')
    result_processed.append(int(result_arr[0].strip()))
    for r in result_arr[1:]:
        arr = r.split('(')
        result_processed.append(float(arr[0].strip()))
        result_processed.append(float(arr[1].split(')')[0]))
    return result_processed

def make_df(col_names_all,results):
    d = []
    for result in results:
        arr = process_result(result)
        d.append(arr)
    return pd.DataFrame(d,columns=col_names_all)

def process_used_configs_file(lines):
    df = pd.DataFrame()
    for line in lines:
        pairs = line.split(',')
        d = {}
        for pair in pairs:
            key,val = pair.split(':')
            d[key] = val
        df = df.append(d,ignore_index=True)
        
    return df

def process_config(config):
    arr = config.split(',')
    d = {}
    for el in arr:
        k,v = el.split('=')
        try:
            v = float(v)
        except:
            v = v
        d[k.strip()] = v
    return d

In [6]:
def txt_to_df(results_name,configs_name):
    results_file = open(results_name,'r+')
    results_lines= results_file.read().splitlines()
    
    breaks = []
    for i,line in enumerate(results_lines):
        if(line==''):
            breaks.append(i)

    train_idx = range(5,breaks[0])
    val_idx = range(breaks[0]+4,breaks[1])
    test_idx = range(breaks[1]+4,breaks[2])

    train_results = [results_lines[i] for i in train_idx]
    val_results = [results_lines[i] for i in val_idx]
    test_results = [results_lines[i] for i in test_idx]
    
    col_names_train = get_col_names(results_lines,'train')
    col_names_val = get_col_names(results_lines,'val')
    col_names_test = get_col_names(results_lines,'test')

    train_df = make_df(col_names_train,train_results)
    val_df = make_df(col_names_val,val_results)
    test_df = make_df(col_names_test,test_results)
    
    configs_file = open(configs_name,'r+')
    configs_lines= configs_file.read().splitlines()
    l = []
    for config in configs_lines:
        l.append(list(process_config(config).values()))
        
    l = list(map(list, zip(*l)))
    col_names = list(process_config(configs_lines[0]).keys())
    for i,el in enumerate(l):
        train_df[col_names[i]] = el
        val_df[col_names[i]] = el
        test_df[col_names[i]] = el
    return train_df,val_df,test_df

# make IHDP dataframes

In [7]:
exp_name = 'ihdp100_neurips'
plot_dir = exp_name+'_plots'
if(save_plots):
    if (not os.path.isdir(plot_dir)):
        os.mkdir(plot_dir)
        
results_name = "/media/common/"+exp_name+"/results_summary.txt"
configs_name = "/media/common/"+ exp_name+"/configs_sorted.txt"

train_df,val_df,test_df = txt_to_df(results_name,configs_name)

# Best val PEHE_NN for each weight scheme

In [10]:
weight_schemes = ['IPW','OW','MW','TruncIPW']
argmins = [np.argmin(val_df[val_df['weight_scheme']==w]['Pehe_nn']) for w in weight_schemes]
test_df.iloc[argmins][['Pehe','Pehe_stderr','weight_scheme']]

Unnamed: 0,Pehe,Pehe_stderr,weight_scheme
1,0.769,0.053,IPW
11,0.66,0.063,OW
8,0.659,0.063,MW
4,0.647,0.046,TruncIPW


# Save these configs & run on IHDP1000

In [11]:
hyperparams = ['n_prop','imb_fun','n_in','dim_prop','dim_in','p_alpha','dim_out','n_out','weight_scheme']
best_df = test_df.iloc[argmins][hyperparams]

In [12]:
arrs = []
for i,row in best_df.iterrows():
    out = []
    for k,val in row.iteritems():
        if(k=='len'):
            continue
        if(type(val)==float):
            if(val.is_integer()):
                val = int(val)
        if(type(val)==str):
            val = [val]
        out.append(k+'='+str(val))
    with open('configs/neurips/ihdp100.txt','r') as f:
    #     with open('configs/neurips/missing_ihdp/missing_{}.txt'.format(i),'w') as f:
        lines = f.readlines()
        for line in lines:
            var = line.split('=')[0]
            if(var in best_df):
                continue
            elif(var=='experiments'):
                out.append('experiments=1000')
            elif(var=='outdir'):
                out.append("outdir='/media/common/ihdp1000_neurips/{}".format(row['weight_scheme'])+"/'")
            elif(var=='datadir'):
                out.append("datadir='../datasets/IHDP1000/'")
            elif(var=='dataform'):
                out.append("dataform='ihdp_npci_1-1000.train.npz'")
            elif(var=='data_test'):
                out.append("data_test='ihdp_npci_1-1000.test.npz'")
            else:
#                 print(line.strip())
                out.append(line.strip())
    arrs.append(out)

In [13]:
savedir = 'configs/neurips/ihdp1000/'
weights = np.array(best_df['weight_scheme'])
for i,arr in enumerate(arrs):
    savepath = os.path.join(savedir,'{}.txt'.format(weights[i]))
    with open(savepath,'w') as outfile:
        outfile.write("\n".join(arr))

# Process IHDP1000 results

In [14]:
import sys
exp_name = 'ihdp1000_neurips'

weight_schemes = ['IPW','MW','OW','TruncIPW']
train_all = []
val_all = []
test_all = []
for w in weight_schemes:
    
    results_name = os.path.join('/media/common/',exp_name,w,'results_summary.txt')
    configs_name = os.path.join('/media/common/',exp_name,w,'configs_sorted.txt')
    results_file = open(results_name,'r+')
    results_lines= results_file.read().splitlines()
    breaks = []
    for i,line in enumerate(results_lines):
        if(line==''):
            breaks.append(i)

    train_idx = range(5,breaks[0])
    val_idx = range(breaks[0]+4,breaks[1])
    test_idx = range(breaks[1]+4,breaks[2])

    train_results = [results_lines[i] for i in train_idx]
    val_results = [results_lines[i] for i in val_idx]
    test_results = [results_lines[i] for i in test_idx]
    
    col_names_train = get_col_names(results_lines,'train')
    col_names_val = get_col_names(results_lines,'val')
    col_names_test = get_col_names(results_lines,'test')
    
    
    train = make_df(col_names_train,train_results)
    val = make_df(col_names_val,val_results)
    test = make_df(col_names_test,test_results)
    train['weight_scheme'] = w
    val['weight_scheme'] = w
    test['weight_scheme'] = w
    train_all.append(train)
    val_all.append(val)
    test_all.append(test)
    
train_all = pd.concat(train_all)
val_all = pd.concat(val_all)
test_all = pd.concat(test_all)

In [15]:
results = test_all[['weight_scheme','Pehe','Pehe_stderr','Bias_ate','Bias_ate_stderr']]

In [16]:
results

Unnamed: 0,weight_scheme,Pehe,Pehe_stderr,Bias_ate,Bias_ate_stderr
0,IPW,0.722,0.014,0.205,0.008
0,MW,0.659,0.017,0.176,0.008
0,OW,0.65,0.016,0.176,0.007
0,TruncIPW,0.632,0.013,0.186,0.008


In [17]:
results["$\epsilon_{PEHE}$"] = results["Pehe"].astype(str) + ' $\pm$ '+ results["Pehe_stderr"].astype(str)
results["$\epsilon_{ATE}$"] = results["Bias_ate"].astype(str) + ' $\pm$ '+ results["Bias_ate_stderr"].astype(str)

In [18]:
print(results[['weight_scheme','$\epsilon_{PEHE}$','$\epsilon_{ATE}$']].to_latex(escape=False,index=False))

\begin{tabular}{lll}
\toprule
weight_scheme &  $\epsilon_{PEHE}$ &   $\epsilon_{ATE}$ \\
\midrule
          IPW &  0.722 $\pm$ 0.014 &  0.205 $\pm$ 0.008 \\
           MW &  0.659 $\pm$ 0.017 &  0.176 $\pm$ 0.008 \\
           OW &   0.65 $\pm$ 0.016 &  0.176 $\pm$ 0.007 \\
     TruncIPW &  0.632 $\pm$ 0.013 &  0.186 $\pm$ 0.008 \\
\bottomrule
\end{tabular}

