In [27]:
import numpy as np
import pandas as pd
import os
import yaml

pd.options.display.max_rows = 999

#### Paths & Vars

In [1]:
# vars
adapt_methods = ['coral','al_layer']
gen_methods = ['irm','dro','coral','al_layer']
tasks = ['mortality','longlos','invasivevent','sepsis']
metrics = ['auc','auprc','ace_abs_logistic_log']
Ns = [100,500,1000,1500]

# paths
tables_fpath = '/hpf/projects/lsung/projects/mimic4ds/Experiments/domain_adapt/tables/'
adapt_fpath = '/hpf/projects/lsung/projects/mimic4ds/Experiments/domain_adapt'
gen_fpath = '/hpf/projects/lsung/projects/mimic4ds/Experiments/domain_gen'
base_fpath = '/hpf/projects/lsung/projects/mimic4ds/Experiments/baseline'

In [53]:
alpha = 0.05

results = {
    "base":{
        "eval":pd.read_csv(f"{base_fpath}/results/model_evaluation_by_gender_{alpha}.csv").query("phase=='test'"),
        "compare":pd.read_csv(f"{base_fpath}/results/model_comparison_by_gender_{alpha}.csv").query("phase=='test'"),
    },
    "gen":{
        "eval":pd.read_csv(f"{gen_fpath}/results/model_evaluation_by_gender_{alpha}.csv").query("phase=='test' and group==1"),
        "compare":pd.read_csv(f"{gen_fpath}/results/model_comparison_by_gender_{alpha}.csv").query("phase=='test' and group==1"),
    },
    "adapt":{
        "eval":pd.read_csv(f"{adapt_fpath}/results/model_evaluation_by_gender_{alpha}.csv").query("phase=='test' and group==1"),
        "compare":pd.read_csv(f"{adapt_fpath}/results/model_comparison_by_gender_{alpha}.csv").query("phase=='test' and group==1"),
    },
}

