In [2]:
import numpy as np
import pandas as pd
import os
import glob
import torch
import torch.nn.functional as F
import joblib
import itertools
import scipy
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.api as sm
import warnings
import string
from sklearn.metrics import roc_auc_score, average_precision_score, brier_score_loss, recall_score, precision_score
from prediction_utils.util import df_dict_concat, yaml_read
from matplotlib.ticker import FormatStrFormatter

In [3]:
project_dir = "/share/pi/nigam/projects/spfohl/cohorts/admissions/starr_20200523"
experiment_name_baseline = 'baseline_tuning_fold_1_10'
experiment_name_fair = 'fair_tuning_fold_1_10'
tasks = ['hospital_mortality', 'LOS_7', 'readmission_30']
cohort_path = os.path.join(project_dir, 'cohort', 'cohort.parquet')
row_id_map_path = os.path.join(
    project_dir, 'merged_features_binary/features_sparse/features_row_id_map.parquet'
)
result_path = os.path.join(project_dir, 'experiments', 'merged_results_fold_1_10')
os.makedirs(result_path, exist_ok=True)

In [4]:
attributes = ['gender_concept_name', 'age_group', 'race_eth']

In [5]:
cohort = pd.read_parquet(cohort_path)
row_id_map = pd.read_parquet(row_id_map_path)
cohort = cohort.merge(row_id_map)

### Generate the cohort table

In [6]:
### Cohort table
cohort_df_long = (
    cohort
    .melt(
        id_vars = ['person_id'] + attributes,
        value_vars = tasks,
        var_name = 'task',
        value_name = 'labels'
    )
    .melt(
        id_vars = ['person_id', 'task', 'labels'],
        value_vars = attributes,
        var_name = 'attribute',
        value_name = 'group'
    )
)

In [7]:
cohort_statistics_df = (
    cohort_df_long
    .groupby(['task', 'attribute', 'group'])
    .agg(
        prevalence=('labels', 'mean'),
    )
    .reset_index()
    .groupby('attribute')
    .apply(lambda x: x.pivot_table(index = 'group', columns = 'task', values = 'prevalence'))
    .reset_index()
)

group_size_df = (
    cohort_df_long
    .groupby(['task', 'attribute', 'group'])
    .agg(
        size = ('labels', lambda x: x.shape[0])
    )
    .reset_index()
    .drop(columns = 'task')
    .drop_duplicates()
)

cohort_statistics_df = cohort_statistics_df.merge(group_size_df)
cohort_statistics_df = (
    cohort_statistics_df
    .set_index(['attribute', 'group'])
    [['size'] + tasks]
)

In [8]:
cohort_statistics_df

Unnamed: 0_level_0,Unnamed: 1_level_0,size,hospital_mortality,LOS_7,readmission_30
attribute,group,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
age_group,[18-30),23042,0.006814,0.175072,0.04609
age_group,[30-45),43432,0.005963,0.129536,0.039648
age_group,[45-55),27394,0.017778,0.205227,0.052712
age_group,[55-65),35703,0.025096,0.227432,0.055794
age_group,[65-75),36084,0.028378,0.234204,0.054844
age_group,[75-91),32989,0.040013,0.237807,0.054533
gender_concept_name,FEMALE,112713,0.016085,0.165935,0.045177
gender_concept_name,MALE,85923,0.027117,0.244335,0.057098
gender_concept_name,No matching concept,8,0.0,0.125,0.0
race_eth,Asian,29460,0.020876,0.171317,0.053632


In [9]:
## Write to Latex
table_path = './../figures/starr_20200523'
os.makedirs(table_path, exist_ok=True)
with open(os.path.join(table_path, 'cohort_table.txt'), 'w') as fp:
    (
        cohort_statistics_df
        .reset_index().drop(columns='attribute').set_index(['group'])
        .to_latex(
            fp, 
            float_format = '%.3g', 
            index_names = False, 
            index=True
        )
    )

### Get the results

