In [None]:
library(tidyverse)
library(data.table)
library(RColorBrewer)
library(lemon)
library(ggsci)
library(egg)

In [None]:
metric_list = c(
    'auc'='AUC',
    'auc_min'='AUC',
    'loss_bce'='Loss',
    'loss_bce_max'='Loss',
    'ace_abs_logistic_logit'='ACE',
    'ace_abs_logistic_logit_max'='ACE',
    'net_benefit_rr_0.075'='NB (7.5%)',
    'net_benefit_rr_0.075_min'='NB (7.5%)',
    'net_benefit_rr_recalib_0.075'='cNB (7.5%)',
    'net_benefit_rr_recalib_0.075_min'='cNB (7.5%)',
    'net_benefit_rr_0.2'='NB (20%)',
    'net_benefit_rr_0.2_min'='NB (20%)',
    'net_benefit_rr_recalib_0.2'='cNB (20%)',
    'net_benefit_rr_recalib_0.2_min'='cNB (20%)'
)

tag_list = c(
    'erm_baseline'='ERM (Pooled)',
    'erm_subset'='ERM (Stratified)',
    'regularized_loss_max'='Regularized (Loss)',
    'regularized_auc_min'='Regularized (AUC)',
    'dro_loss_max'='DRO (Loss)',
    'dro_auc_min'='DRO (AUC)'
)

attribute_list=c(
    'FALSE'='Overall',
    'TRUE'='Worst-case',
    'age_group'='Age',
    'gender_concept_name'='Sex',
    'race_eth'='Race/Eth',
    'race_eth_gender'='Race/Eth/Sex',
    'has_ckd_history'='CKD',
    'has_ra_history'='RA',
    'has_diabetes_type1_history'='Diabetes T1',
    'has_diabetes_type2_history'='Diabetes T2'
)



is_min_max_metric_list = c(
    'FALSE'='Overall',
    'TRUE'='Worst-case'
)

eval_group_list=c(
    'Age',
    '40-50',
    '50-60',
    '60-75',
    'Sex',
    'Female',
    'Male',
    'Race/Eth',
    'Race/Eth/Sex',
    'Asian',
    'Black',
    'Hispanic',
    'Other',
    'White',
    'A-F',
    'A-M',
    'B-F',
    'B-M',
    'H-F',
    'H-M',
    'O-F',
    'O-M',
    'W-F',
    'W-M',
    'ckd_present',
    'ckd_absent',
    'ra_present',
    'ra_absent',
    'diabetes_type1_present',
    'diabetes_type1_absent',
    'diabetes_type2_present',
    'diabetes_type2_absent',
    'CKD',
    'RA',
    'Diabetes T1',
    'Diabetes T2',
    'Present',
    'Absent'
)

eval_group_clean_map <- c(
    'FEMALE'='Female',
    'MALE'='Male',
    'Black or African American'='Black',
    'Hispanic or Latino'='Hispanic',
    'other'= 'Other',
    'white'= 'White',
    'Asian | FEMALE'= 'A-F',
    'Asian | MALE'= 'A-M',
    'Black or African American | FEMALE'='B-F',
    'Black or African American | MALE'='B-M',
    'Hispanic or Latino | FEMALE'='H-F',
    'Hispanic or Latino | MALE'='H-M',
    'Other | FEMALE'='O-F',
    'Other | MALE'='O-M',
    'White | FEMALE'='W-F',
    'White | MALE'='W-M',
    'ckd_present'='Present',
    'ckd_absent'='Absent',
    'ra_present'='Present',
    'ra_absent'='Absent',
    'diabetes_type1_present'='Present',
    'diabetes_type1_absent'='Absent',
    'diabetes_type2_present'='Present',
    'diabetes_type2_absent'='Absent'
)

In [None]:
attribute_sets = list(
    'race_eth_sex'=c('race_eth', 'gender_concept_name', 'race_eth_gender'),
    'comorbidities'=c('has_ckd_history', 'has_ra_history', 'has_diabetes_type1_history', 'has_diabetes_type2_history')
)

