In [None]:
library(tidyverse)
library(data.table)
library(RColorBrewer)
library(lemon)
library(ggsci)
library(egg)

In [None]:
metric_list = c(
    'auc'='AUC',
    'auc_min'='AUC',
    'loss_bce'='Loss',
    'loss_bce_max'='Loss',
    'ace_abs_logistic_log'='ACE',
    'ace_abs_logistic_log_max'='ACE',
    'ece_q_abs'='ECE'
)

tag_list = c(
    'erm_baseline'='ERM (Pooled)',
    'erm_subset'='ERM (Stratified)',
    'aware_balanced'='ERM (Balanced)',
    'aware_auc_min'='ERM (Select AUC)',
    'aware_loss_max'='ERM (Select Loss)',
    'dro_auc_min'='DRO (Select AUC)',
    'dro_loss_max'='DRO (Select Loss)',
    'dro_loss_max_objective_loss'='DRO (Obj. Loss; Sel. Loss)',
    'dro_loss_max_objective_baselined_loss'='DRO (Obj. Marg-BL; Sel. Loss)',
    'dro_loss_max_objective_auc_proxy'='DRO (Obj. AUC; Sel. Loss)',
    'dro_loss_max_objective_size_adjusted_loss'='DRO (Obj. Prop-Adj; Sel. Loss)',
    'dro_loss_max_objective_size_adjusted_loss_reciprocal'='DRO (Obj. Recip-Adj; Sel. Loss)',
    'dro_auc_min_objective_loss'='DRO (Obj. Loss; Sel. AUC)',
    'dro_auc_min_objective_baselined_loss'='DRO (Obj. Marg-BL; Sel. AUC)',
    'dro_auc_min_objective_auc_proxy'='DRO (Obj. AUC; Sel. AUC)',
    'dro_auc_min_objective_size_adjusted_loss'='DRO (Obj. Prop-Adj; Sel. AUC)',
    'dro_auc_min_objective_size_adjusted_loss_reciprocal'='DRO (Obj. Recip-Adj; Sel. AUC)'
)

attribute_list=c(
    'FALSE'='Overall',
    'TRUE'='Worst-case',
    'age_group'='Age',
    'gender_concept_name'='Sex',
    'gender'='Sex',
    'race_eth'='Race/Ethnicity',
    'ethnicity'='Race/Ethnicity'
)

is_min_max_metric_list = c(
    'FALSE'='Overall',
    'TRUE'='Worst-case'
)

eval_group_list=c(
    'Age',
    '18-30',
    '30-45',
    '45-55',
    '55-65',
    '65-75',
    '75-90',
    '40-50',
    '50-60',
    '60-75',
    'Sex',
    'Female',
    'Male',
    'Race/Ethnicity',
    'Asian',
    'Black',
    'Hispanic',
    'Other',
    'White'
)

In [None]:
transform_df <- function(
    df,
    var_to_spread,
    metrics_to_plot,
    fold_id_to_plot
) {    
    temp <- df %>%
        filter(
            metric %in% metrics_to_plot,
            tag %in% tags_to_plot,
            eval_group != 'overall'
        ) %>% 
        select(metric, eval_attribute, eval_group, 
               tag, CI_quantile_95, .data[[var_to_spread]]) %>%
        distinct() %>%
        spread(CI_quantile_95, .data[[var_to_spread]]) %>%
        mutate(
            eval_group = replace(eval_group, eval_group =='FEMALE', 'Female'),
            eval_group = replace(eval_group, eval_group =='F', 'Female'),
            eval_group = replace(eval_group, eval_group =='MALE', 'Male'),
            eval_group = replace(eval_group, eval_group =='M', 'Male'),
            eval_group = replace(eval_group, eval_group =='[75-91)', '[75-90)'),
            eval_group = replace(eval_group, eval_group =='Black or African American', 'Black'),
            eval_group = replace(eval_group, eval_group =='black', 'Black'),
            eval_group = replace(eval_group, eval_group =='African American', 'Black'),
            eval_group = replace(eval_group, eval_group =='Hispanic or Latino', 'Hispanic'),
            eval_group = replace(eval_group, eval_group =='other', 'Other'),
            eval_group = replace(eval_group, eval_group =='white', 'White'),
            eval_group = replace(eval_group, eval_group =='Caucasian', 'White'),
            eval_group = gsub('[[)]', "", eval_group)
        ) %>%
        mutate(
            metric=factor(metric, levels=names(metric_list[metrics_to_plot]), labels=metric_list[metrics_to_plot]),
            tag=factor(tag, levels=names(tag_list[tags_to_plot]), labels=tag_list[tags_to_plot]),
            eval_attribute=factor(eval_attribute, levels=names(attribute_list), labels=attribute_list)
        )
    return(temp)
}

