In [4]:
import glob
from math import ceil
import os
from pprint import pprint
import re

import pandas as pd
import matplotlib.pyplot as plt

In [5]:
results_dir = "../domain_shift_results"

# experiments still in process; ignore any existing results for these.
in_process_expts = [
    "acsfoodstamps_region",
    "acsincome_region",
    "acspubcov_year",
    "acsunemployment_year",
    'diabetes_admtype',  # not in progress - we just want to exclude from results (removed from benchmark)
    'mimic_extract_los_3_ins',
    'mooc_course'
]

In [6]:
def extract_task_from_filepath(fp:str)->str:
    task = re.search(".*/domain_shift_results/(\w+)/.*", fp).group(1)
    return task

In [7]:
files = []

for expt in os.listdir(results_dir):
    if expt in in_process_expts:
        print(f"skipping in progress expt at {expt}")
        continue
    wc = os.path.join(results_dir, expt, "**", "**_full.csv")
    full_results = glob.glob(wc)
    if full_results:
        most_recent_result = full_results[-1]
        print(f"got recent result file for expt {expt}")
        files.append(most_recent_result)

got recent result file for expt anes_year
skipping in progress expt at acsfoodstamps_region
got recent result file for expt brfss_diabetes_race
skipping in progress expt at acsincome_region
got recent result file for expt anes_region
got recent result file for expt nhanes_lead_poverty
skipping in progress expt at mimic_extract_los_3_ins
skipping in progress expt at mooc_course
got recent result file for expt physionet_set
got recent result file for expt diabetes_admsrc
skipping in progress expt at diabetes_admtype
skipping in progress expt at acspubcov_year
got recent result file for expt nhanes_cholesterol_race
skipping in progress expt at acsunemployment_year
got recent result file for expt mimic_extract_mort_hosp_ins
got recent result file for expt brfss_blood_pressure_income


In [8]:
dfs = []
for f in files:
    df = pd.read_csv(f)
    task = extract_task_from_filepath(f)
    df["task"] = task
    dfs.append(df)

In [9]:
len(dfs)

9

In [10]:
df = pd.concat(dfs)
df['domain_split_ood_values'] = df['domain_split_ood_values'].apply(lambda x: str(x).replace("'", ""))
df['domain_split_id_values'] = df['domain_split_id_values'].apply(lambda x: str(x).replace("'", ""))
df

Unnamed: 0,train-auc,train-map,validation-auc,validation-map,id_test-auc,id_test-map,ood_test-auc,ood_test-map,ood_validation-auc,ood_validation-map,...,domain_split_id_values,train-average_precision,validation-average_precision,id_test-average_precision,ood_test-average_precision,ood_validation-average_precision,config/params/min_child_samples,config/params/reg_alpha,config/params/reg_lambda,task
0,0.859967,0.902029,0.881125,0.925866,0.872542,0.919961,0.894500,0.934414,0.902139,0.939591,...,"[1948, 1952, 1954, 1956, 1958, 1960, 1962, 196...",,,,,,,,,anes_year
1,0.805127,0.856644,0.823616,0.882994,0.811675,0.870999,0.852571,0.905029,0.841470,0.905356,...,"[1948, 1952, 1954, 1956, 1958, 1960, 1962, 196...",,,,,,,,,anes_year
2,0.747346,0.829408,0.708851,0.825744,0.709987,0.826910,0.681453,0.824842,0.663097,0.827163,...,"[1948, 1952, 1954, 1956, 1958, 1960, 1962, 196...",,,,,,,,,anes_year
3,0.802530,0.855615,0.814410,0.878067,0.800995,0.866320,0.853459,0.909304,0.875131,0.928597,...,"[1948, 1952, 1954, 1956, 1958, 1960, 1962, 196...",,,,,,,,,anes_year
4,0.675308,0.779708,0.668050,0.789081,0.669585,0.795185,0.665199,0.798295,0.669186,0.807106,...,"[1948, 1952, 1954, 1956, 1958, 1960, 1962, 196...",,,,,,,,,anes_year
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,1.000000,,1.000000,,1.000000,,1.000000,,1.000000,,...,[0],1.0,1.0,1.0,1.0,1.0,8.0,8.429910,2.660396e-03,brfss_blood_pressure_income
96,1.000000,,1.000000,,1.000000,,1.000000,,1.000000,,...,[0],1.0,1.0,1.0,1.0,1.0,8.0,17.220488,1.365476e-02,brfss_blood_pressure_income
97,1.000000,,1.000000,,1.000000,,1.000000,,1.000000,,...,[0],1.0,1.0,1.0,1.0,1.0,8.0,45.201172,1.431158e-01,brfss_blood_pressure_income
98,1.000000,,1.000000,,1.000000,,1.000000,,1.000000,,...,[0],1.0,1.0,1.0,1.0,1.0,16.0,0.282971,4.923437e-03,brfss_blood_pressure_income


In [11]:
print(sorted(df.task.unique()))
print(sorted(df.columns))