metric_sets=list(
    'performance'=c('auc', 'loss_bce', 'ace_abs_logistic_logit'),
    'net_benefit'=c('net_benefit_rr_0.075', 'net_benefit_rr_recalib_0.075', 'net_benefit_rr_0.2', 'net_benefit_rr_recalib_0.2')
)
metric_sets_min_max=list(
    'performance'=c('auc', 'auc_min', 'loss_bce', 'loss_bce_max', 'ace_abs_logistic_logit', 'ace_abs_logistic_logit_max'),
    'net_benefit'=c(
        'net_benefit_rr_0.075', 
        'net_benefit_rr_0.075_min', 
        'net_benefit_rr_recalib_0.075', 
        'net_benefit_rr_recalib_0.075_min',
        'net_benefit_rr_0.2', 
        'net_benefit_rr_0.2_min', 
        'net_benefit_rr_recalib_0.2',
        'net_benefit_rr_recalib_0.2_min'
    )
)

In [None]:
clean_eval_group <- function(df) {
    for (i in names(eval_group_clean_map)) {
        df <- df %>% mutate(
            eval_group = replace(eval_group, eval_group == i, eval_group_clean_map[[i]])
        )
    }
    return(df)

}

transform_df <- function(
    df,
    var_to_spread,
    metrics_to_plot,
    fold_id_to_plot
) {    
    temp <- df %>%
        filter(
            metric %in% metrics_to_plot,
            tag %in% tags_to_plot,
            eval_group != 'overall'
        ) %>% 
        select(metric, eval_attribute, eval_group, 
               tag, CI_quantile_95, .data[[var_to_spread]]) %>%
        distinct() %>%
        spread(CI_quantile_95, .data[[var_to_spread]]) %>%
        clean_eval_group() %>%
        mutate(
            eval_group = gsub('[[)]', "", eval_group)
        ) %>%
        mutate(
            metric=factor(metric, levels=names(metric_list[metrics_to_plot]), labels=metric_list[metrics_to_plot]),
            tag=factor(tag, levels=names(tag_list[tags_to_plot]), labels=tag_list[tags_to_plot]),
            eval_attribute=factor(eval_attribute, levels=names(attribute_list), labels=attribute_list)
        )
    return(temp)
}

transform_df_combined_marginal <- function(
    df,
    var_to_spread,
    metrics_to_plot,
    fold_id_to_plot
) {    
    temp <- df %>%
         mutate(
            bare_metric=as.character(strsplit(metric, '_min|_max')),
            is_min_max_metric=grepl('_min|_max', metric)
        ) %>%
        filter(
            metric %in% metrics_to_plot,
            tag %in% tags_to_plot,
            eval_group == 'overall'
        ) %>%   
        mutate(metric=bare_metric) %>%
        select(
            metric, 
            is_min_max_metric, 
            eval_attribute, 
            eval_group, 
            tag, 
            CI_quantile_95, 
            .data[[var_to_spread]]
        ) %>%
        distinct() %>%
        spread(CI_quantile_95, .data[[var_to_spread]]) %>%
        mutate(
            metric=factor(metric, levels=names(metric_list[metrics_to_plot]), labels=metric_list[metrics_to_plot]),
            tag=factor(tag, levels=names(tag_list[tags_to_plot]), labels=tag_list[tags_to_plot]),
            eval_attribute=factor(eval_attribute, levels=names(attribute_list), labels=attribute_list),
            is_min_max_metric=factor(is_min_max_metric, levels=names(is_min_max_metric_list), labels=is_min_max_metric_list)
        ) %>%
        mutate(
            eval_group=eval_attribute,
            eval_attribute=is_min_max_metric
        )
    return(temp)
}