df_results_all = {}
for gender in ['F','M']:
    # base results
    df_base = results['base']['eval'].query(
        "\
        train_group=='2008 - 2010' and\
        eval_group=='2017 - 2019' and\
        metric==@metrics and\
        gender==@gender\
        "
    )[['analysis_id', 'metric', 'ci_lower','ci_med','ci_upper']]
    df_base['group']=1
    df_base['train_method']='Baseline [08-10]'

    # oracle results
    df_oracle = results['base']['eval'].query(
        "\
        train_group=='2017 - 2019' and\
        eval_group=='2017 - 2019' and\
        metric==@metrics and\
        gender==@gender\
        "
    )[['analysis_id', 'metric', 'ci_lower','ci_med','ci_upper']]
    df_oracle['group']=1
    df_oracle['train_method'] = 'Oracle [17-19]'

    # ERM results
    df_erm = results['gen']['eval'].query(
        "\
        train_method=='erm' and\
        metric==@metrics and\
        gender==@gender\
        "
    )[['analysis_id', 'metric', 'ci_lower','ci_med','ci_upper','group','train_method']]

    # domain gen results
    df_gen = results['gen']['eval'].query(
        "\
        train_method!='erm' and\
        metric==@metrics and\
        gender==@gender\
        "
    )[['analysis_id', 'metric', 'ci_lower','ci_med','ci_upper','group','train_method']]
    df_gen['framework'] = 'Domain Generalization'

    stats = results['gen']['compare'].query(
        "\
        metric==@metrics and\
        gender==@gender\
        "
    )[['analysis_id', 'metric', 'ci_lower','ci_upper','group','train_method']]
    stats['sig'] = stats['ci_lower']*stats['ci_upper']>0
    stats.drop(columns=['ci_lower','ci_upper'],inplace=True)

    df_gen = pd.merge(
        df_gen, 
        stats, 
        how='left', 
        left_on=['analysis_id','metric','group', 'train_method'],
        right_on = ['analysis_id','metric','group','train_method']
    )

    # domain adapt results
    df_adapt = results['adapt']['eval'].query(
        "\
        metric==@metrics and\
        gender==@gender\
        "
    )[['analysis_id', 'metric', 'ci_lower','ci_med','ci_upper','group','train_method','n_ood']]
    df_adapt['framework'] = 'Domain Adaptation'

    stats = results['adapt']['compare'].query(
        "\
        metric==@metrics and\
        gender==@gender\
        "
    )[['analysis_id', 'metric', 'ci_lower','ci_upper','group','train_method','n_ood']]
    stats['sig'] = stats['ci_lower']*stats['ci_upper']>0
    stats.drop(columns=['ci_lower','ci_upper'],inplace=True)

    df_adapt = pd.merge(
        df_adapt, 
        stats, 
        how='left', 
        left_on=['analysis_id','metric','group', 'train_method','n_ood'],
        right_on = ['analysis_id','metric','group','train_method','n_ood']
    )

    # join results
    df_results = pd.concat((df_base, df_oracle, df_erm, df_gen, df_adapt))
    df_results['sig'].replace({
        False:'',
        True:'*',
        np.nan:''
    }, inplace=True)

    # combine CI columns into performance column
    df_results['Performance'] = (
        df_results['ci_med'].apply('{:.3f}'.format) + 
        df_results['sig'] +
        " (" + 
        df_results['ci_lower'].apply('{:.3f}'.format) + 
        ',' + 
        df_results['ci_upper'].apply('{:.3f}'.format) +
        ')'
    )
    df_results.drop(columns=['ci_lower','ci_med','ci_upper'],inplace=True)
    df_results.fillna(" ",inplace=True)

    # rename columns and values
    df_results.rename(columns = {
        'metric':'Metric',
        'group':'Year Group',
        'analysis_id':'Task',
        'n_ood':'Unlabeled OOD Samples',
        'framework': 'Framework',
        'train_method':'Method'},inplace=True)

    # Replace values
    df_results['Year Group'].replace(
        {
            0:"2008 - 2016 [ID]",
            1:"2017 - 2019 [OOD]"
        },
        inplace=True
    )
    df_results['Metric'].replace(
        {
            'auc':"AUROC",
            'auprc':"AUPRC",
            'ace_abs_logistic_log':'Calibration'
        },
        inplace=True
    )
    df_results['Method'].replace(
        {
            'al_layer':"AL",
            'coral':"CORAL",
            'erm':"ERM",
            'irm':'IRM',
            'dro':'GroupDRO'
        },
        inplace=True
    )
    df_results['Task'].replace(
        {
            'longlos':'Long LOS',
            'sepsis':'Sepsis',
            'mortality':'Mortality',
            'invasivevent':'Invasive Ventilation',
        },
        inplace=True
    )
    

    # pivot table    
    df_results = df_results.pivot(index=["Task","Year Group","Metric"],columns=["Framework","Unlabeled OOD Samples","Method"],values=["Performance"])
    df_results.fillna(" ",inplace = True)
    df_results.columns = pd.MultiIndex.from_tuples([x[1:] for x in df_results.columns], names = ['Framework','Unlabelled OOD Samples','Method'])

    # Order Indices & Columns
    df_results = df_results[[
        (                    ' ',    ' ', 'Baseline [08-10]'),
        (                    ' ',    ' ',   'Oracle [17-19]'),
        (                    ' ',    ' ',              'ERM'),
        ('Domain Generalization',    ' ',              'IRM'),
        ('Domain Generalization',    ' ',         'GroupDRO'),
        ('Domain Generalization',    ' ',               'AL'),
        ('Domain Generalization',    ' ',            'CORAL'),
        (    'Domain Adaptation',  100.0,               'AL'),
        (    'Domain Adaptation',  100.0,            'CORAL'),
        (    'Domain Adaptation',  500.0,               'AL'),
        (    'Domain Adaptation',  500.0,            'CORAL'),
        (    'Domain Adaptation', 1000.0,               'AL'),
        (    'Domain Adaptation', 1000.0,            'CORAL'),
        (    'Domain Adaptation', 1500.0,               'AL'),
        (    'Domain Adaptation', 1500.0,            'CORAL'),
    ]]

    df_results = df_results.reindex(labels = ['Long LOS','Sepsis','Mortality','Invasive Ventilation'], level=0)
    df_results = df_results.reindex(labels = ['AUROC','AUPRC','Calibration'],level=2)
    #df_results = df_results.style.apply(highlight_sig, axis=0)

    # add to dictionary
    df_results_all[gender] = df_results

In [54]:
df_results_all['F']