['anes_region', 'anes_year', 'brfss_blood_pressure_income', 'brfss_diabetes_race', 'diabetes_admsrc', 'mimic_extract_mort_hosp_ins', 'nhanes_cholesterol_race', 'nhanes_lead_poverty', 'physionet_set']
['config/params/alpha', 'config/params/colsample_bylevel', 'config/params/colsample_bytree', 'config/params/gamma', 'config/params/lambda', 'config/params/learning_rate', 'config/params/max_bin', 'config/params/max_depth', 'config/params/min_child_samples', 'config/params/min_child_weight', 'config/params/reg_alpha', 'config/params/reg_lambda', 'config/params/subsample', 'date', 'domain_split_id_values', 'domain_split_ood_values', 'domain_split_varname', 'done', 'episodes_total', 'estimator', 'experiment_id', 'hostname', 'id_test-auc', 'id_test-average_precision', 'id_test-map', 'id_test_accuracy', 'iterations_since_restore', 'logdir', 'node_ip', 'ood_test-auc', 'ood_test-average_precision', 'ood_test-map', 'ood_test_accuracy', 'ood_validation-auc', 'ood_validation-average_precision', 'ood

In [12]:
df[['task', 'estimator', 'validation_accuracy']]

Unnamed: 0,task,estimator,validation_accuracy
0,anes_year,xgb,0.823033
1,anes_year,xgb,0.795159
2,anes_year,xgb,0.692463
3,anes_year,xgb,0.794242
4,anes_year,xgb,0.700165
...,...,...,...
95,brfss_blood_pressure_income,lightgbm,1.000000
96,brfss_blood_pressure_income,lightgbm,1.000000
97,brfss_blood_pressure_income,lightgbm,1.000000
98,brfss_blood_pressure_income,lightgbm,1.000000


In [13]:
def best_by_metric(df, metric='validation_accuracy'):
    idxs = df.groupby(['task', 'estimator', 'domain_split_ood_values'])['validation_accuracy'].idxmax().values
    return df.iloc[idxs]

In [14]:
best_acc_per_task = best_by_metric(df)
best_acc_per_task[['validation_accuracy', 'task', 'estimator', 'domain_split_ood_values']]

Unnamed: 0,validation_accuracy,task,estimator,domain_split_ood_values
191,0.674491,anes_year,lightgbm,[2016]
307,0.677211,anes_year,lightgbm,[2020]
194,0.870640,anes_region,lightgbm,[1.0]
147,0.874523,brfss_diabetes_race,lightgbm,"[2, 3, 4, 5, 6]"
15,0.785256,anes_year,xgb,[2016]
...,...,...,...,...
0,0.823033,anes_year,xgb,[2016]
142,0.674491,anes_year,lightgbm,[2016]
342,0.677211,anes_year,lightgbm,[2020]
86,0.781955,anes_year,xgb,[2016]


In [15]:
best_acc_per_task.groupby(['task', 'estimator', 'domain_split_ood_values']).size()

task                 estimator  domain_split_ood_values
anes_region          lightgbm   [1.0]                      3
                                [2.0]                      1
                                [3.0]                      1
                                [4.0]                      1
                     xgb        [1.0]                      3
                                [2.0]                      1
                                [3.0]                      1
                                [4.0]                      1
anes_year            lightgbm   [2016]                     8
                                [2020]                     5
                     xgb        [2016]                     9
                                [2020]                     5
brfss_diabetes_race  lightgbm   [2, 3, 4, 5, 6]            3
                     xgb        [2, 3, 4, 5, 6]            3
diabetes_admsrc      lightgbm   (1,)                       1
                             

In [None]:
tasks = sorted(df.task.unique())

fig, axs = plt.subplots(ncols=2, nrows=ceil(len(tasks)/2), figsize=(10,20))
for i, task in enumerate(tasks):
    rownum = i // 2
    colnum = i % 2
    ax = axs[rownum,colnum]
    for est in sorted(df.estimator.unique()):
        df_ = df.query(f"estimator == '{est}' and task == '{task}'")
        ax.scatter(df_['id_test_accuracy'].values, df_['ood_test_accuracy'].values, label=est)
    ax.axline((0.5,0.5), (1, 1), linestyle="dashed", color="grey", label="y=x")
    ax.legend()
    
    ax.set_title(task)

In [None]:
best_acc_per_task.task.unique()

In [None]:
# fig, axs = plt.subplots(ncols=2, nrows=ceil(len(tasks)/2), figsize=(10,20))
# for i, task in enumerate(tasks):
task = 'mimic_extract_mort_hosp_ins'

df_ = best_acc_per_task.query(f"estimator == '{est}' and task == '{task}'")
for ood_vals in df_['domain_split_ood_values'].unique():
    tmp = df_.query(f"domain_split_ood_values == '{ood_vals}'")
#     pprint(tmp[['task', 'estimator', 'domain_split_ood_values',
#        'ood_test_accuracy', 'id_test_accuracy']].to_dict())
    plt.scatter(df_['id_test_accuracy'].values, 
                df_['ood_test_accuracy'].values, 
                label=ood_vals, 
                alpha=0.8)
#     break
# plt.axline((0.5,0.5), (1, 1), linestyle="dashed", color="grey", label="y=x")
plt.legend()

In [None]:
tmp[['task', 'estimator', 'domain_split_id_values', 'domain_split_ood_values',
       'ood_test_accuracy', 'id_test_accuracy', 'validation_accuracy']]

In [None]:
tmp = df.query(f"estimator == '{est}' and task == '{task}' and domain_split_ood_values == '(25,)'")
tmp