make_plot_combined_marginal <- function(
    df,
    tags_to_plot=c('erm_baseline', 'erm_subset', 'aware_loss_max', 'dro_loss_max'),
    fold_id_to_plot='test',
    metric_set_key='performance',
    combined=FALSE,
    y_label='',
    mode=NULL
) {
    
    results_absolute <- transform_df(
        df, 
        var_to_spread='comparator',
        metrics_to_plot=metric_sets[[metric_set_key]],
        fold_id_to_plot='test'
    )

    results_relative <- transform_df(
        df, 
        var_to_spread='delta',
        metrics_to_plot=metric_sets[[metric_set_key]],
        fold_id_to_plot='test'
    ) %>% mutate(metric = paste0(metric, ' (rel)')) %>% mutate(erm_value=0)
    
    results_absolute_marginal <- transform_df_combined_marginal(
        df, 
        var_to_spread='comparator',
        metrics_to_plot=metric_sets_min_max[[metric_set_key]],
        fold_id_to_plot='test'
    )


    if (metric_set_key == 'performance') {
        results_relative_marginal <- transform_df_combined_marginal(
            df,
            var_to_spread='delta',
            metrics_to_plot=metric_sets_min_max[[metric_set_key]],
            fold_id_to_plot='test'
        ) %>% mutate(metric = paste0(metric, ' (rel)')) %>% mutate(erm_value=0)
        
        combined_results <- full_join(results_absolute, results_relative) %>% 
            full_join(results_absolute_marginal) %>% 
            full_join(results_relative_marginal) %>%
            mutate(
                eval_attribute=factor(eval_attribute, levels=attribute_list, labels=attribute_list),
                eval_group=factor(eval_group, levels=eval_group_list, labels=eval_group_list),
                metric=factor(metric, levels=c('AUC', 'AUC (rel)', 'ACE', 'ACE (rel)', 'Loss', 'Loss (rel)'))
            )

        if (!is.null(mode)) {
            if (mode == 'relative') {
                combined_results <- combined_results %>% 
                    filter(metric %in% c('AUC (rel)', 'ACE (rel)', 'Loss (rel)')) %>%
                    mutate(metric=str_replace(metric, stringr::fixed(' (rel)'), stringr::fixed(''))) %>%
                    mutate(metric=factor(metric, levels=c('AUC', 'ACE', 'Loss')))
            } else if (mode == 'absolute') {
                combined_results <- combined_results %>%
                    filter(
                        !str_detect(metric, stringr::fixed('(rel'))
                    )
            }
        }
    } else if (metric_set_key == 'net_benefit') {
        results_relative_marginal <- transform_df_combined_marginal(
            df,
            var_to_spread='delta',
            metrics_to_plot=metric_sets_min_max[[metric_set_key]],
            fold_id_to_plot='test'
        ) %>% mutate(metric = paste0(metric, ' (rel)')) %>% mutate(erm_value=0)
        
        combined_results <- full_join(results_absolute, results_relative) %>% 
            full_join(results_absolute_marginal) %>% 
            full_join(results_relative_marginal) %>%
            mutate(metric=str_replace(metric, stringr::fixed(') (rel)'), stringr::fixed('; rel)'))) %>%
            mutate(
                eval_attribute=factor(eval_attribute, levels=attribute_list, labels=attribute_list),
                eval_group=factor(eval_group, levels=eval_group_list, labels=eval_group_list),
                metric=factor(metric, levels=c('NB (7.5%)', 'NB (7.5%; rel)', 
                                               'cNB (7.5%)', 'cNB (7.5%; rel)',
                                               'NB (20%)', 'NB (20%; rel)', 
                                               'cNB (20%)', 'cNB (20%; rel)'
                                              )
                             )
            )

        if (!is.null(mode)) {
            if (mode == 'relative') {
                combined_results <- combined_results %>% 
                    filter(str_detect(metric, stringr::fixed('; rel'))) %>%
                    mutate(metric=str_replace(metric, stringr::fixed('; rel)'), stringr::fixed(')'))) %>%
#                     mutate(metric=str_replace(metric, stringr::fixed(' (rel)'), stringr::fixed(' (relative)'))) %>%
                    mutate(metric=factor(metric, c('NB (7.5%)', 'cNB (7.5%)', 'NB (20%)', 'cNB (20%)')))
            } else if (mode == 'absolute') {
                combined_results <- combined_results %>%
                    filter(
                        !str_detect(metric, stringr::fixed('; rel'))
                    )
            }
        }
    }

    
    g <- combined_results %>% 
        ggplot(aes(eval_group, mid, color=tag)) + 
        coord_cartesian(clip=FALSE) +
        geom_point(position=position_dodge(width=0.75), size=0.5) +
        geom_linerange(
            aes(ymin=lower, ymax=upper), 
            size=0.5,
            position=position_dodge(width=0.75)
        ) + 
        lemon::facet_rep_grid(
            rows = vars(metric), 
            cols=vars(eval_attribute), 
            scales='free',
            switch='y',
        ) +
        theme_bw() +
        ggsci::scale_color_d3() +
        theme(
            axis.title = element_text(size = rel(1.75)),
            axis.title.y = element_blank(),
            axis.title.x = element_blank(),
            strip.text.x = element_text(size = rel(1.35), vjust=1),
            strip.text.y = element_text(size = rel(1.1)),
            strip.background = element_blank(),
            strip.placement = "outside",
            axis.text.x = element_text(angle = 45, vjust=0.95, hjust=1),
            axis.text = element_text(size=rel(1), color='black'),
            panel.grid.major = element_blank(), 
            panel.grid.minor = element_blank(),
            panel.border = element_blank(),
            axis.line = element_line(color='black'),
            legend.text=element_text(size=rel(0.85)),
            legend.position='bottom'
        ) +
        labs(
            y=y_label,
            color = "Method"
        )
    g <- g + geom_hline(aes(yintercept=erm_value), color='black', linetype='dashed', size=0.5, alpha=0.5)
    g <- tag_facet(g, open="", close="", tag_pool=c(toupper(letters), as.character(tolower(as.roman(1:10)))), 
                   hjust = -0.5, vjust = 0.5
                  )
        return(g)
}