In [10]:
def get_result_df_baseline(base_path, filename='result_df_group_standard_eval.parquet'):
    """
    Gets the results for training the baseline models
    """
    selected_models_path = os.path.join(
        base_path, 
        'config',
        'selected_models', '**', '*.yaml'
    )
    selected_models_dict = {
        filename.split('/')[-2]: filename.split('/')[-1]
        for filename in glob.glob(selected_models_path, recursive=True)
    }
    paths = [
        glob.glob(
            os.path.join(
                base_path,
                'performance',
                task, 
                config_filename, 
                '**', 
                filename
            ),
            recursive=True
        )
        for task, config_filename in selected_models_dict.items()
    ]
    paths = list(itertools.chain(*paths))
    result_df_baseline = df_dict_concat(
        {
            tuple(filename.split('/'))[-4:-1]:
            pd.read_parquet(filename)
            for filename in paths
        },
        ['task2', 'config_filename', 'fold_id']
    ).drop(columns='task2')
    return result_df_baseline

In [11]:
result_df_baseline = get_result_df_baseline(
    os.path.join(
        project_dir,
        'experiments',
        experiment_name_baseline,  
    )
)

In [12]:
result_df_baseline.task.unique()

array(['LOS_7', 'readmission_30', 'hospital_mortality'], dtype=object)

In [13]:
result_df_calibration_baseline = get_result_df_baseline(
    os.path.join(
        project_dir,
        'experiments',
        experiment_name_baseline,  
    ),
    filename='calibration_result.parquet'
)
id_vars = ['fold_id', 'phase', 'config_filename', 'task', 'attribute', 'group']
result_df_calibration_baseline = result_df_calibration_baseline.melt(
    id_vars = id_vars,
    value_vars = set(result_df_calibration_baseline.columns) - set(id_vars),
    var_name = 'metric',
    value_name = 'performance'
).query('metric != "brier"')

In [14]:
result_df_calibration_baseline.metric.unique()

array(['brier_signed', 'calib_error', 'calib_error_signed',
       'calib_group_error', 'calib_group_error_signed'], dtype=object)

In [15]:
# Import fair_ova metrics
result_df_ova_baseline = get_result_df_baseline(
    os.path.join(
        project_dir,
        'experiments',
        experiment_name_baseline,  
    ),
    filename='result_df_group_fair_ova.parquet'
)
id_vars = ['fold_id', 'phase', 'config_filename', 'task', 'attribute', 'group']
result_df_ova_baseline = result_df_ova_baseline.melt(
    id_vars = id_vars,
    value_vars = set(result_df_ova_baseline.columns) - set(id_vars),
    var_name = 'metric',
    value_name = 'performance'
)

In [16]:
result_df_baseline = pd.concat([result_df_baseline, result_df_calibration_baseline, result_df_ova_baseline], ignore_index=True)

In [17]:
result_df_baseline

Unnamed: 0,config_filename,fold_id,metric,phase,task,attribute,group,performance,performance_overall
0,15.yaml,7,auc,test,LOS_7,age_group,[18-30),0.812202,0.792947
1,15.yaml,7,auc,test,LOS_7,age_group,[30-45),0.830439,0.792947
2,15.yaml,7,auc,test,LOS_7,age_group,[45-55),0.806845,0.792947
3,15.yaml,7,auc,test,LOS_7,age_group,[55-65),0.793922,0.792947
4,15.yaml,7,auc,test,LOS_7,age_group,[65-75),0.762436,0.792947
...,...,...,...,...,...,...,...,...,...
14815,13.yaml,9,xauc_ova_1,val,hospital_mortality,race_eth,Hispanic or Latino,0.901913,
14816,13.yaml,9,xauc_ova_1,val,hospital_mortality,race_eth,Black or African American,0.869531,
14817,13.yaml,9,xauc_ova_1,val,hospital_mortality,race_eth,Other,0.897275,
14818,13.yaml,9,xauc_ova_1,val,hospital_mortality,race_eth,Asian,0.881732,


In [18]:
def flatten_multicolumns(df):
    """
    Converts multi-index columns into single colum
    """
    df.columns = ['_'.join([el for el in col if el != '']).strip() for col in df.columns.values if len(col) > 1]
    return df

In [19]:
result_df_baseline_mean = (
    result_df_baseline
    .groupby(list(set(result_df_baseline.columns) - {'fold_id', 'performance', 'performance_overall'}))
    [['performance', 'performance_overall']]
    .agg(['mean', 'std', 'sem'])
    .reset_index()
)
result_df_baseline_mean = result_df_baseline_mean.rename(
    columns={
        'performance': 'performance_baseline',
        'performance_overall': 'performance_overall_baseline'
    }
)
result_df_baseline_mean = flatten_multicolumns(result_df_baseline_mean)

