In [1]:
import numpy as np
import pandas as pd
import os
import glob
import torch
import torch.nn.functional as F
import joblib
import itertools
import scipy
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.api as sm
import warnings
import string
from sklearn.metrics import roc_auc_score, average_precision_score, brier_score_loss, recall_score, precision_score
from prediction_utils.util import df_dict_concat, yaml_read
from matplotlib.ticker import FormatStrFormatter

In [44]:
project_dir = "/share/pi/nigam/projects/spfohl/cohorts/admissions/optum"
experiment_name_baseline = 'baseline_tuning_fold_1'
experiment_name_fair = 'fair_tuning_fold_1'
tasks = ['LOS_7', 'readmission_30']
cohort_path = os.path.join(project_dir, 'cohort', 'cohort.parquet')
row_id_map_path = os.path.join(
    project_dir, 'merged_features_binary/features_sparse/features_row_id_map.parquet'
)
result_path = os.path.join(project_dir, 'experiments', 'merged_results_fold_1')
os.makedirs(result_path, exist_ok=True)

In [3]:
attributes = ['gender_concept_name', 'age_group']

In [4]:
cohort = pd.read_parquet(cohort_path)
row_id_map = pd.read_parquet(row_id_map_path)
cohort = cohort.merge(row_id_map)

### Generate the cohort table

In [5]:
### Cohort table
cohort_df_long = (
    cohort
    .melt(
        id_vars = ['person_id'] + attributes,
        value_vars = tasks,
        var_name = 'task',
        value_name = 'labels'
    )
    .melt(
        id_vars = ['person_id', 'task', 'labels'],
        value_vars = attributes,
        var_name = 'attribute',
        value_name = 'group'
    )
)

In [6]:
cohort_statistics_df = (
    cohort_df_long
    .groupby(['task', 'attribute', 'group'])
    .agg(
        prevalence=('labels', 'mean'),
    )
    .reset_index()
    .groupby('attribute')
    .apply(lambda x: x.pivot_table(index = 'group', columns = 'task', values = 'prevalence'))
    .reset_index()
)

group_size_df = (
    cohort_df_long
    .groupby(['task', 'attribute', 'group'])
    .agg(
        size = ('labels', lambda x: x.shape[0])
    )
    .reset_index()
    .drop(columns = 'task')
    .drop_duplicates()
)

cohort_statistics_df = cohort_statistics_df.merge(group_size_df)
cohort_statistics_df = (
    cohort_statistics_df
    .set_index(['attribute', 'group'])
    [['size'] + tasks]
)

In [7]:
cohort_statistics_df

Unnamed: 0_level_0,Unnamed: 1_level_0,size,LOS_7,readmission_30
attribute,group,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
age_group,[18-30),1067423,0.060803,0.034595
age_group,[30-45),1854239,0.061145,0.034738
age_group,[45-55),1006924,0.137928,0.061149
age_group,[55-65),1173140,0.195259,0.080777
age_group,[65-75),1294273,0.25788,0.100376
age_group,[75-91),1678572,0.38577,0.168239
gender_concept_name,FEMALE,5040564,0.168399,0.07647
gender_concept_name,MALE,3032831,0.223706,0.093776
gender_concept_name,No matching concept,1176,0.213435,0.111395


In [8]:
## Write to Latex
table_path = './../figures/optum'
os.makedirs(table_path, exist_ok=True)
with open(os.path.join(table_path, 'cohort_table.txt'), 'w') as fp:
    (
        cohort_statistics_df
        .reset_index().drop(columns='attribute').set_index(['group'])
        .to_latex(
            fp, 
            float_format = '%.3g', 
            index_names = False, 
            index=True
        )
    )

### Get the results

In [9]:
def get_result_df_baseline(base_path, filename='result_df_group_standard_eval.parquet'):
    """
    Gets the results for training the baseline models
    """
    selected_models_path = os.path.join(
        base_path, 
        'config',
        'selected_models', '**', '*.yaml'
    )
    selected_models_dict = {
        filename.split('/')[-2]: filename.split('/')[-1]
        for filename in glob.glob(selected_models_path, recursive=True)
    }
    paths = [
        glob.glob(
            os.path.join(
                base_path,
                'performance',
                task, 
                config_filename, 
                '**', 
                filename
            ),
            recursive=True
        )
        for task, config_filename in selected_models_dict.items()
    ]
    paths = list(itertools.chain(*paths))
    result_df_baseline = df_dict_concat(
        {
            tuple(filename.split('/'))[-4:-1]:
            pd.read_parquet(filename)
            for filename in paths
        },
        ['task2', 'config_filename', 'fold_id']
    ).drop(columns='task2')
    return result_df_baseline

