In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import glob
import os
import sys
sys.path.append('../')
import matplotlib.pyplot as plt
import shutil
from prediction_utils.util import df_dict_concat, yaml_read, yaml_write
from pathlib import Path
from tqdm import tqdm
import json
import hashlib

In [None]:
project_dir = Path("/scratch/hdd001/home/haoran/stanford_robustness/results")

In [None]:
force_reload = False
pkl_path = Path('df_all_erm.pkl')
if pkl_path.exists() and not force_reload:
    df_all = pd.read_pickle(pkl_path)
else:
    res = []
    for i in tqdm(project_dir.glob('*ERM/**/result_df_group_standard_eval.parquet')):    
        df_i = pd.read_parquet(i)
        args_i = json.load((i.parent/'args.json').open('r'))
        args_i['task'] = i.parent.parent.name[:-3] + '_' + args_i['label_col']
        args_i['config_filename'] = i.parent.name
        
        if (args_i['balance_groups'] or args_i['selection_metric'] != 'loss' 
            or not pd.isnull(args_i['subset_attribute']) 
            or args_i['sensitive_attribute'] != 'gender'):
            continue
        
        for j in ['task', 'config_filename', 'group_objective_type', 'selection_metric', 
                  'balance_groups', 'sensitive_attribute', 'fold_id', 'group_objective_metric', 'subset_attribute',
                 'subset_group']:
            
            if not isinstance(args_i[j], (list, tuple)):
                df_i[j] = args_i[j]
        
        hparams = ['lr', 'num_hidden', 'drop_prob', 'hidden_dim', 'model_type']
        for hparam in hparams:
            df_i[hparam] = args_i[hparam]
        df_i['hparams_id'] = hashlib.md5(str(df_i[hparams].iloc[0].values.tolist()).encode('utf-8')).hexdigest()  
            
        res.append(df_i)
    df_all = pd.concat(res).reset_index(drop = True)
    df_all.to_pickle(pkl_path)

In [None]:
# sanity check
np.all(df_all[(df_all.phase == 'eval') & (df_all.eval_attribute == 'gender') & (df_all.eval_group == 'M') & (df_all.metric == 'auc')].groupby(['hparams_id', 'task'])['performance'].count() == 5)

In [None]:
for task_prefix in ['eICUMortality_target', 'MIMICMortality_target']:

    result_df_erm = df_all[df_all.task == task_prefix]

    mean_performance = (
        pd.DataFrame(
            result_df_erm
            .query('metric == "loss_bce" & phase == "eval"')
            .groupby(['hparams_id'])
            .agg(performance=('performance_overall', 'mean'))
            .reset_index()
        )
    )

    best_model = (
        mean_performance
        .agg(performance=('performance','min'))
        .merge(mean_performance)   
    )

   #  display(best_model)

    selected_config_df = best_model[['hparams_id']].merge(result_df_erm)
    # display(selected_config_df)
    selected_config_dict_list = (
        selected_config_df[hparams]
        .drop_duplicates()
        .to_dict('records')
    )
    assert len(selected_config_dict_list) == 1
    selected_config_dict = selected_config_dict_list[0]
    selected_config_dict['task'] = task_prefix
    print(selected_config_dict)

    yaml_write(selected_config_dict, project_dir/f'{task_prefix}_erm_config.yaml')