In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from zipcode_cvd.experiments.util import flatten_multicolumns
import itertools
import string

In [None]:
figures_data_path = '../zipcode_cvd/experiments/figures_data/eo_rr'

In [None]:
attributes = ['race_eth', 'gender_concept_name', 'race_eth_gender']
data_dict = {
    attribute: pd.read_parquet(os.path.join(figures_data_path, f'result_df_ci_performance_{attribute}.parquet'))
    for attribute in attributes
}

In [None]:
group_name_dict = {
    'race_eth': pd.DataFrame(
        {
            'Asian': 'Asian',
            'Black or African American': 'Black',
            'Hispanic or Latino': 'Hispanic',
            'Other': 'Other',
            'White': 'White',
            'overall': 'Overall'
        }, index=['_race_eth']).transpose().rename_axis('race_eth').reset_index(),
    'gender_concept_name': pd.DataFrame({
        'FEMALE': 'Female',
        'MALE': 'Male',
        'overall': 'Overall'
    }, index=['_gender_concept_name']).transpose().rename_axis('gender_concept_name').reset_index(),
    'race_eth_gender': pd.DataFrame(
        {
        'Asian | FEMALE': 'A-F',
        'Asian | MALE': 'A-M',
        'Black or African American | MALE': 'B-M',
        'Black or African American | FEMALE': 'B-F',
        'Hispanic or Latino | MALE': 'H-M',
        'Hispanic or Latino | FEMALE': 'H-F',
        'Other | FEMALE': 'O-F',
        'Other | MALE': 'O-M',
        'White | FEMALE': 'W-F',
        'White | MALE': 'W-M',
        'overall': 'Overall'
    }, index=['_race_eth_gender']).transpose().rename_axis('race_eth_gender').reset_index(),
    
}

In [None]:
data_dict = {
    key: value.merge(group_name_dict[key]).drop(columns=key).rename(columns={f'_{key}': key})
    for key, value in data_dict.items()
}

In [None]:
data_dict['gender_concept_name'].gender_concept_name.unique()

In [None]:
data_dict_pivot = {}
for key, value in data_dict.items():
    data_dict_pivot[key] = value.pivot(
        index = set(value.columns) - set(['comparator', 'baseline', 'delta', 'CI_quantile_95', 'metric']),
        columns=['metric', 'CI_quantile_95'],
        values=['comparator', 'delta']
    ).pipe(flatten_multicolumns).reset_index()