In [10]:
result_df_baseline = get_result_df_baseline(
    os.path.join(
        project_dir,
        'experiments',
        experiment_name_baseline,  
    )
)

In [11]:
result_df_baseline.task.unique()

array(['LOS_7', 'readmission_30'], dtype=object)

In [12]:
result_df_baseline.task.unique()

array(['LOS_7', 'readmission_30'], dtype=object)

In [13]:
result_df_calibration_baseline = get_result_df_baseline(
    os.path.join(
        project_dir,
        'experiments',
        experiment_name_baseline,  
    ),
    filename='calibration_result.parquet'
)
id_vars = ['fold_id', 'phase', 'config_filename', 'task', 'attribute', 'group']
result_df_calibration_baseline = result_df_calibration_baseline.melt(
    id_vars = id_vars,
    value_vars = set(result_df_calibration_baseline.columns) - set(id_vars),
    var_name = 'metric',
    value_name = 'performance'
).query('metric != "brier"')

In [14]:
result_df_calibration_baseline.metric.unique()

array(['brier_signed', 'calib_error_signed', 'calib_group_error_signed',
       'calib_error', 'calib_group_error'], dtype=object)

In [15]:
# Import fair_ova metrics
result_df_ova_baseline = get_result_df_baseline(
    os.path.join(
        project_dir,
        'experiments',
        experiment_name_baseline,  
    ),
    filename='result_df_group_fair_ova.parquet'
)
id_vars = ['fold_id', 'phase', 'config_filename', 'task', 'attribute', 'group']
result_df_ova_baseline = result_df_ova_baseline.melt(
    id_vars = id_vars,
    value_vars = set(result_df_ova_baseline.columns) - set(id_vars),
    var_name = 'metric',
    value_name = 'performance'
)

In [16]:
result_df_baseline = pd.concat([result_df_baseline, result_df_calibration_baseline, result_df_ova_baseline], ignore_index=True)
# result_df_baseline = pd.concat([result_df_baseline, result_df_ova_baseline], ignore_index=True)

In [17]:
result_df_baseline

Unnamed: 0,config_filename,fold_id,metric,phase,task,attribute,group,performance,performance_overall
0,2.yaml,1,auc,test,LOS_7,age_group,[18-30),0.840362,0.802328
1,2.yaml,1,auc,test,LOS_7,age_group,[30-45),0.826019,0.802328
2,2.yaml,1,auc,test,LOS_7,age_group,[45-55),0.724967,0.802328
3,2.yaml,1,auc,test,LOS_7,age_group,[55-65),0.710528,0.802328
4,2.yaml,1,auc,test,LOS_7,age_group,[65-75),0.709761,0.802328
...,...,...,...,...,...,...,...,...,...
603,2.yaml,1,mean_prediction,val,readmission_30,age_group,[65-75),0.021172,
604,2.yaml,1,mean_prediction,val,readmission_30,age_group,[55-65),-0.001764,
605,2.yaml,1,mean_prediction,val,readmission_30,age_group,[45-55),-0.018385,
606,2.yaml,1,mean_prediction,val,readmission_30,gender_concept_name,MALE,0.010943,


In [18]:
def flatten_multicolumns(df):
    """
    Converts multi-index columns into single colum
    """
    df.columns = ['_'.join([el for el in col if el != '']).strip() for col in df.columns.values if len(col) > 1]
    return df

In [19]:
result_df_baseline_mean = (
    result_df_baseline
    .groupby(list(set(result_df_baseline.columns) - {'fold_id', 'performance', 'performance_overall'}))
    [['performance', 'performance_overall']]
    .agg(['mean', 'std', 'sem'])
    .reset_index()
)
result_df_baseline_mean = result_df_baseline_mean.rename(
    columns={
        'performance': 'performance_baseline',
        'performance_overall': 'performance_overall_baseline'
    }
)
result_df_baseline_mean = flatten_multicolumns(result_df_baseline_mean)

In [20]:
result_df_baseline_mean