In [20]:
result_df_baseline_mean

Unnamed: 0,config_filename,metric,task,phase,group,attribute,performance_baseline_mean,performance_baseline_std,performance_baseline_sem,performance_overall_baseline_mean,performance_overall_baseline_std,performance_overall_baseline_sem
0,13.yaml,auc,hospital_mortality,test,Asian,race_eth,0.879913,0.005568,0.001761,0.864627,0.003792,0.001199
1,13.yaml,auc,hospital_mortality,test,Black or African American,race_eth,0.906015,0.014008,0.004430,0.864627,0.003792,0.001199
2,13.yaml,auc,hospital_mortality,test,FEMALE,gender_concept_name,0.887575,0.002943,0.000931,0.864627,0.003792,0.001199
3,13.yaml,auc,hospital_mortality,test,Hispanic or Latino,race_eth,0.871576,0.006754,0.002136,0.864627,0.003792,0.001199
4,13.yaml,auc,hospital_mortality,test,MALE,gender_concept_name,0.835377,0.005570,0.001761,0.864627,0.003792,0.001199
...,...,...,...,...,...,...,...,...,...,...,...,...
1477,48.yaml,xauc_ova_1,readmission_30,val,[30-45),age_group,0.716632,0.025187,0.007965,,,
1478,48.yaml,xauc_ova_1,readmission_30,val,[45-55),age_group,0.758859,0.021833,0.006904,,,
1479,48.yaml,xauc_ova_1,readmission_30,val,[55-65),age_group,0.771841,0.017223,0.005446,,,
1480,48.yaml,xauc_ova_1,readmission_30,val,[65-75),age_group,0.758074,0.017058,0.005394,,,


In [21]:
result_df_baseline_mean.task.unique()

array(['hospital_mortality', 'LOS_7', 'readmission_30'], dtype=object)

In [22]:
def get_result_df_fair(base_path=None, filename='result_df_group_standard_eval.parquet', paths=None):
    if paths is None:
        performance_path = os.path.join(
            base_path,
            'performance',
        )
        paths = glob.glob(os.path.join(performance_path, '**', filename), recursive=True)
    result_df_fair = df_dict_concat(
        {
            tuple(file_name.split('/'))[-5:-1]:
            pd.read_parquet(file_name)
            for file_name in paths
        },
        ['task2', 'sensitive_attribute', 'config_filename', 'fold_id']
    ).drop(columns='task2')
    return result_df_fair

In [23]:
# Fair results
result_df_fair = get_result_df_fair(
    os.path.join(
        project_dir,
        'experiments',
        experiment_name_fair
    )
)

In [24]:
# List config_filenames without ten results
(
    result_df_fair
    .groupby(
        list(set(result_df_fair.columns) - set(['fold_id', 'performance', 'performance_overall']))
    )
    .agg(lambda x: len(x))
    .query("fold_id != 10")
    .reset_index()
    .config_filename
    .sort_values()
    .unique()
)

array([], dtype=float64)

In [25]:
result_df_calibration_fair = get_result_df_fair(
    os.path.join(
        project_dir,
        'experiments',
        experiment_name_fair
    ),
    filename='calibration_result.parquet'
)

id_vars = ['fold_id', 'phase', 'config_filename', 'task', 'sensitive_attribute', 'attribute', 'group']
result_df_calibration_fair = result_df_calibration_fair.melt(
    id_vars = id_vars,
    value_vars = set(result_df_calibration_fair.columns) - set(id_vars),
    var_name = 'metric',
    value_name = 'performance'
).query('metric != "brier"')

In [26]:
result_df_ova_fair = get_result_df_fair(
    os.path.join(
        project_dir,
        'experiments',
        experiment_name_fair
    ),
    filename='result_df_group_fair_ova.parquet'
)

id_vars = ['fold_id', 'phase', 'config_filename', 'task', 'sensitive_attribute', 'attribute', 'group']
result_df_ova_fair = result_df_ova_fair.melt(
    id_vars = id_vars,
    value_vars = set(result_df_ova_fair.columns) - set(id_vars),
    var_name = 'metric',
    value_name = 'performance'
)