Unnamed: 0_level_0,Unnamed: 1_level_0,Framework,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Domain Generalization,Domain Generalization,Domain Generalization,Domain Generalization,Domain Adaptation,Domain Adaptation,Domain Adaptation,Domain Adaptation,Domain Adaptation,Domain Adaptation,Domain Adaptation,Domain Adaptation
Unnamed: 0_level_1,Unnamed: 1_level_1,Unlabelled OOD Samples,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,100.0,100.0,500.0,500.0,1000.0,1000.0,1500.0,1500.0
Unnamed: 0_level_2,Unnamed: 1_level_2,Method,Baseline [08-10],Oracle [17-19],ERM,IRM,GroupDRO,AL,CORAL,AL,CORAL,AL,CORAL,AL,CORAL,AL,CORAL
Task,Year Group,Metric,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3
Long LOS,2017 - 2019 [OOD],AUROC,"0.614 (0.600,0.627)","0.702 (0.690,0.714)","0.651 (0.640,0.661)","0.651 (0.641,0.662)","0.652 (0.641,0.662)","0.651 (0.641,0.661)","0.652 (0.642,0.663)","0.652 (0.641,0.662)","0.652 (0.642,0.662)","0.654* (0.643,0.664)","0.654* (0.644,0.664)","0.652 (0.642,0.662)","0.653 (0.643,0.663)","0.652 (0.642,0.662)","0.653 (0.643,0.663)"
Long LOS,2017 - 2019 [OOD],AUPRC,"0.479 (0.459,0.499)","0.534 (0.513,0.554)","0.501 (0.486,0.516)","0.505 (0.489,0.520)","0.503 (0.487,0.518)","0.502 (0.487,0.518)","0.504 (0.488,0.519)","0.503 (0.488,0.518)","0.503 (0.487,0.518)","0.504 (0.489,0.520)","0.506* (0.491,0.521)","0.502 (0.486,0.517)","0.503 (0.488,0.519)","0.503 (0.487,0.518)","0.507* (0.492,0.522)"
Long LOS,2017 - 2019 [OOD],Calibration,"0.096 (0.087,0.105)","0.039 (0.032,0.047)","0.065 (0.058,0.071)","0.058* (0.052,0.064)","0.077 (0.070,0.083)","0.064 (0.057,0.070)","0.063 (0.057,0.070)","0.061* (0.054,0.067)","0.062* (0.056,0.069)","0.060* (0.054,0.066)","0.061* (0.054,0.067)","0.062* (0.056,0.069)","0.064 (0.057,0.070)","0.064 (0.057,0.070)","0.066 (0.060,0.073)"
Sepsis,2017 - 2019 [OOD],AUROC,"0.661 (0.636,0.686)","0.691 (0.667,0.716)","0.646 (0.626,0.666)","0.643 (0.623,0.663)","0.648 (0.629,0.667)","0.650 (0.630,0.670)","0.644 (0.624,0.665)","0.644 (0.624,0.664)","0.646 (0.626,0.667)","0.643 (0.623,0.664)","0.644 (0.624,0.665)","0.646 (0.625,0.666)","0.645 (0.625,0.666)","0.646 (0.625,0.666)","0.644 (0.624,0.664)"
Sepsis,2017 - 2019 [OOD],AUPRC,"0.204 (0.181,0.229)","0.284 (0.250,0.319)","0.229 (0.207,0.251)","0.230 (0.208,0.253)","0.208* (0.190,0.227)","0.234 (0.212,0.257)","0.229 (0.208,0.252)","0.227 (0.205,0.249)","0.232 (0.210,0.255)","0.226 (0.205,0.248)","0.233 (0.211,0.256)","0.228 (0.207,0.251)","0.233 (0.211,0.256)","0.235 (0.213,0.258)","0.230 (0.208,0.252)"
Sepsis,2017 - 2019 [OOD],Calibration,"0.044 (0.037,0.052)","0.020 (0.015,0.026)","0.038 (0.032,0.043)","0.037 (0.032,0.043)","0.038 (0.033,0.044)","0.035* (0.030,0.041)","0.036 (0.030,0.042)","0.039 (0.033,0.044)","0.038 (0.033,0.044)","0.040* (0.034,0.046)","0.038 (0.032,0.044)","0.039* (0.034,0.045)","0.037 (0.032,0.043)","0.038 (0.032,0.043)","0.039 (0.033,0.045)"
Mortality,2017 - 2019 [OOD],AUROC,"0.863 (0.846,0.879)","0.887 (0.870,0.901)","0.892 (0.880,0.904)","0.892 (0.880,0.903)","0.850* (0.836,0.863)","0.892 (0.881,0.904)","0.893 (0.881,0.904)","0.892 (0.880,0.904)","0.892 (0.881,0.903)","0.892 (0.880,0.904)","0.893 (0.881,0.904)","0.893 (0.881,0.904)","0.894 (0.882,0.905)","0.892 (0.880,0.904)","0.892 (0.880,0.903)"
Mortality,2017 - 2019 [OOD],AUPRC,"0.484 (0.442,0.523)","0.490 (0.449,0.531)","0.557 (0.526,0.588)","0.559 (0.527,0.590)","0.487* (0.456,0.518)","0.559 (0.528,0.590)","0.558 (0.527,0.589)","0.554 (0.523,0.585)","0.548 (0.517,0.579)","0.560 (0.529,0.591)","0.560 (0.529,0.590)","0.554 (0.523,0.585)","0.564 (0.532,0.594)","0.556 (0.524,0.587)","0.558 (0.527,0.589)"
Mortality,2017 - 2019 [OOD],Calibration,"0.021 (0.017,0.026)","0.022 (0.018,0.026)","0.023 (0.020,0.026)","0.023 (0.019,0.026)","0.045* (0.041,0.049)","0.023 (0.020,0.027)","0.023 (0.020,0.027)","0.024 (0.020,0.027)","0.021* (0.018,0.024)","0.023 (0.020,0.027)","0.022* (0.019,0.025)","0.023 (0.020,0.026)","0.022* (0.019,0.025)","0.022 (0.019,0.026)","0.022* (0.019,0.025)"
Invasive Ventilation,2017 - 2019 [OOD],AUROC,"0.872 (0.850,0.892)","0.870 (0.847,0.891)","0.877 (0.859,0.893)","0.874 (0.857,0.891)","0.878 (0.860,0.893)","0.874 (0.856,0.890)","0.876 (0.860,0.892)","0.876 (0.859,0.892)","0.875 (0.858,0.891)","0.874 (0.856,0.890)","0.878 (0.861,0.893)","0.873 (0.855,0.889)","0.880 (0.863,0.895)","0.873 (0.856,0.890)","0.873 (0.856,0.890)"