Unnamed: 0,group,config_filename,attribute,phase,metric,task,performance_baseline_mean,performance_baseline_std,performance_baseline_sem,performance_overall_baseline_mean,performance_overall_baseline_std,performance_overall_baseline_sem
0,FEMALE,2.yaml,gender_concept_name,test,auc,LOS_7,0.839665,,,0.802328,,
1,FEMALE,2.yaml,gender_concept_name,test,auc,readmission_30,0.794727,,,0.774163,,
2,FEMALE,2.yaml,gender_concept_name,test,auprc,LOS_7,0.513063,,,0.496130,,
3,FEMALE,2.yaml,gender_concept_name,test,auprc,readmission_30,0.258544,,,0.256986,,
4,FEMALE,2.yaml,gender_concept_name,test,brier,LOS_7,0.106802,,,0.122795,,
...,...,...,...,...,...,...,...,...,...,...,...,...
603,[75-91),2.yaml,age_group,val,xauc_1,readmission_30,0.912136,,,,,
604,[75-91),2.yaml,age_group,val,xauc_ova_0,LOS_7,0.533867,,,,,
605,[75-91),2.yaml,age_group,val,xauc_ova_0,readmission_30,0.546223,,,,,
606,[75-91),2.yaml,age_group,val,xauc_ova_1,LOS_7,0.892120,,,,,


In [21]:
result_df_baseline_mean.task.unique()

array(['LOS_7', 'readmission_30'], dtype=object)

In [22]:
def get_result_df_fair(base_path=None, filename='result_df_group_standard_eval.parquet', paths=None):
    if paths is None:
        performance_path = os.path.join(
            base_path,
            'performance',
        )
        paths = glob.glob(os.path.join(performance_path, '**', filename), recursive=True)
    result_df_fair = df_dict_concat(
        {
            tuple(file_name.split('/'))[-5:-1]:
            pd.read_parquet(file_name)
            for file_name in paths
        },
        ['task2', 'sensitive_attribute', 'config_filename', 'fold_id']
    ).drop(columns='task2')
    return result_df_fair

In [23]:
# Fair results
result_df_fair = get_result_df_fair(
    os.path.join(
        project_dir,
        'experiments',
        experiment_name_fair
    )
)

In [24]:
# List config_filenames without a result
(
    result_df_fair
    .groupby(
        list(set(result_df_fair.columns) - set(['fold_id', 'performance', 'performance_overall']))
    )
    .agg(lambda x: len(x))
    .query("fold_id != 1")
    .reset_index()
    .config_filename
    .sort_values()
    .unique()
)

array([], dtype=float64)

In [25]:
result_df_calibration_fair = get_result_df_fair(
    os.path.join(
        project_dir,
        'experiments',
        experiment_name_fair
    ),
    filename='calibration_result.parquet'
)

id_vars = ['fold_id', 'phase', 'config_filename', 'task', 'sensitive_attribute', 'attribute', 'group']
result_df_calibration_fair = result_df_calibration_fair.melt(
    id_vars = id_vars,
    value_vars = set(result_df_calibration_fair.columns) - set(id_vars),
    var_name = 'metric',
    value_name = 'performance'
).query('metric != "brier"')

In [26]:
result_df_calibration_fair.query(
    """
    task == "LOS_7" & sensitive_attribute == "gender_concept_name" & metric == "calib_group_error" & group == "FEMALE" & phase == "test"
    """
)

Unnamed: 0,fold_id,phase,config_filename,task,sensitive_attribute,attribute,group,metric,performance
19206,1,test,8.yaml,LOS_7,gender_concept_name,gender_concept_name,FEMALE,calib_group_error,7.931658e-06
19222,1,test,37.yaml,LOS_7,gender_concept_name,gender_concept_name,FEMALE,calib_group_error,0.0009581342
19238,1,test,36.yaml,LOS_7,gender_concept_name,gender_concept_name,FEMALE,calib_group_error,0.0009366479
19254,1,test,0.yaml,LOS_7,gender_concept_name,gender_concept_name,FEMALE,calib_group_error,5.450435e-07
19270,1,test,35.yaml,LOS_7,gender_concept_name,gender_concept_name,FEMALE,calib_group_error,2.671223e-05
19286,1,test,1.yaml,LOS_7,gender_concept_name,gender_concept_name,FEMALE,calib_group_error,8.002997e-07
19302,1,test,3.yaml,LOS_7,gender_concept_name,gender_concept_name,FEMALE,calib_group_error,9.040098e-06
19318,1,test,13.yaml,LOS_7,gender_concept_name,gender_concept_name,FEMALE,calib_group_error,5.47761e-05
19334,1,test,51.yaml,LOS_7,gender_concept_name,gender_concept_name,FEMALE,calib_group_error,0.0004443074
19350,1,test,33.yaml,LOS_7,gender_concept_name,gender_concept_name,FEMALE,calib_group_error,8.153347e-05


In [27]:
result_df_ova_fair = get_result_df_fair(
    os.path.join(
        project_dir,
        'experiments',
        experiment_name_fair
    ),
    filename='result_df_group_fair_ova.parquet'
)