transform_df_combined_marginal <- function(
    df,
    var_to_spread,
    metrics_to_plot,
    fold_id_to_plot
) {    
    temp <- df %>%
         mutate(
            bare_metric=as.character(strsplit(metric, '_min|_max')),
            is_min_max_metric=grepl('_min|_max', metric)
        ) %>%
        filter(
            metric %in% metrics_to_plot,
            tag %in% tags_to_plot,
            eval_group == 'overall'
        ) %>%   
        mutate(metric=bare_metric) %>%
        select(
            metric, 
            is_min_max_metric, 
            eval_attribute, 
            eval_group, 
            tag, 
            CI_quantile_95, 
            .data[[var_to_spread]]
        ) %>%
        distinct() %>%
        spread(CI_quantile_95, .data[[var_to_spread]]) %>%
        mutate(
            metric=factor(metric, levels=names(metric_list[metrics_to_plot]), labels=metric_list[metrics_to_plot]),
            tag=factor(tag, levels=names(tag_list[tags_to_plot]), labels=tag_list[tags_to_plot]),
            eval_attribute=factor(eval_attribute, levels=names(attribute_list), labels=attribute_list),
            is_min_max_metric=factor(is_min_max_metric, levels=names(is_min_max_metric_list), labels=is_min_max_metric_list)
        ) %>%
        mutate(
            eval_group=eval_attribute,
            eval_attribute=is_min_max_metric
        )
    return(temp)
}