In [None]:
# task_path_prefixes = c(
#     'optum'
# )

for (attribute_set in names(attribute_sets)) {
    for (metric_set_key in names(metric_sets)) {
#     for (metric_set_key in c('net_benefit')) {
        eval_attributes <- attribute_sets[[attribute_set]]
        data_path = '../zipcode_cvd/experiments/figures_data'

        results_path = file.path(data_path, 'performance', 'result_df_ci.csv')
        aggregated_results = fread(results_path)

        figure_path = file.path('../zipcode_cvd/experiments/figures/', metric_set_key, attribute_set)
        dir.create(figure_path, recursive=TRUE)

        ## Plot absolute performance metrics
        tags_to_plot <- c(
            'erm_baseline', 
            'erm_subset',
            'regularized_loss_max',
            'regularized_auc_min',
            'dro_loss_max',
            'dro_auc_min'
        )

#         metrics_to_plot <- c('auc', 'loss_bce', 'ace_abs_logistic_logit')
        fold_id_to_plot <- 'test'

        if ('exp' %in% names(aggregated_results)) {
            aggregated_results <- aggregated_results %>% rename(experiment_name=exp)
        }

        aggregated_results <- aggregated_results %>% filter(eval_attribute %in% eval_attributes)


        g <- make_plot_combined_marginal(
            df=aggregated_results, 
            tags_to_plot=tags_to_plot, 
            fold_id_to_plot=fold_id_to_plot,
            metric_set_key=metric_set_key,
            mode='absolute'
        )

        ggsave(filename=file.path(figure_path, 'method_comparison_absolute.png'), plot=g, device='png', width=10, height=6, units='in')
        ggsave(filename=file.path(figure_path, 'method_comparison_absolute.pdf'), plot=g, device='pdf', width=10, height=6, units='in')

        g <- make_plot_combined_marginal(
            df=aggregated_results, 
            tags_to_plot=tags_to_plot, 
            fold_id_to_plot=fold_id_to_plot,
            metric_set_key=metric_set_key,
            mode='relative'
        )

        ggsave(filename=file.path(figure_path, 'method_comparison_relative.png'), plot=g, device='png', width=10, height=6, units='in')
        ggsave(filename=file.path(figure_path, 'method_comparison_relative.pdf'), plot=g, device='pdf', width=10, height=6, units='in')

        # Combined marginal
        g <- make_plot_combined_marginal(
            df=aggregated_results, 
            tags_to_plot=tags_to_plot, 
            fold_id_to_plot=fold_id_to_plot,
            metric_set_key=metric_set_key
        )

        ggsave(filename=file.path(figure_path, 'method_comparison_combined_marginal.png'), plot=g, device='png', width=10, height=8, units='in')
        ggsave(filename=file.path(figure_path, 'method_comparison_combined_marginal.pdf'), plot=g, device='pdf', width=10, height=8, units='in')
    }
    }
    