id_vars = ['fold_id', 'phase', 'config_filename', 'task', 'sensitive_attribute', 'attribute', 'group']
result_df_ova_fair = result_df_ova_fair.melt(
    id_vars = id_vars,
    value_vars = set(result_df_ova_fair.columns) - set(id_vars),
    var_name = 'metric',
    value_name = 'performance'
)

In [28]:
result_df_calibration_fair.query('metric == "calib_group_error"')

Unnamed: 0,fold_id,phase,config_filename,task,sensitive_attribute,attribute,group,metric,performance
19200,1,test,8.yaml,LOS_7,gender_concept_name,age_group,[18-30),calib_group_error,0.000524
19201,1,test,8.yaml,LOS_7,gender_concept_name,age_group,[30-45),calib_group_error,0.000811
19202,1,test,8.yaml,LOS_7,gender_concept_name,age_group,[45-55),calib_group_error,0.000456
19203,1,test,8.yaml,LOS_7,gender_concept_name,age_group,[55-65),calib_group_error,0.000146
19204,1,test,8.yaml,LOS_7,gender_concept_name,age_group,[65-75),calib_group_error,0.000008
...,...,...,...,...,...,...,...,...,...
23035,1,val,49.yaml,readmission_30,age_group,age_group,[55-65),calib_group_error,0.000018
23036,1,val,49.yaml,readmission_30,age_group,age_group,[65-75),calib_group_error,0.000198
23037,1,val,49.yaml,readmission_30,age_group,age_group,[75-91),calib_group_error,0.006054
23038,1,val,49.yaml,readmission_30,age_group,gender_concept_name,FEMALE,calib_group_error,0.000041


In [29]:
result_df_fair = pd.concat([result_df_fair, result_df_calibration_fair, result_df_ova_fair], ignore_index=True)
# result_df_fair = pd.concat([result_df_fair, result_df_ova_fair], ignore_index=True)

In [30]:
result_df_fair_mean = (
    result_df_fair
    .groupby(list(set(result_df_fair.columns) - set(['fold_id', 'performance', 'performance_overall'])))
    [['performance', 'performance_overall']]
    .agg(['mean', 'std', 'sem'])
    .reset_index()
)
result_df_fair_mean = flatten_multicolumns(result_df_fair_mean)

In [31]:
ci_func = lambda x: x * 1.96
result_df_fair_mean = result_df_fair_mean.assign(
    performance_CI = lambda x: ci_func(x['performance_sem']),
    performance_overall_CI = lambda x: ci_func(x['performance_overall_sem']),
)

In [32]:
def label_fair_mode(df):
    df['fair_mode'] = (
    df['regularization_metric']
    .where(~df['regularization_metric'].str.match('mmd'), 
           df['regularization_metric'].astype(str) + '_' + df['mmd_mode'].astype(str), 
           axis=0)
    )
    df['fair_mode'] = (
        df['fair_mode']
        .where(~df['fair_mode'].str.match('mean_prediction'), 
               df['fair_mode'].astype(str) + '_' + df['mean_prediction_mode'].astype(str), 
               axis=0
              )
    )
    return df

In [33]:
def get_fair_config_df(base_path):
    config_path = os.path.join(
        base_path,
        'config',
    )
    fair_config_files = glob.glob(
        os.path.join(config_path, '**', '*.yaml'),
        recursive=True
    )
    fair_config_dict_dict = {
        tuple(file_name.split('/'))[-2:]:
        yaml_read(file_name)
        for file_name in fair_config_files
    }

    fair_config_df = df_dict_concat(
        {
            key: pd.DataFrame(value, index=[key])
            for key, value in fair_config_dict_dict.items()
        },
        ['task', 'config_filename']
    )
    fair_config_df = label_fair_mode(fair_config_df)[['task', 'config_filename', 'fair_mode', 'lambda_group_regularization']]
    return fair_config_df

In [34]:
fair_config_df = get_fair_config_df(
    os.path.join(
        project_dir,
        'experiments',
        experiment_name_fair
    )
)

In [35]:
fair_config_df

Unnamed: 0,task,config_filename,fair_mode,lambda_group_regularization
0,LOS_7,8.yaml,mean_prediction_conditional_pos,0.002783
1,LOS_7,37.yaml,mean_prediction_unconditional,0.464159
2,LOS_7,36.yaml,mean_prediction_conditional,0.464159
3,LOS_7,0.yaml,mean_prediction_conditional,0.001000
4,LOS_7,35.yaml,mmd_conditional_pos,0.166810
...,...,...,...,...
115,readmission_30,40.yaml,mmd_unconditional,0.464159
116,readmission_30,57.yaml,mmd_conditional,10.000000
117,readmission_30,52.yaml,mmd_unconditional,3.593814
118,readmission_30,5.yaml,mmd_conditional_pos,0.001000