In [None]:
def plot_data(
    df, 
    ax=None, 
    x_var='score',
    y_var='calibration_density', 
    group_var_name='race_eth_gender', 
    ci_lower_var=None,
    ci_upper_var=None,
    ci_type='fill',
    drawstyle=None, 
    ylim=(None, None),
    xlim=(None, None),
    ylabel=None,
    xlabel=None,
    legend=True,
    bbox_to_anchor=(1.04, 1),
    plot_y_equals_x=False,
    plot_x_axis=False,
    despine=True,
    hide_yticks=False,
    hide_xticks=False,
    linestyle=None,
    label_group=True,
    title=None,
    axvline=None,
    y_labelpad=None,
    ylabel_fontsize=None,
    xticks=None,
    xticklabels=None,
    use_symlog_x=False,
    symlog_x_linthresh=1e-2,
    errorbar_capsize=0.0,
    groupby_sort=True
):
    
    if ax is None:
        plt.figure()
        ax = plt.gca()

    groups = []
    for i, (group_id, group_df) in enumerate(df.groupby(group_var_name, sort=groupby_sort)):
        groups.append(group_id)
        color = plt.rcParams['axes.prop_cycle'].by_key()['color'][i%len(plt.rcParams['axes.prop_cycle'])]
        ax.plot(group_df[x_var], group_df[y_var], drawstyle=drawstyle, color=color, linestyle=linestyle, label=group_id if label_group else None)
        if ci_upper_var is not None and ci_lower_var is not None:
            if ci_type == "fill":
                ax.fill_between(
                    group_df[x_var], 
                    group_df[ci_lower_var], 
                    group_df[ci_upper_var],
                    alpha=0.25,
                    color=color,
                    label='_nolegend_'
                )
            elif ci_type == 'errorbar':
                ax.errorbar(
                    x=group_df[x_var],
                    y=group_df[y_var],
                    yerr=[group_df[y_var] - group_df[ci_lower_var], group_df[ci_upper_var] - group_df[y_var]],
                    capsize=errorbar_capsize,
                    label="_nolegend_"
                )
            else:
                raise ValueError('Invalid ci_type')
    
    if use_symlog_x:
        ax.set_xscale('symlog', linthresh=symlog_x_linthresh)
        
    if plot_y_equals_x:
        ax.plot(np.linspace(1e-4, 1-1e-4, 1000), np.linspace(1e-4, 1-1e-4, 1000), linestyle='--', color='k', label='_nolegend_')
        
    if axvline is not None:
        ax.axvline(axvline, linestyle='--', color='k', label="_nolegend_")
    
    if plot_x_axis:
        ax.axhline(0, linestyle='--', color='k', label="_nolegend_")
        
    if legend:
        ax.legend(labels=groups, bbox_to_anchor=bbox_to_anchor, frameon=False)
    
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
        
    if ylabel is not None:
        ax.set_ylabel(ylabel, labelpad=y_labelpad, fontsize=ylabel_fontsize)
    
    if title is not None:
        ax.set_title(title)
    
    if hide_xticks:
        ax.xaxis.set_ticklabels([])
    elif xticks is not None:
        ax.set_xticks(xticks)
        if xticklabels is not None:
            ax.set_xticklabels(xticklabels)
            
    if hide_yticks:
        ax.yaxis.set_ticklabels([])
    
    if despine:
        sns.despine()

    return ax

In [None]:
metric_label_dict = {
    ('auc', 'comparator'): 'AUC',
    ('loss_bce', 'comparator'): 'Loss',
    ('ace_abs_logistic_logit', 'comparator'): 'ACE',
    ('net_benefit_rr_0.075', 'comparator'): 'NB (7.5%)',
    ('net_benefit_rr_0.2', 'comparator'): 'NB (20%)',
    ('net_benefit_rr_recalib_0.075', 'comparator'): 'cNB (7.5%)',
    ('net_benefit_rr_recalib_0.2', 'comparator'): 'cNB (20%)',
    ('auc', 'delta'): 'AUC (rel)',
    ('loss_bce', 'delta'): 'Loss (rel)',
    ('ace_abs_logistic_logit', 'delta'): 'ACE (rel)',
    ('net_benefit_rr_0.075', 'delta'): 'NB (7.5%, rel)',
    ('net_benefit_rr_0.2', 'delta'): 'NB (20%, rel)',
    ('net_benefit_rr_recalib_0.075', 'delta'): 'cNB (7.5%, rel)',
    ('net_benefit_rr_recalib_0.2', 'delta'): 'cNB (20%, rel)',
}
ylabel_dict = {
    'comparator': 'Absolute',
    'delta': 'Relative'
}
plot_keys_metrics_dict = {
    'performance': ['auc', 'loss_bce', 'ace_abs_logistic_logit'],
    'net_benefit': ['net_benefit_rr_0.075', 'net_benefit_rr_recalib_0.075', 'net_benefit_rr_0.2', 'net_benefit_rr_recalib_0.2'],
}