In [27]:
# List config_filenames without ten results
(
    result_df_ova_fair
    .groupby(
        list(set(result_df_ova_fair.columns) - set(['fold_id', 'performance', 'performance_overall']))
    )
    .agg(lambda x: len(x))
    .query("fold_id != 10")
    .reset_index()
    .config_filename
    .sort_values()
    .unique()
)

array([], dtype=float64)

In [27]:
result_df_fair = pd.concat([result_df_fair, result_df_calibration_fair, result_df_ova_fair], ignore_index=True)

In [28]:
result_df_fair_mean = (
    result_df_fair
    .groupby(list(set(result_df_fair.columns) - set(['fold_id', 'performance', 'performance_overall'])))
    [['performance', 'performance_overall']]
    .agg(['mean', 'std', 'sem'])
    .reset_index()
)
result_df_fair_mean = flatten_multicolumns(result_df_fair_mean)

In [29]:
ci_func = lambda x: x * 1.96
result_df_fair_mean = result_df_fair_mean.assign(
    performance_CI = lambda x: ci_func(x['performance_sem']),
    performance_overall_CI = lambda x: ci_func(x['performance_overall_sem']),
)

In [30]:
def label_fair_mode(df):
    df['fair_mode'] = (
    df['regularization_metric']
    .where(~df['regularization_metric'].str.match('mmd'), 
           df['regularization_metric'].astype(str) + '_' + df['mmd_mode'].astype(str), 
           axis=0)
    )
    df['fair_mode'] = (
        df['fair_mode']
        .where(~df['fair_mode'].str.match('mean_prediction'), 
               df['fair_mode'].astype(str) + '_' + df['mean_prediction_mode'].astype(str), 
               axis=0
              )
    )
    return df

In [31]:
def get_fair_config_df(base_path):
    config_path = os.path.join(
        base_path,
        'config',
    )
    fair_config_files = glob.glob(
        os.path.join(config_path, '**', '*.yaml'),
        recursive=True
    )
    fair_config_dict_dict = {
        tuple(file_name.split('/'))[-2:]:
        yaml_read(file_name)
        for file_name in fair_config_files
    }

    fair_config_df = df_dict_concat(
        {
            key: pd.DataFrame(value, index=[key])
            for key, value in fair_config_dict_dict.items()
        },
        ['task', 'config_filename']
    )
    fair_config_df = label_fair_mode(fair_config_df)[['task', 'config_filename', 'fair_mode', 'lambda_group_regularization']]
    return fair_config_df

In [32]:
fair_config_df = get_fair_config_df(
    os.path.join(
        project_dir,
        'experiments',
        experiment_name_fair
    )
)

In [33]:
fair_config_df

Unnamed: 0,task,config_filename,fair_mode,lambda_group_regularization
0,LOS_7,8.yaml,mean_prediction_conditional_pos,0.002783
1,LOS_7,37.yaml,mean_prediction_unconditional,0.464159
2,LOS_7,36.yaml,mean_prediction_conditional,0.464159
3,LOS_7,0.yaml,mean_prediction_conditional,0.001000
4,LOS_7,35.yaml,mmd_conditional_pos,0.166810
...,...,...,...,...
175,hospital_mortality,40.yaml,mmd_unconditional,0.464159
176,hospital_mortality,57.yaml,mmd_conditional,10.000000
177,hospital_mortality,52.yaml,mmd_unconditional,3.593814
178,hospital_mortality,5.yaml,mmd_conditional_pos,0.001000


In [34]:
result_df_fair_mean.task.unique()

array(['LOS_7', 'hospital_mortality', 'readmission_30'], dtype=object)

In [35]:
result_df = pd.merge(result_df_baseline_mean.drop(columns='config_filename'), result_df_fair_mean,
                    how='outer', indicator=True).merge(fair_config_df)
assert result_df_fair_mean.shape[0] == result_df.shape[0]
result_df.head()