In [36]:
result_df_fair_mean.task.unique()

array(['LOS_7', 'readmission_30'], dtype=object)

In [37]:
result_df = pd.merge(result_df_baseline_mean.drop(columns='config_filename'), result_df_fair_mean,
                    how='outer', indicator=True).merge(fair_config_df)
assert result_df_fair_mean.shape[0] == result_df.shape[0]
result_df.head()

Unnamed: 0,group,attribute,phase,metric,task,performance_baseline_mean,performance_baseline_std,performance_baseline_sem,performance_overall_baseline_mean,performance_overall_baseline_std,...,performance_std,performance_sem,performance_overall_mean,performance_overall_std,performance_overall_sem,performance_CI,performance_overall_CI,_merge,fair_mode,lambda_group_regularization
0,FEMALE,gender_concept_name,test,auc,LOS_7,0.839665,,,0.802328,,...,,,0.801219,,,,,both,mean_prediction_conditional,0.001
1,FEMALE,gender_concept_name,test,auc,LOS_7,0.839665,,,0.802328,,...,,,0.802113,,,,,both,mean_prediction_conditional,0.001
2,FEMALE,gender_concept_name,test,auprc,LOS_7,0.513063,,,0.49613,,...,,,0.494226,,,,,both,mean_prediction_conditional,0.001
3,FEMALE,gender_concept_name,test,auprc,LOS_7,0.513063,,,0.49613,,...,,,0.495495,,,,,both,mean_prediction_conditional,0.001
4,FEMALE,gender_concept_name,test,brier,LOS_7,0.106802,,,0.122795,,...,,,0.123059,,,,,both,mean_prediction_conditional,0.001


In [38]:
result_df.query('_merge == "right_only"')

Unnamed: 0,group,attribute,phase,metric,task,performance_baseline_mean,performance_baseline_std,performance_baseline_sem,performance_overall_baseline_mean,performance_overall_baseline_std,...,performance_std,performance_sem,performance_overall_mean,performance_overall_std,performance_overall_sem,performance_CI,performance_overall_CI,_merge,fair_mode,lambda_group_regularization


In [39]:
result_df.metric.unique()

array(['auc', 'auprc', 'brier', 'brier_signed', 'calib_error',
       'calib_error_signed', 'calib_group_error',
       'calib_group_error_signed', 'emd_ova', 'emd_ova_0', 'emd_ova_1',
       'loss_bce', 'mean_prediction', 'mean_prediction_0',
       'mean_prediction_1', 'xauc_0', 'xauc_1', 'xauc_ova_0',
       'xauc_ova_1'], dtype=object)

In [40]:
result_df = result_df.query('phase == "test"')

In [41]:
result_df.query(
    """
    task == "LOS_7" & sensitive_attribute == "gender_concept_name" & metric == "calib_group_error" & group == "FEMALE" & phase == "test"
    """
)['performance_mean']

13       5.450435e-07
621      8.002997e-07
1229     3.873992e-06
1837     4.323141e-06
2445     7.028620e-05
3053     5.477610e-05
3661     8.851241e-06
4269     2.359575e-06
4877     2.404759e-06
5485     4.135908e-06
6093     2.206707e-04
6701     2.513765e-04
7309     6.542721e-06
7917     1.170078e-05
8525     1.784844e-06
9133     3.156953e-06
9741     4.958354e-06
10349    4.519681e-04
10957    4.879118e-04
11565    1.597107e-05
12173    1.066927e-05
12781    2.684063e-06
13389    1.250092e-05
13997    9.040098e-06
14605    6.746761e-04
15213    7.984086e-04
15821    2.806319e-05
16429    8.153347e-05
17037    4.825646e-05
17645    2.671223e-05
18253    9.366479e-04
18861    9.581342e-04
19469    4.273849e-05
20077    3.243833e-04
20685    4.974112e-06
21293    2.770984e-04
21901    4.938817e-05
22509    1.159088e-03
23117    9.532767e-04
23725    3.793634e-05
24333    5.166612e-04
24941    6.660116e-04
25549    7.846712e-05
26157    1.117048e-03
26765    9.340641e-04
27373    4

In [42]:
result_df = result_df.drop(columns = '_merge')

In [45]:
result_df.to_csv(os.path.join(result_path, 'group_results.csv'), index=False)