plot_grid_config = {
    'race_eth': {
        'bbox_to_anchor': (1.05, 0.65),
    },
    'gender_concept_name': {
        'bbox_to_anchor': (1.05, 0.65),
    },
    'race_eth_gender': {
        'bbox_to_anchor': (1.02, 0.85),
    }
}
plot_config = {
    ('race_eth', 'comparator', 'net_benefit'): {
        'ylim': (-0.0025, 0.025)
    },
    ('race_eth', 'delta', 'net_benefit'): {
        'ylim': (-0.01, 0.0025)
    },
    ('gender_concept_name', 'comparator', 'net_benefit'): {
        'ylim': (-0.0025, 0.025)
    },
    ('gender_concept_name', 'delta', 'net_benefit'): {
        'ylim': (-0.01, 0.0025)
    },
    ('race_eth_gender', 'comparator', 'net_benefit'): {
        'ylim': (-0.0025, 0.025)
    },
    ('race_eth_gender', 'delta', 'net_benefit'): {
        'ylim': (-0.01, 0.0025)
    },
    ('race_eth', 'comparator', 'performance'): {
    },
    ('race_eth', 'delta', 'performance'): {
    },
    ('gender_concept_name', 'comparator', 'performance'): {
    },
    ('gender_concept_name', 'delta', 'performance'): {
    },
    ('race_eth_gender', 'comparator', 'performance'): {
    },
    ('race_eth_gender', 'delta', 'performance'): {
    },
}

In [None]:
def make_plot_grid(
    df, 
    group_var_name, 
    metrics, 
    plot_key,
    metric_prefixes=['comparator', 'delta'],
    wspace=0.2, 
    hspace=0.25,
    bbox_to_anchor=(1.05, 0.65),
    xlabel_height=-0.05,
    sharey=None
):
    
    fig, ax_list = plt.subplots(
        len(metric_prefixes), len(metrics), squeeze=False, figsize=(10,1.5*len(metric_prefixes)), dpi=180,
        sharey=sharey
    )
    plt.subplots_adjust(wspace=wspace, hspace=hspace)
    for j, metric in enumerate(metrics):
        for i, metric_prefix in enumerate(metric_prefixes):
            plot_data(
                df.query(f'{group_var_name} != "Overall"'), 
                ax=ax_list[i, j],
                x_var='lambda_group_regularization', 
                y_var=f'{metric_prefix}_{metric}_mid', 
                group_var_name=group_var_name,
                use_symlog_x=True,
                symlog_x_linthresh=1e-2,
                ci_lower_var=f'{metric_prefix}_{metric}_lower',
                ci_upper_var=f'{metric_prefix}_{metric}_upper',
                ci_type='errorbar',
                errorbar_capsize=3,
                plot_x_axis=metric_prefix == "delta",
                legend=False,
                hide_xticks=i + 1 < len(metric_prefixes),
                title=metric_label_dict[(metric, metric_prefix)] if i == 0 else None,
                ylabel=ylabel_dict[metric_prefix] if j == 0 else None,
                ylabel_fontsize=12,
                **plot_config[(group_var_name, metric_prefix, plot_key)]
            )
            ax_list[i, j].text(
                0.02, 1, 
                string.ascii_uppercase[i*len(metrics) + j], 
                transform=ax_list[i, j].transAxes, 
                size=12, weight='bold')
    fig.align_ylabels(ax_list[:, 0])

    handles, labels = ax_list[-1, -1].get_legend_handles_labels()
    fig.text(0.5, xlabel_height, r'Regularization $\lambda$', ha='center', size=18)
    plt.figlegend(
        handles, labels, bbox_to_anchor=bbox_to_anchor, frameon=False
    )

In [None]:
data_dict_pivot['race_eth'].columns

In [None]:
id_cols = ['lambda_group_regularization', 'race_eth', 'experiment', 'phase', 'group_objective_metric']
(
    data_dict_pivot['race_eth']
    [id_cols + [x for x in data_dict_pivot['race_eth'].columns if 'delta_net_benefit_rr_recalib_0.075' in x]]
    .query('group_objective_metric == "mmd"')
)

In [None]:
attribute='race_eth'
group_objective_metric = "mmd"
plot_key = 'performance'
make_plot_grid(
    data_dict_pivot[attribute].query('group_objective_metric == @group_objective_metric'), 
    attribute, 
    metrics=plot_keys_metrics_dict[plot_key],
    plot_key=plot_key,
    wspace = 0.2 if plot_key == 'net_benefit' else 0.3,
    sharey = 'row' if plot_key == 'net_benefit' else 'none',
    **plot_grid_config[attribute]
)

