In [1]:
library(tidyverse)
library(data.table)
library(ggplot2)
library(ggsci)
require(scales)

── [1mAttaching packages[22m ─────────────────────────────────────── tidyverse 1.3.0 ──

[32m✔[39m [34mggplot2[39m 3.3.2     [32m✔[39m [34mpurrr  [39m 0.3.4
[32m✔[39m [34mtibble [39m 3.0.1     [32m✔[39m [34mdplyr  [39m 1.0.0
[32m✔[39m [34mtidyr  [39m 1.1.0     [32m✔[39m [34mstringr[39m 1.4.0
[32m✔[39m [34mreadr  [39m 1.3.1     [32m✔[39m [34mforcats[39m 0.5.0

── [1mConflicts[22m ────────────────────────────────────────── tidyverse_conflicts() ──
[31m✖[39m [34mdplyr[39m::[32mfilter()[39m masks [34mstats[39m::filter()
[31m✖[39m [34mdplyr[39m::[32mlag()[39m    masks [34mstats[39m::lag()


Attaching package: ‘data.table’


The following objects are masked from ‘package:dplyr’:

    between, first, last


The following object is masked from ‘package:purrr’:

    transpose


Loading required package: scales


Attaching package: ‘scales’


The following object is masked from ‘package:purrr’:

    discard


The following object is masked from

In [2]:
options(repr.plot.width=16, repr.plot.height=12, repr.plot.res=300, repr.plot.quality = 300)

In [3]:
figure_path = file.path('../figures')
if (!dir.exists(figure_path)) {
    dir.create(figure_path)
}

In [4]:
db_name_list = c('starr'='starr_20200523', 'optum'='optum', 'mimic'='mimic_omop')
# db_name_list = c('mimic'='mimic_omop')

fair_modes = list(
    "mmd"=c('mmd_unconditional', 'mmd_conditional', 'mmd_conditional_pos'),
    "mean"=c('mean_prediction_unconditional', 'mean_prediction_conditional', 'mean_prediction_conditional_pos')
)

metrics = list(
    "performance"=c('auc', 'auprc', 'loss_bce'),
    "calibration"=c('calib_error', 'calib_error_signed', 'calib_group_error', 'calib_group_error_signed'),
    "fairness"=c('emd_ova', 'emd_ova_1', 'mean_prediction', 'mean_prediction_1'),
    "xauc"=c('xauc_0', 'xauc_1', 'xauc_ova_0', 'xauc_ova_1')
)

metric_list <- c(
        'auc'='AUROC', 
        'auprc'='Average Precision', 
        'loss_bce'='BCE Loss',
        'brier'="Brier Score", 
        'brier_signed'="Sign. Brier Score",
        'calib_error'='Abs. Calib. Err.',
        'calib_error_signed'='Sign. Abs. Calib. Err.',
        'calib_group_error'='Rel. Calib. Err.',
        'calib_group_error_signed'='Sign. Rel. Calib. Err.',
        'emd_ova'='EMD',
        'emd_ova_1'='EMD (y=1)',
        'emd_ova_0'='EMD (y=0)',
        'mean_prediction'='Mean Pred. Diff.',
        'mean_prediction_1'='Mean Pred. Diff. (y=1)',
        'mean_prediction_0'='Mean Pred. Diff. (y=0)',
        'xauc_0'= 'xAUC (y=0)',
        'xauc_1'= 'xAUC (y=1)',
        'xauc_ova_0'= 'Bal. xAUC (y=0)',
        'xauc_ova_1'= 'Bal. xAUC (y=1)'
    )

metric_labeler <- as_labeller(metric_list)

fair_mode_list <- c(
    'mmd_unconditional'='Uncond. MMD', 
      'mmd_conditional'='Cond. MMD', 
      'mmd_conditional_pos'='Pos. Cond. MMD',
      'mean_prediction_unconditional'='Uncond. Mean',
      'mean_prediction_conditional'='Cond. Mean',
      'mean_prediction_conditional_pos'='Pos. Cond. Mean'
     )
fair_mode_labeler <- as_labeller(fair_mode_list)

task_list <- list(
    'starr'=c('Prolonged Length of Stay'='LOS_7', 'Hospital Mortality'='hospital_mortality', '30-Day Readmission'='readmission_30'),
    'optum'=c('Prolonged Length of Stay'='LOS_7', '30-Day Readmission'='readmission_30'),
    'mimic'=c('ICU LOS > 3 Days'='los_icu_3days', 
              'ICU LOS > 7 Days'='los_icu_7days', 
              'ICU Mortality'='mortality_icu', 
              'Hospital Mortality'='mortality_hospital'
            )
)

attribute_list <- list(
    'starr'=c('Age Group'='age_group', 'Gender'='gender_concept_name', 'Race and Ethnicity'='race_eth'),
    'optum'=c('Age Group'='age_group', 'Gender'='gender_concept_name'),
    'mimic'=c('Age Group'='age_group', 'Gender'='gender_concept_name', 'Race and Ethnicity'='race_eth')
)

In [5]:
metric_vs_lambda <- function(
    df, selected_task, selected_attribute, selected_fair_mode, selected_metrics
    ) {
    plot_df <- df %>% 
        filter(task==selected_task, 
            attribute==selected_attribute, 
            sensitive_attribute==selected_attribute,
            fair_mode %in% fair_modes[[selected_fair_mode]],
            metric %in% metrics[[selected_metrics]]
        ) %>%
        mutate(group = replace(group, group =='FEMALE', 'Female'),
               group = replace(group, group =='MALE', 'Male'),
               group = replace(group, group =='[75-91)', '[75-90)'),
               group = replace(group, group =='Black or African American', 'Black'),
               group = replace(group, group =='Hispanic or Latino', 'Hispanic')
              )
    g <- ggplot(plot_df, aes(x=lambda_group_regularization, color=group)) +
        facet_grid(
            rows=vars(metric), 
            cols=vars(fair_mode), 
            scales='free', 
            switch='y', 
            labeller=labeller('metric'=metric_labeler, 'fair_mode'=fair_mode_labeler)
        ) + 
        geom_line(aes(x=lambda_group_regularization, y=performance_mean, color=group), size=1.5) +
        geom_hline(aes(yintercept=performance_baseline_mean, color=group), linetype='dashed', size=1.5) +
        geom_errorbar(
            aes(x=lambda_group_regularization, ymin=performance_interval_lower, ymax=performance_interval_upper),
            size=1.5,
            width=0.1
        ) +
        scale_x_log10(labels=trans_format('log10', math_format(10^.x))) +
        theme_bw() +
        scale_color_d3() + 
        theme(
            axis.title = element_text(size = rel(2)),
            axis.title.y = element_blank(), 
            axis.line = element_line(color='black'),
            axis.text = element_text(size=rel(1.5), color='black'),
            axis.text.y = element_text(size=rel(1)),
            strip.text.x = element_text(size = rel(2)),
            strip.text.y = element_text(size = rel(2)),
            strip.background = element_blank(),
            strip.placement = "outside",
            legend.title = element_text(size=rel(2)),
            legend.text = element_text(size=rel(1.75)),
            panel.grid.major = element_blank(), 
            panel.grid.minor = element_blank()
        ) + 
        labs(
            color = "Group"
        ) + 
        xlab(bquote(Regularization~lambda))
    return(g)
}

In [6]:
# Plot all the figures
for (db_name in names(db_name_list)) {
    data_path = file.path('/share/pi/nigam/projects/spfohl/cohorts/admissions/', db_name_list[[db_name]], 'experiments/merged_results_fold_1_10/group_results.csv')
#     data_path = file.path('./data', paste0(db_name, '.csv'))
    
    df = fread(data_path)
    df <- df %>%
        mutate(
            performance_interval_lower = performance_mean - performance_sem,
            performance_interval_upper = performance_mean + performance_sem
        )
    for (selected_metrics in names(metrics)) {
        num_metrics = length(metrics[[selected_metrics]])
        for (selected_fair_mode in names(fair_modes)) {
            for (selected_task in task_list[[db_name]]) {
                for (selected_attribute in attribute_list[[db_name]]) {
                    g <- metric_vs_lambda(
                        df=df, 
                        selected_task=selected_task, 
                        selected_attribute=selected_attribute,
                        selected_fair_mode=selected_fair_mode,
                        selected_metrics=selected_metrics
                    )
                    write_path = file.path(
                        figure_path, 
                        db_name_list[db_name], 
                        selected_metrics,
                        selected_fair_mode,
                        selected_task,
                        selected_attribute
                    )
                    print(write_path)
                    if (!dir.exists(write_path)) {
                        dir.create(write_path, recursive=TRUE)
                    }
                    ggsave(filename='plot.png', path=write_path, width = 12, height=3*num_metrics, units='in', dpi=90)
                }
            }
        }
    }
}

[1] "../figures/starr_20200523/performance/mmd/LOS_7/age_group"
[1] "../figures/starr_20200523/performance/mmd/LOS_7/gender_concept_name"
[1] "../figures/starr_20200523/performance/mmd/LOS_7/race_eth"
[1] "../figures/starr_20200523/performance/mmd/hospital_mortality/age_group"
[1] "../figures/starr_20200523/performance/mmd/hospital_mortality/gender_concept_name"
[1] "../figures/starr_20200523/performance/mmd/hospital_mortality/race_eth"
[1] "../figures/starr_20200523/performance/mmd/readmission_30/age_group"
[1] "../figures/starr_20200523/performance/mmd/readmission_30/gender_concept_name"
[1] "../figures/starr_20200523/performance/mmd/readmission_30/race_eth"
[1] "../figures/starr_20200523/performance/mean/LOS_7/age_group"
[1] "../figures/starr_20200523/performance/mean/LOS_7/gender_concept_name"
[1] "../figures/starr_20200523/performance/mean/LOS_7/race_eth"
[1] "../figures/starr_20200523/performance/mean/hospital_mortality/age_group"
[1] "../figures/starr_20200523/performance/mean/h

In [80]:
metric_list <- c(
    'auc'='AUROC', 
    'auprc'='Avg. Prec.', 
    'loss_bce'='CE. Loss',
    'emd_ova'='EMD',
    'mean_prediction'='Mean Diff.',
    'emd_ova_1'='EMD (y=1)',
    'mean_prediction_1'='Mean Diff. (y=1)',
    'calib_error'='ACE',
    'calib_error_signed'='Sign. ACE',
    'calib_group_error'='RCE',
    'calib_group_error_signed'='Sign. RCE',
    'xauc_0'= 'xAUC (y=0)',
    'xauc_1'= 'xAUC (y=1)'
)

metric_list_list <- list(
    'all_metrics' = c(
        'auc'='AUROC', 
        'auprc'='Avg. Prec.', 
        'loss_bce'='CE. Loss',
        'emd_ova'='EMD',
        'mean_prediction'='Mean Diff.',
        'emd_ova_1'='EMD (y=1)',
        'mean_prediction_1'='Mean Diff. (y=1)',
        'calib_error'='ACE',
        'calib_error_signed'='Sign. ACE',
        'calib_group_error'='RCE',
        'calib_group_error_signed'='Sign. RCE',
        'xauc_0'= 'xAUC (y=0)',
        'xauc_1'= 'xAUC (y=1)'
    ),
    'all_performance' = c(
        'auc'='AUROC', 
        'auprc'='Avg. Prec.', 
        'loss_bce'='CE. Loss',
        'calib_error'='ACE',
        'calib_error_signed'='Sign. ACE'
    ),
    'all_fairness' = c(
        'emd_ova'='EMD',
        'mean_prediction'='Mean Diff.',
        'emd_ova_1'='EMD (y=1)',
        'mean_prediction_1'='Mean Diff. (y=1)',
        'calib_group_error'='RCE',
        'calib_group_error_signed'='Sign. RCE',
        'xauc_0'= 'xAUC (y=0)',
        'xauc_1'= 'xAUC (y=1)'
    )
)

metric_labeler <- as_labeller(metric_list)

fair_mode_list <- c(
    'mmd_unconditional'='Uncond. MMD', 
    'mean_prediction_unconditional'='Uncond. Mean',
    'mmd_conditional'='Cond. MMD', 
    'mean_prediction_conditional'='Cond. Mean',
    'mmd_conditional_pos'='Pos. Cond. MMD',
    'mean_prediction_conditional_pos'='Pos. Cond. Mean'
)

fair_mode_labeler <- as_labeller(fair_mode_list)

metric_vs_lambda <- function(
    df, selected_task, selected_attribute, selected_fair_mode_list, selected_metric_list
    ) {
    plot_df <- df %>% 
        filter(task==selected_task, 
            attribute==selected_attribute, 
            sensitive_attribute==selected_attribute,
            fair_mode %in% names(selected_fair_mode_list),
            metric %in% names(selected_metric_list)
        ) %>%
        mutate(group = replace(group, group =='FEMALE', 'Female'),
               group = replace(group, group =='MALE', 'Male'),
               group = replace(group, group =='[75-91)', '[75-90)'),
               group = replace(group, group =='Black or African American', 'Black'),
               group = replace(group, group =='Hispanic or Latino', 'Hispanic')
              )
    g <- ggplot(plot_df, aes(x=lambda_group_regularization, color=group)) +
        facet_grid(
            rows=vars(factor(metric, levels=names(selected_metric_list), labels=selected_metric_list)), 
            cols=vars(factor(fair_mode, levels=names(selected_fair_mode_list), labels=selected_fair_mode_list)),
            scales='free', 
            switch='y'
        ) + 
        geom_line(aes(x=lambda_group_regularization, y=performance_mean, color=group), size=1) +
        geom_hline(aes(yintercept=performance_baseline_mean, color=group), linetype='dashed', size=1) +
        geom_errorbar(
            aes(x=lambda_group_regularization, ymin=performance_interval_lower, ymax=performance_interval_upper),
            size=1.5,
            width=0.1
        ) +
        scale_x_log10(labels=trans_format('log10', math_format(10^.x))) +
        theme_bw() +
        scale_color_d3() + 
        theme(
            axis.title = element_text(size = rel(1.75)),
            axis.title.y = element_blank(), 
            axis.line = element_line(color='black'),
            axis.text = element_text(size=rel(1), color='black'),
            axis.text.y = element_text(size=rel(0.95)),
            strip.text.x = element_text(size = rel(1.5)),
            strip.text.y = element_text(size = rel(1.5)),
            strip.background = element_blank(),
            strip.placement = "outside",
            legend.title = element_text(size=rel(1.5)),
            legend.text = element_text(size=rel(1.25)),
            panel.grid.major = element_blank(), 
            panel.grid.minor = element_blank()
        ) + 
        labs(
            color = "Group"
        ) +
        xlab(bquote(Regularization~lambda))
    return(g)
}

In [81]:
# Plot all the figures
for (db_name in names(db_name_list)) {
    data_path = file.path('/share/pi/nigam/projects/spfohl/cohorts/admissions/', db_name_list[[db_name]], 'experiments/merged_results_fold_1_10/group_results.csv')
    
    df = fread(data_path)
    df <- df %>%
        mutate(
            performance_interval_lower = performance_mean - performance_sem,
            performance_interval_upper = performance_mean + performance_sem
        )
    for (selected_metrics in names(metric_list_list)) {
        the_metrics <- metric_list_list[[selected_metrics]]
        num_metrics <- length(the_metrics)
        
        for (selected_task in task_list[[db_name]]) {
            for (selected_attribute in attribute_list[[db_name]]) {
                g <- metric_vs_lambda(
                    df=df, 
                    selected_task=selected_task, 
                    selected_attribute=selected_attribute,
                    selected_fair_mode_list=fair_mode_list,
                    selected_metric_list=the_metrics
                )
                write_path = file.path(
                    figure_path, 
                    db_name_list[db_name],
                    selected_metrics,
                    selected_task,
                    selected_attribute
                )
                print(write_path)
                if (!dir.exists(write_path)) {
                    dir.create(write_path, recursive=TRUE)
                }
                ggsave(filename='plot.png', path=write_path, width = 12, height=1.5*length(the_metrics), units='in', dpi=90, limitsize=FALSE)
            }
        }
    }
}

[1] "../figures/starr_20200523/all_metrics/LOS_7/age_group"
[1] "../figures/starr_20200523/all_metrics/LOS_7/gender_concept_name"
[1] "../figures/starr_20200523/all_metrics/LOS_7/race_eth"
[1] "../figures/starr_20200523/all_metrics/hospital_mortality/age_group"
[1] "../figures/starr_20200523/all_metrics/hospital_mortality/gender_concept_name"
[1] "../figures/starr_20200523/all_metrics/hospital_mortality/race_eth"
[1] "../figures/starr_20200523/all_metrics/readmission_30/age_group"
[1] "../figures/starr_20200523/all_metrics/readmission_30/gender_concept_name"
[1] "../figures/starr_20200523/all_metrics/readmission_30/race_eth"
[1] "../figures/starr_20200523/all_performance/LOS_7/age_group"
[1] "../figures/starr_20200523/all_performance/LOS_7/gender_concept_name"
[1] "../figures/starr_20200523/all_performance/LOS_7/race_eth"
[1] "../figures/starr_20200523/all_performance/hospital_mortality/age_group"
[1] "../figures/starr_20200523/all_performance/hospital_mortality/gender_concept_name"
[1