make_plot_combined_marginal <- function(
    df,
    tags_to_plot=c('erm_baseline', 'erm_subset', 'aware_loss_max', 'dro_loss_max'),
    fold_id_to_plot='test',
    combined=FALSE,
    y_label='',
    mode=NULL
) {
    
    results_absolute <- transform_df(
        df, 
        var_to_spread='comparator',
        metrics_to_plot=c('auc', 'loss_bce', 'ace_abs_logistic_log'),
        fold_id_to_plot='test'
    )

    results_relative <- transform_df(
        df, 
        var_to_spread='delta',
        metrics_to_plot=c('auc', 'loss_bce', 'ace_abs_logistic_log'),
        fold_id_to_plot='test'
    ) %>% mutate(metric = paste0(metric, ' (rel)')) %>% mutate(erm_value=0)
    
    results_absolute_marginal <- transform_df_combined_marginal(
        df, 
        var_to_spread='comparator',
        metrics_to_plot=c('auc', 'auc_min', 'loss_bce', 'loss_bce_max', 'ace_abs_logistic_log', 'ace_abs_logistic_log_max'),
        fold_id_to_plot='test'
    )

    results_relative_marginal <- transform_df_combined_marginal(
        df,
        var_to_spread='delta',
        metrics_to_plot=c('auc', 'auc_min', 'loss_bce', 'loss_bce_max', 'ace_abs_logistic_log', 'ace_abs_logistic_log_max'),
        fold_id_to_plot='test'
    ) %>% mutate(metric = paste0(metric, ' (rel)')) %>% mutate(erm_value=0)

    combined_results <- full_join(results_absolute, results_relative) %>% 
        full_join(results_absolute_marginal) %>% 
        full_join(results_relative_marginal) %>%
        mutate(
            eval_attribute=factor(eval_attribute, levels=attribute_list, labels=attribute_list),
            eval_group=factor(eval_group, levels=eval_group_list, labels=eval_group_list),
            metric=factor(metric, levels=c('AUC', 'AUC (rel)', 'ACE', 'ACE (rel)', 'Loss', 'Loss (rel)'))
        )
    
    if (!is.null(mode)) {
        if (mode == 'relative') {
            
            combined_results <- combined_results %>% 
                filter(metric %in% c('AUC (rel)', 'ACE (rel)', 'Loss (rel)')) %>%
                mutate(metric=str_replace(metric, stringr::fixed(' (rel)'), stringr::fixed(' (relative)'))) %>%
                mutate(metric=factor(metric, levels=c('AUC (relative)', 'ACE (relative)', 'Loss (relative)')))
        }
    }
    
    g <- combined_results %>% 
        ggplot(aes(eval_group, mid, color=tag)) + 
        coord_cartesian(clip=FALSE) +
        geom_point(position=position_dodge(width=0.75), size=0.5) +
        geom_linerange(
            aes(ymin=lower, ymax=upper), 
            size=0.5,
            position=position_dodge(width=0.75)
        ) + 
        lemon::facet_rep_grid(
            rows = vars(metric), 
            cols=vars(eval_attribute), 
            scales='free',
            switch='y',
        ) +
        theme_bw() +
        ggsci::scale_color_d3() +
        theme(
            axis.title = element_text(size = rel(1.75)),
            axis.title.y = element_blank(),
            axis.title.x = element_blank(),
            strip.text.x = element_text(size = rel(1.35), vjust=1),
            strip.text.y = element_text(size = rel(1.1)),
            strip.background = element_blank(),
            strip.placement = "outside",
            axis.text.x = element_text(angle = 45, vjust=0.95, hjust=1),
            axis.text = element_text(size=rel(1), color='black'),
            panel.grid.major = element_blank(), 
            panel.grid.minor = element_blank(),
            panel.border = element_blank(),
            axis.line = element_line(color='black'),
            legend.text=element_text(size=rel(0.85)),
            legend.position='bottom'
        ) +
        labs(
            y=y_label,
            color = "Method"
        )
    g <- g + geom_hline(aes(yintercept=erm_value), color='black', linetype='dashed', size=0.5, alpha=0.5)
    g <- tag_facet(g, open="", close="", tag_pool=c(toupper(letters), as.character(tolower(as.roman(1:10)))), 
                   hjust = -0.5, vjust = 0.5
                  )
        return(g)
}

In [None]:
task_path_prefixes = c(
    'admissions/los',
    'admissions/mortality',
    'admissions/readmission',
    'mimic/mortality',
    'eicu/mortality',
)