In [56]:
df_results_all['M']

Unnamed: 0_level_0,Unnamed: 1_level_0,Framework,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Domain Generalization,Domain Generalization,Domain Generalization,Domain Generalization,Domain Adaptation,Domain Adaptation,Domain Adaptation,Domain Adaptation,Domain Adaptation,Domain Adaptation,Domain Adaptation,Domain Adaptation
Unnamed: 0_level_1,Unnamed: 1_level_1,Unlabelled OOD Samples,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,100.0,100.0,500.0,500.0,1000.0,1000.0,1500.0,1500.0
Unnamed: 0_level_2,Unnamed: 1_level_2,Method,Baseline [08-10],Oracle [17-19],ERM,IRM,GroupDRO,AL,CORAL,AL,CORAL,AL,CORAL,AL,CORAL,AL,CORAL
Task,Year Group,Metric,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3
Long LOS,2017 - 2019 [OOD],AUROC,"0.664 (0.652,0.675)","0.705 (0.694,0.715)","0.696 (0.688,0.704)","0.696 (0.688,0.704)","0.698 (0.690,0.706)","0.697 (0.689,0.705)","0.695 (0.687,0.703)","0.697 (0.689,0.705)","0.696 (0.688,0.705)","0.695 (0.686,0.703)","0.698 (0.690,0.706)","0.697 (0.689,0.705)","0.698 (0.690,0.706)","0.697 (0.689,0.705)","0.696 (0.688,0.704)"
Long LOS,2017 - 2019 [OOD],AUPRC,"0.538 (0.522,0.556)","0.554 (0.538,0.571)","0.559 (0.546,0.572)","0.561 (0.548,0.574)","0.561 (0.548,0.574)","0.562 (0.549,0.575)","0.560 (0.547,0.573)","0.562 (0.549,0.575)","0.561 (0.548,0.575)","0.559 (0.546,0.572)","0.561 (0.548,0.574)","0.562 (0.549,0.574)","0.564* (0.550,0.576)","0.561 (0.548,0.574)","0.558 (0.545,0.571)"
Long LOS,2017 - 2019 [OOD],Calibration,"0.085 (0.077,0.093)","0.039 (0.033,0.046)","0.061 (0.055,0.067)","0.055* (0.049,0.060)","0.079 (0.073,0.084)","0.060 (0.054,0.066)","0.062 (0.057,0.068)","0.058* (0.053,0.064)","0.060 (0.054,0.065)","0.059* (0.053,0.064)","0.058* (0.052,0.063)","0.059 (0.054,0.065)","0.061 (0.055,0.066)","0.061 (0.055,0.066)","0.066* (0.060,0.071)"
Sepsis,2017 - 2019 [OOD],AUROC,"0.633 (0.614,0.651)","0.762 (0.744,0.780)","0.725 (0.712,0.737)","0.720* (0.707,0.732)","0.673* (0.660,0.686)","0.720* (0.708,0.733)","0.720* (0.708,0.732)","0.722 (0.709,0.734)","0.722 (0.709,0.734)","0.722 (0.709,0.734)","0.722 (0.710,0.734)","0.725 (0.712,0.737)","0.720* (0.708,0.732)","0.723 (0.711,0.735)","0.720* (0.707,0.732)"
Sepsis,2017 - 2019 [OOD],AUPRC,"0.144 (0.131,0.157)","0.328 (0.297,0.360)","0.194 (0.181,0.207)","0.190 (0.177,0.203)","0.162* (0.151,0.173)","0.189* (0.177,0.202)","0.191 (0.178,0.204)","0.189* (0.177,0.202)","0.189* (0.177,0.202)","0.192 (0.179,0.205)","0.191 (0.178,0.204)","0.193 (0.181,0.207)","0.188* (0.176,0.201)","0.191 (0.179,0.205)","0.189* (0.176,0.202)"
Sepsis,2017 - 2019 [OOD],Calibration,"0.081 (0.074,0.087)","0.015 (0.011,0.020)","0.030 (0.026,0.034)","0.031 (0.027,0.035)","0.048* (0.043,0.052)","0.030 (0.026,0.034)","0.029 (0.025,0.033)","0.032 (0.028,0.036)","0.033* (0.029,0.037)","0.033* (0.029,0.037)","0.031 (0.027,0.035)","0.031 (0.028,0.035)","0.031 (0.027,0.035)","0.030 (0.026,0.034)","0.033* (0.029,0.037)"
Mortality,2017 - 2019 [OOD],AUROC,"0.889 (0.876,0.902)","0.917 (0.909,0.925)","0.898 (0.890,0.906)","0.899 (0.890,0.907)","0.841* (0.827,0.854)","0.897 (0.888,0.905)","0.896 (0.887,0.904)","0.896 (0.887,0.904)","0.895 (0.886,0.904)","0.896 (0.888,0.904)","0.898 (0.889,0.906)","0.895* (0.886,0.903)","0.900 (0.891,0.908)","0.896 (0.888,0.905)","0.896 (0.888,0.905)"
Mortality,2017 - 2019 [OOD],AUPRC,"0.466 (0.428,0.504)","0.521 (0.484,0.556)","0.554 (0.526,0.581)","0.555 (0.527,0.582)","0.521* (0.492,0.549)","0.554 (0.526,0.581)","0.554 (0.527,0.581)","0.551 (0.523,0.578)","0.555 (0.528,0.582)","0.555 (0.527,0.582)","0.554 (0.526,0.581)","0.553 (0.525,0.580)","0.558 (0.530,0.585)","0.555 (0.527,0.582)","0.556 (0.528,0.583)"
Mortality,2017 - 2019 [OOD],Calibration,"0.013 (0.011,0.016)","0.012 (0.010,0.014)","0.015 (0.013,0.018)","0.015 (0.013,0.017)","0.042* (0.039,0.045)","0.016 (0.013,0.018)","0.016 (0.014,0.019)","0.016 (0.014,0.019)","0.013* (0.011,0.015)","0.016 (0.014,0.018)","0.015 (0.013,0.017)","0.016 (0.014,0.018)","0.014* (0.012,0.016)","0.015 (0.013,0.018)","0.015 (0.013,0.017)"
Invasive Ventilation,2017 - 2019 [OOD],AUROC,"0.878 (0.864,0.890)","0.882 (0.870,0.894)","0.886 (0.877,0.895)","0.887 (0.877,0.896)","0.883 (0.874,0.893)","0.886 (0.876,0.895)","0.886 (0.877,0.895)","0.885 (0.876,0.894)","0.886 (0.876,0.895)","0.885 (0.876,0.894)","0.887 (0.878,0.896)","0.887 (0.878,0.896)","0.887 (0.877,0.896)","0.880* (0.871,0.890)","0.885 (0.875,0.894)"