Unnamed: 0,task,attribute,metric,group,phase,performance_baseline_mean,performance_baseline_std,performance_baseline_sem,performance_overall_baseline_mean,performance_overall_baseline_std,...,performance_std,performance_sem,performance_overall_mean,performance_overall_std,performance_overall_sem,performance_CI,performance_overall_CI,_merge,fair_mode,lambda_group_regularization
0,LOS_7,age_group,auc,[18-30),test,0.809874,0.001614,0.00051,0.792915,0.0013,...,0.003531,0.001117,0.790199,0.001344,0.000425,0.002188,0.000833,both,mean_prediction_conditional,0.001
1,LOS_7,age_group,auc,[18-30),test,0.809874,0.001614,0.00051,0.792915,0.0013,...,0.0033,0.001044,0.792207,0.001222,0.000386,0.002045,0.000757,both,mean_prediction_conditional,0.001
2,LOS_7,age_group,auc,[18-30),test,0.809874,0.001614,0.00051,0.792915,0.0013,...,0.003101,0.000981,0.790876,0.001838,0.000581,0.001922,0.001139,both,mean_prediction_conditional,0.001
3,LOS_7,age_group,auc,[18-30),val,0.813017,0.010958,0.003465,0.798198,0.003871,...,0.013908,0.004398,0.795609,0.004146,0.001311,0.008621,0.00257,both,mean_prediction_conditional,0.001
4,LOS_7,age_group,auc,[18-30),val,0.813017,0.010958,0.003465,0.798198,0.003871,...,0.011854,0.003749,0.795987,0.003541,0.00112,0.007347,0.002194,both,mean_prediction_conditional,0.001


In [37]:
assert result_df.query('_merge == "right_only"').shape[0] == 0

In [38]:
result_df.metric.unique()

array(['auc', 'auprc', 'brier', 'brier_signed', 'calib_error',
       'calib_error_signed', 'calib_group_error',
       'calib_group_error_signed', 'emd_ova', 'emd_ova_0', 'emd_ova_1',
       'loss_bce', 'mean_prediction', 'mean_prediction_0',
       'mean_prediction_1', 'xauc_0', 'xauc_1', 'xauc_ova_0',
       'xauc_ova_1'], dtype=object)

In [39]:
result_df = result_df.query('phase == "test"')

In [43]:
result_df.head()

Unnamed: 0,task,attribute,metric,group,phase,performance_baseline_mean,performance_baseline_std,performance_baseline_sem,performance_overall_baseline_mean,performance_overall_baseline_std,...,performance_std,performance_sem,performance_overall_mean,performance_overall_std,performance_overall_sem,performance_CI,performance_overall_CI,_merge,fair_mode,lambda_group_regularization
0,LOS_7,age_group,auc,[18-30),test,0.809874,0.001614,0.00051,0.792915,0.0013,...,0.003531,0.001117,0.790199,0.001344,0.000425,0.002188,0.000833,both,mean_prediction_conditional,0.001
1,LOS_7,age_group,auc,[18-30),test,0.809874,0.001614,0.00051,0.792915,0.0013,...,0.0033,0.001044,0.792207,0.001222,0.000386,0.002045,0.000757,both,mean_prediction_conditional,0.001
2,LOS_7,age_group,auc,[18-30),test,0.809874,0.001614,0.00051,0.792915,0.0013,...,0.003101,0.000981,0.790876,0.001838,0.000581,0.001922,0.001139,both,mean_prediction_conditional,0.001
6,LOS_7,age_group,auc,[30-45),test,0.830717,0.002633,0.000833,0.792915,0.0013,...,0.001823,0.000577,0.790199,0.001344,0.000425,0.00113,0.000833,both,mean_prediction_conditional,0.001
7,LOS_7,age_group,auc,[30-45),test,0.830717,0.002633,0.000833,0.792915,0.0013,...,0.001433,0.000453,0.792207,0.001222,0.000386,0.000888,0.000757,both,mean_prediction_conditional,0.001


In [46]:
result_df.columns

Index(['task', 'attribute', 'metric', 'group', 'phase',
       'performance_baseline_mean', 'performance_baseline_std',
       'performance_baseline_sem', 'performance_overall_baseline_mean',
       'performance_overall_baseline_std', 'performance_overall_baseline_sem',
       'config_filename', 'sensitive_attribute', 'performance_mean',
       'performance_std', 'performance_sem', 'performance_overall_mean',
       'performance_overall_std', 'performance_overall_sem', 'performance_CI',
       'performance_overall_CI', '_merge', 'fair_mode',
       'lambda_group_regularization'],
      dtype='object')

In [47]:
result_df = result_df.drop(columns = '_merge')

In [49]:
result_df.to_csv(os.path.join(result_path, 'group_results.csv'), index=False)