In [None]:
figures_path = '../zipcode_cvd/experiments/figures/optum/eo_rr/bootstrapped'

In [None]:
for attribute, group_objective_metric, plot_key in itertools.product(
    attributes, 
    ['mmd', 'threshold_rate'], 
    plot_keys_metrics_dict.keys()
):
    
    fig = make_plot_grid(
        data_dict_pivot[attribute].query('group_objective_metric == @group_objective_metric'), 
        metrics=plot_keys_metrics_dict[plot_key], 
        group_var_name=attribute,
        plot_key=plot_key,
        wspace = 0.2 if plot_key == 'net_benefit' else 0.35,
        sharey = 'row' if plot_key == 'net_benefit' else 'none',
        **plot_grid_config[attribute]
    )
    figure_path = os.path.join(figures_path, attribute, group_objective_metric)
    os.makedirs(figure_path, exist_ok=True)
    plt.savefig(os.path.join(figure_path, 'metric_grid_{}.png'.format(plot_key)), dpi=180, bbox_inches='tight')
    plt.savefig(os.path.join(figure_path, 'metric_grid_{}.pdf'.format(plot_key)), bbox_inches='tight')
    plt.close()

## Plots of TPR/FPR variance

In [None]:
def prepare_plot_df(attribute, group_objective_metric):
    plot_df = (
        data_dict_pivot[attribute]
        .query(f'{attribute} == "Overall"')
        .query('group_objective_metric == @group_objective_metric')
    )

    keep_cols = ['lambda_group_regularization', 'experiment', 'group_objective_metric', attribute]
    metric_cols = [x for x in plot_df.columns if 'var' in x and (('recall' in x) or ('fpr' in x))]
    plot_df_long=plot_df.melt(id_vars=keep_cols, value_vars=metric_cols, var_name='metric', value_name='performance')
    plot_df_long = (
        plot_df_long.assign(
            comparator = lambda x: x.metric.str.contains('comparator'),
            tpr = lambda x: x.metric.str.contains('recall'),
            t_075 = lambda x: x.metric.str.contains('0.075'),
            recalib = lambda x: x.metric.str.contains('recalib')
        )
        .assign(
            comparator_str = 'Absolute',
            tpr_str = 'TPR',
            t_075_str = '(7.5%)',
            recalib_str = 'Calibrated'
        )
        .assign(
            comparator_str = lambda x: x.comparator_str.where(x.comparator, 'Relative'),
            tpr_str = lambda x: x.tpr_str.where(x.tpr, 'FPR'),
            t_075_str = lambda x: x.t_075_str.where(x.t_075, '(20%)'),
            recalib_str = lambda x: x.recalib_str.where(x.recalib, 'Unadjusted'),
            tpr_t_str = lambda x: x.tpr_str.str.cat(x.t_075_str, sep=' '),
            ci_var_str = lambda x: x.metric.str.split('var_').str[-1]
        )
    )
    plot_df_long_pivot = plot_df_long.pivot(
        index=set(plot_df_long.columns) - set(['ci_var_str', 'performance', 'metric']),
        columns = 'ci_var_str',
        values='performance'
    ).reset_index()
    return plot_df_long_pivot

In [None]:
plot_df_dict = {
    (attribute, group_objective_metric): prepare_plot_df(attribute, group_objective_metric)
    for attribute, group_objective_metric in itertools.product(
        ['race_eth', 'gender_concept_name', 'race_eth_gender'],
        ['mmd', 'threshold_rate']
    )
}

In [None]:
plot_df_dict[('race_eth', 'mmd')]