for (task_path_prefix in task_path_prefixes) {
    data_path = '../figures_data/'
    
    results_path = file.path(data_path, task_path_prefix, 'result_df_ci_no_agg.csv')
    aggregated_results = fread(results_path)

    figure_path = file.path('../figures', task_path_prefix)
    dir.create(figure_path, recursive=TRUE)
    
    ## Plot absolute performance metrics
    tags_to_plot <- c(
        'erm_baseline', 
        'erm_subset', 
        'aware_balanced',
        'aware_auc_min',
        'aware_loss_max',
        'dro_auc_min',
        'dro_loss_max'
    )
    
    metrics_to_plot <- c('auc', 'loss_bce', 'ace_abs_logistic_log')
    fold_id_to_plot <- 'test'
    
    if ('exp' %in% names(aggregated_results)) {
        aggregated_results <- aggregated_results %>% rename(experiment_name=exp)
    }

    g <- make_plot_combined_marginal(
        df=aggregated_results, 
        tags_to_plot=tags_to_plot, 
        fold_id_to_plot=fold_id_to_plot,
        mode='relative'
    )

    ggsave(filename=file.path(figure_path, 'method_comparison_relative.png'), plot=g, device='png', width=8, height=6, units='in')
    ggsave(filename=file.path(figure_path, 'method_comparison_relative.pdf'), plot=g, device='pdf', width=8, height=6, units='in')
    
    # Combined marginal
    g <- make_plot_combined_marginal(
        df=aggregated_results, 
        tags_to_plot=tags_to_plot, 
        fold_id_to_plot=fold_id_to_plot
    )

    ggsave(filename=file.path(figure_path, 'method_comparison_combined_marginal.png'), plot=g, device='png', width=8, height=8, units='in')
    ggsave(filename=file.path(figure_path, 'method_comparison_combined_marginal.pdf'), plot=g, device='pdf', width=8, height=8, units='in')
    
    ## Plot DRO ablation, with max loss selection
    tags_to_plot <- c(
        'erm_baseline', 
        'dro_loss_max_objective_loss',
        'dro_loss_max_objective_baselined_loss',
        'dro_loss_max_objective_auc_proxy',
        'dro_loss_max_objective_size_adjusted_loss',
        'dro_loss_max_objective_size_adjusted_loss_reciprocal'
    )

    g <- make_plot_combined_marginal(
        df=aggregated_results, 
        tags_to_plot=tags_to_plot, 
        fold_id_to_plot=fold_id_to_plot,
        mode='relative'
    )
    
    ggsave(filename=file.path(figure_path, 'DRO_comparison_loss_selection_relative.png'), plot=g, device='png', width=8, height=6, units='in')
    ggsave(filename=file.path(figure_path, 'DRO_comparison_loss_selection_relative.pdf'), plot=g, device='pdf', width=8, height=6, units='in')
      
    # Combined marginal plot
    g <- make_plot_combined_marginal(
        df=aggregated_results, 
        tags_to_plot=tags_to_plot, 
        fold_id_to_plot=fold_id_to_plot
    )

    ggsave(filename=file.path(figure_path, 'DRO_comparison_loss_selection_combined_marginal.png'), plot=g, device='png', width=8, height=8, units='in')
    ggsave(filename=file.path(figure_path, 'DRO_comparison_loss_selection_combined_marginal.pdf'), plot=g, device='pdf', width=8, height=8, units='in')
    
#     # Plot DRO ablation, with min AUC selection
    tags_to_plot <- c(
        'erm_baseline', 
        'dro_auc_min_objective_loss',
        'dro_auc_min_objective_baselined_loss',
        'dro_auc_min_objective_auc_proxy',
        'dro_auc_min_objective_size_adjusted_loss',
        'dro_auc_min_objective_size_adjusted_loss_reciprocal'
    ) 
    
    g <- make_plot_combined_marginal(
        df=aggregated_results, 
        tags_to_plot=tags_to_plot, 
        fold_id_to_plot=fold_id_to_plot,
        mode='relative'
    )
    
    ggsave(filename=file.path(figure_path, 'DRO_comparison_auc_selection_relative.png'), plot=g, device='png', width=8, height=6, units='in')
    ggsave(filename=file.path(figure_path, 'DRO_comparison_auc_selection_relative.pdf'), plot=g, device='pdf', width=8, height=6, units='in')
      

    # Combined marginal plot
    g <- make_plot_combined_marginal(
        df=aggregated_results, 
        tags_to_plot=tags_to_plot, 
        fold_id_to_plot=fold_id_to_plot
    )

    ggsave(filename=file.path(figure_path, 'DRO_comparison_auc_selection_combined_marginal.png'), plot=g, device='png', width=8, height=8, units='in')
    ggsave(filename=file.path(figure_path, 'DRO_comparison_auc_selection_combined_marginal.pdf'), plot=g, device='pdf', width=8, height=8, units='in')
   
}