In [None]:
plot_config = {
    ('race_eth', 'Absolute'): {
        'ylim': (-0.0025, 0.03),
        'ylabel': 'IG-Var'
    },
    ('race_eth', 'Relative'): {
        'ylim': (-0.015, 0.015),
        'ylabel': 'IG-Var (rel)'
    },
    ('gender_concept_name', 'Absolute'): {
        'ylim': (-0.0025, 0.015),
        'ylabel': 'IG-Var'
    },
    ('gender_concept_name', 'Relative'): {
        'ylim': (-0.01, 0.012),
        'ylabel': 'IG-Var (rel)'
    },
    ('race_eth_gender', 'Absolute'): {
        'ylim': (-0.0025, 0.04),
        'ylabel': 'IG-Var'
    },
    ('race_eth_gender', 'Relative'): {
        'ylim': (-0.015, 0.04),
        'ylabel': 'IG-Var (rel)'
    },
}

plot_keys = ['TPR (7.5%)', 'FPR (7.5%)', 'TPR (20%)', 'FPR (20%)']
comparator_keys = ['Absolute', 'Relative']

In [None]:
def make_plot_grid(
    df, 
    attribute, 
    wspace=0.2, 
    hspace=0.25,
    bbox_to_anchor=(1.07, 0.6),
    xlabel_height=-0.05,
    sharey=None
):
    plt.close()
    fig, ax_list = plt.subplots(
        len(comparator_keys), len(plot_keys), squeeze=False, figsize=(10,1.5*len(comparator_keys)), dpi=180
    )
    plt.subplots_adjust(wspace=wspace, hspace=hspace)
    for j, plot_key in enumerate(plot_keys):
        for i, comparator_key in enumerate(comparator_keys):

            temp = (
                df
                .query('tpr_t_str == @plot_key')
                .query('comparator_str == @comparator_key')
            )

            temp = temp.sort_values('recalib_str', ascending=False)

            plot_data(
                temp,
                ax=ax_list[i, j],
                x_var='lambda_group_regularization', 
                y_var='mid',
                group_var_name='recalib_str',
                use_symlog_x=True,
                symlog_x_linthresh=1e-2,
                ci_lower_var='lower',
                ci_upper_var='upper',
                ci_type='errorbar',
                errorbar_capsize=5,
                plot_x_axis=comparator_key == 'Relative',
                ylim=plot_config[(attribute, comparator_key)].get('ylim'),
                legend=False,
                hide_yticks= j > 0,
                hide_xticks= i + 1 < len(comparator_keys),
                ylabel=plot_config[(attribute, comparator_key)].get('ylabel') if j == 0 else None,
                title = plot_key if i == 0 else None,
                groupby_sort=False,
                ylabel_fontsize=12
            )
            ax_list[i, j].text(
                0.02, 1, 
                string.ascii_uppercase[i*len(plot_keys) + j], 
                transform=ax_list[i, j].transAxes, 
                size=12, weight='bold')
    handles, labels = ax_list[-1, -1].get_legend_handles_labels()
    fig.text(0.5, xlabel_height, r'Regularization $\lambda$', ha='center', size=18)
    fig.align_ylabels(ax_list[:, 0])
    plt.figlegend(
        handles, labels, bbox_to_anchor=bbox_to_anchor, frameon=False
    )

In [None]:
# attribute = 'race_eth'
# attribute = 'gender_concept_name'
attribute = 'race_eth_gender'
group_objective_metric = 'mmd'
# group_objective_metric = 'threshold_rate'
plot_df = plot_df_dict[(attribute, group_objective_metric)]

make_plot_grid(
    df = plot_df,
    attribute=attribute
)

In [None]:
for attribute, group_objective_metric in itertools.product(
    attributes, 
    ['mmd', 'threshold_rate'], 
):
    
    fig = make_plot_grid(
        plot_df_dict[(attribute, group_objective_metric)], 
        attribute=attribute
    )
    figure_path = os.path.join(figures_path, attribute, group_objective_metric)
    os.makedirs(figure_path, exist_ok=True)
    plt.savefig(os.path.join(figure_path, 'tpr_fpr_var.png'), dpi=180, bbox_inches='tight')
    plt.savefig(os.path.join(figure_path, 'tpr_fpr_var.pdf'), bbox_inches='tight')
    plt.close()