In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy
import itertools
import string

from prediction_utils.pytorch_utils.metrics import CalibrationEvaluator, FairOVAEvaluator
from sklearn.metrics import roc_curve
from zipcode_cvd.experiments.util import flatten_multicolumns

In [None]:
figures_data_path = '../zipcode_cvd/experiments/figures_data/eo_rr/'

In [None]:
attributes = ['race_eth', 'gender_concept_name', 'race_eth_gender']

In [None]:
data_dict = {
    attribute: pd.read_parquet(os.path.join(figures_data_path, f'result_df_ci_{attribute}.parquet'))
    for attribute in attributes
}

In [None]:
group_name_dict = {
    'race_eth': pd.DataFrame(
        {
            'Asian': 'Asian',
            'Black or African American': 'Black',
            'Hispanic or Latino': 'Hispanic',
            'Other': 'Other',
            'White': 'White',
        }, index=['_race_eth']).transpose().rename_axis('race_eth').reset_index(),
    'gender_concept_name': pd.DataFrame({
        'FEMALE': 'Female',
        'MALE': 'Male',
    }, index=['_gender_concept_name']).transpose().rename_axis('gender_concept_name').reset_index(),
    'race_eth_gender': pd.DataFrame(
        {
        'Asian | FEMALE': 'A-F',
        'Asian | MALE': 'A-M',
        'Black or African American | MALE': 'B-M',
        'Black or African American | FEMALE': 'B-F',
        'Hispanic or Latino | MALE': 'H-M',
        'Hispanic or Latino | FEMALE': 'H-F',
        'Other | FEMALE': 'O-F',
        'Other | MALE': 'O-M',
        'White | FEMALE': 'W-F',
        'White | MALE': 'W-M',
    }, index=['_race_eth_gender']).transpose().rename_axis('race_eth_gender').reset_index(),
    
}

In [None]:
data_dict = {
    key: value.merge(group_name_dict[key]).drop(columns=key).rename(columns={f'_{key}': key})
    for key, value in data_dict.items()
}

In [None]:
def plot_data(
    df, 
    ax=None, 
    x_var='score',
    y_var='calibration_density', 
    group_var_name='race_eth_gender', 
    ci_lower_var=None,
    ci_upper_var=None,
    drawstyle=None, 
    ylim=(None, None),
    xlim=(None, None),
    ylabel=None,
    xlabel=None,
    legend=True,
    bbox_to_anchor=(1.04, 1),
    plot_y_equals_x=False,
    plot_x_axis=False,
    despine=True,
    hide_yticks=False,
    hide_xticks=False,
    linestyle=None,
    label_group=True,
    title=None,
    axvline=None,
    y_labelpad=None,
    titlepad=None,
    xticks=None,
    xticklabels=None,
):
    
    if ax is None:
        plt.figure()
        ax = plt.gca()

    groups = []
    for i, (group_id, group_df) in enumerate(df.groupby(group_var_name)):
        groups.append(group_id)
        color = plt.rcParams['axes.prop_cycle'].by_key()['color'][i%len(plt.rcParams['axes.prop_cycle'])]
        ax.plot(group_df[x_var], group_df[y_var], drawstyle=drawstyle, color=color, linestyle=linestyle, label=group_id if label_group else None)
        if ci_upper_var is not None and ci_lower_var is not None:
            ax.fill_between(
                group_df[x_var], 
                group_df[ci_lower_var], 
                group_df[ci_upper_var],
                alpha=0.25,
                color=color,
                label='_nolegend_'
            )
            
        
    if plot_y_equals_x:
        ax.plot(np.linspace(1e-4, 1-1e-4, 1000), np.linspace(1e-4, 1-1e-4, 1000), linestyle='--', color='k', label='_nolegend_')
        
    if axvline is not None:
        ax.axvline(axvline, linestyle='--', color='k', label="_nolegend_")
    
    if plot_x_axis:
        ax.axhline(0, linestyle='--', color='k', label="_nolegend_")
        
    if legend:
        ax.legend(labels=groups, bbox_to_anchor=bbox_to_anchor, frameon=False)
    
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
        
    if ylabel is not None:
        ax.set_ylabel(ylabel, labelpad=y_labelpad)
    
    if title is not None:
        ax.set_title(title, pad=titlepad)
    
    if hide_xticks:
        ax.xaxis.set_ticklabels([])
    elif xticks is not None:
        ax.set_xticks(xticks)
        if xticklabels is not None:
            ax.set_xticklabels(xticklabels)
            
    if hide_yticks:
        ax.yaxis.set_ticklabels([])
    
    if despine:
        sns.despine()

    return ax

In [None]:
data_dict_pivot = {}
for key, value in data_dict.items():
    data_dict_pivot[key] = value.pivot(
        index = set(value.columns) - set(['comparator', 'baseline', 'delta', 'CI_quantile_95', 'metric']),
        columns=['metric', 'CI_quantile_95'],
        values=['comparator', 'delta']
    ).pipe(flatten_multicolumns).reset_index()

In [None]:
plot_config_dict = {
    'calibration_curve': {
        'xlim':(0, 0.4), 
        'ylim':(0, 0.4), 
        'xticks': [0, 0.2, 0.4],
        'xticklabels': ['0', '0.2', '0.4'],
        'plot_y_equals_x':True, 
        'legend': False, 
        'ylabel': 'Incidence',
        'x_var':'score',
        'y_var':'comparator_calibration_density_mid',
        'ci_lower_var': 'comparator_calibration_density_lower',
        'ci_upper_var': 'comparator_calibration_density_upper',
    },
    'tpr': {
        'xlim':(0, 0.4), 
        'ylim':(0, 1), 
        'xticks': [0, 0.2, 0.4],
        'xticklabels': ['0', '0.2', '0.4'],
        'legend':False, 
        'ylabel': 'TPR',
        'x_var':'score',
        'y_var': 'comparator_tpr_mid',
        'ci_lower_var':'comparator_tpr_lower',
        'ci_upper_var':'comparator_tpr_upper',
    },
    'fpr': {
        'xlim':(0, 0.4), 
        'ylim':(0, 1), 
        'xticks': [0, 0.2, 0.4],
        'xticklabels': ['0', '0.2', '0.4'],
        'legend':False, 
        'ylabel': 'FPR',
        'x_var':'score',
        'y_var': 'comparator_fpr_mid',
        'ci_lower_var':'comparator_fpr_lower',
        'ci_upper_var':'comparator_fpr_upper',
    },
    'decision_curve': {
        'xlim': (0, 0.4), 
        'ylim': (0, 0.05), 
        'xticks': [0, 0.2, 0.4],
        'xticklabels': ['0', '0.2', '0.4'],
        'legend': False, 
        'ylabel': 'NB',
        'x_var': 'score',
        'y_var': 'comparator_nb_mid',
        'ci_lower_var':'comparator_nb_lower',
        'ci_upper_var':'comparator_nb_upper',
#         'y_labelpad': 8
    },
    'decision_curve_treat_all': {
        'xlim': (0, 0.4), 
        'ylim': (0, 0.05), 
        'xticks': [0, 0.2, 0.4],
        'xticklabels': ['0', '0.2', '0.4'],
        'legend': False, 
        'y_var': 'comparator_nb_all_mid',
        'ci_lower_var':'comparator_nb_all_lower',
        'ci_upper_var':'comparator_nb_all_upper',
        'linestyle': '--',
        'label_group': False
    },
    'decision_curve_diff': {
        'xlim': (0, 0.4), 
        'ylim': (-0.025, 0.025), 
        'xticks': [0, 0.2, 0.4],
        'xticklabels': ['0', '0.2', '0.4'],
        'legend': False, 
        'ylabel': 'NB (rel)',
        'x_var': 'score',
        'y_var': 'delta_nb_mid',
        'ci_lower_var': 'delta_nb_lower',
        'ci_upper_var': 'delta_nb_upper',
        'plot_x_axis': True,
#         'y_labelpad': -1
    },
    'decision_curve_implied': {
        'xlim': (0, 0.4), 
        'ylim': (0, 0.05), 
        'xticks': [0, 0.2, 0.4],
        'xticklabels': ['0', '0.2', '0.4'],
        'legend': False, 
        'ylabel': 'cNB',
        'x_var': 'score',
        'y_var': 'comparator_nb_implied_mid',
        'ci_lower_var': 'comparator_nb_implied_lower',
        'ci_upper_var': 'comparator_nb_implied_upper',
#         'y_labelpad': 8
    },
    'decision_curve_treat_all_implied': {
        'xlim': (0, 0.4), 
        'ylim': (0, 0.05), 
        'xticks': [0, 0.2, 0.4],
        'xticklabels': ['0', '0.2', '0.4'],
        'legend': False,
        'y_var': 'comparator_nb_all_mid',
        'ci_lower_var':'comparator_nb_all_lower',
        'ci_upper_var':'comparator_nb_all_upper',
        'linestyle': '--',
        'label_group': False
    },
    'decision_curve_implied_diff': {
        'xlim': (0, 0.4), 
        'ylim': (-0.025, 0.025), 
        'xticks': [0, 0.2, 0.4],
        'xticklabels': ['0', '0.2', '0.4'],
        'legend': False, 
        'ylabel': 'cNB (rel)',
        'x_var': 'score',
        'y_var': 'delta_nb_implied_mid',
        'ci_lower_var': 'delta_nb_implied_lower',
        'ci_upper_var': 'delta_nb_implied_upper',
        'plot_x_axis': True,
#         'y_labelpad': -1
    },
    'decision_curve_075': {
        'xlim': (0, 0.4), 
        'ylim': (0, 0.05), 
        'xticks': [0, 0.2, 0.4],
        'xticklabels': ['0', '0.2', '0.4'],
        'legend': False, 
        'ylabel': 'NB (7.5%)',
        'x_var': 'score',
        'y_var': 'comparator_nb_0.075_mid',
        'ci_lower_var': 'comparator_nb_0.075_lower',
        'ci_upper_var': 'comparator_nb_0.075_upper',
        'axvline': 0.075,
#         'y_labelpad': 8
    },
    'decision_curve_075_diff': {
        'xlim': (0, 0.4), 
        'ylim': (-0.025, 0.025), 
        'xticks': [0, 0.2, 0.4],
        'xticklabels': ['0', '0.2', '0.4'],
        'legend': False, 
        'ylabel': 'NB (7.5%, rel)',
        'x_var': 'score',
        'y_var': 'delta_nb_0.075_mid',
        'ci_lower_var': 'delta_nb_0.075_lower',
        'ci_upper_var': 'delta_nb_0.075_upper',
        'axvline': 0.075,
#         'y_labelpad': 0,
        'plot_x_axis': True
    },
    'decision_curve_075_implied': {
        'xlim': (0, 0.4), 
        'ylim': (0, 0.05), 
        'xticks': [0, 0.2, 0.4],
        'xticklabels': ['0', '0.2', '0.4'],
        'legend': False, 
        'ylabel': 'cNB (7.5%)',
        'x_var': 'score',
        'y_var': 'comparator_nb_0.075_implied_mid',
        'ci_lower_var': 'comparator_nb_0.075_implied_lower',
        'ci_upper_var': 'comparator_nb_0.075_implied_upper',
        'axvline': 0.075,
#         'y_labelpad': 8
    },
    'decision_curve_075_implied_diff': {
        'xlim': (0, 0.4), 
        'ylim': (-0.025, 0.025), 
        'xticks': [0, 0.2, 0.4],
        'xticklabels': ['0', '0.2', '0.4'],
        'legend': False, 
        'ylabel': 'cNB (7.5%, rel)',
        'x_var': 'score',
        'y_var': 'delta_nb_0.075_implied_mid',
        'ci_lower_var': 'delta_nb_0.075_implied_lower',
        'ci_upper_var': 'delta_nb_0.075_implied_upper',
        'axvline': 0.075,
#         'y_labelpad': 0,
        'plot_x_axis': True
    },
    'decision_curve_20': {
        'xlim': (0, 0.4), 
        'ylim': (-0.025, 0.025),  
        'xticks': [0, 0.2, 0.4],
        'xticklabels': ['0', '0.2', '0.4'],
        'legend': False, 
        'ylabel': 'NB (20%)',
        'x_var': 'score',
        'y_var': 'comparator_nb_0.2_mid',
        'ci_lower_var': 'comparator_nb_0.2_lower',
        'ci_upper_var': 'comparator_nb_0.2_upper',
        'axvline': 0.2,
#         'y_labelpad': 0,
        'plot_x_axis': True
    },
    'decision_curve_20_diff': {
        'xlim': (0, 0.4), 
        'ylim': (-0.025, 0.025), 
        'xticks': [0, 0.2, 0.4],
        'xticklabels': ['0', '0.2', '0.4'],
        'legend': False, 
        'ylabel': 'NB (20%, rel)',
        'x_var': 'score',
        'y_var': 'delta_nb_0.2_mid',
        'ci_lower_var': 'delta_nb_0.2_lower',
        'ci_upper_var': 'delta_nb_0.2_upper',
        'axvline': 0.2,
#         'y_labelpad': 0,
        'plot_x_axis': True
    },
    'decision_curve_20_implied': {
        'xlim': (0, 0.4), 
        'ylim': (-0.025, 0.025), 
        'xticks': [0, 0.2, 0.4],
        'xticklabels': ['0', '0.2', '0.4'],
        'legend': False, 
        'ylabel': 'cNB (20%)',
        'x_var': 'score',
        'y_var': 'comparator_nb_0.2_implied_mid',
        'ci_lower_var': 'comparator_nb_0.2_implied_lower',
        'ci_upper_var': 'comparator_nb_0.2_implied_upper',
        'axvline': 0.2,
#         'y_labelpad': 0,
        'plot_x_axis': True
    },
    'decision_curve_20_implied_diff': {
        'xlim': (0, 0.4), 
        'ylim': (-0.025, 0.025), 
        'xticks': [0, 0.2, 0.4],
        'xticklabels': ['0', '0.2', '0.4'],
        'legend': False, 
        'ylabel': 'cNB (20%, rel)',
        'x_var': 'score',
        'y_var': 'delta_nb_0.2_implied_mid',
        'ci_lower_var': 'delta_nb_0.2_implied_lower',
        'ci_upper_var': 'delta_nb_0.2_implied_upper',
        'axvline': 0.2,
#         'y_labelpad': 0,
        'plot_x_axis': True
    },
}

def make_plot_grid(
    result_df, 
    plot_keys, 
    group_var_name, 
    bbox_to_anchor=(1.1, 0.6), 
    xlabel_height=0.02, 
    wspace=0.2, 
    hspace=0.2,
    titlepad=None
):
    lambda_values = result_df.lambda_group_regularization.unique()

    fig, ax_list = plt.subplots(
        len(plot_keys), len(lambda_values), squeeze=False, figsize=(10,1.5*len(plot_keys)), dpi=180
    )
    plt.subplots_adjust(wspace=wspace, hspace=hspace)
    for j, plot_key in enumerate(plot_keys):
        for i, lambda_value in enumerate(lambda_values):
            the_df = result_df.query('lambda_group_regularization == @lambda_value')

            config = plot_config_dict[plot_key].copy()
            if i > 0:
                config['ylabel'] = None
            if j == 0:
                text_title = r'$\lambda$ = {0:.3}'.format(lambda_value)
                ax_list[j, i].set_title(text_title, pad=titlepad)
            plot_data(
                the_df,
                ax=ax_list[j][i],
                hide_yticks=i>0,
                hide_xticks=j<len(plot_keys)-1,
                group_var_name=group_var_name,
                **config
            )

            # Add treat-all line to decision curves
            if plot_key == "decision_curve":
                plot_data(
                    the_df, 
                    ax=ax_list[j][i], 
                    hide_yticks = i > 0,
                    hide_xticks=j<len(plot_keys)-1,
                    x_var=plot_config_dict[plot_key]['x_var'],
                    group_var_name=group_var_name,
                    **plot_config_dict['decision_curve_treat_all']
                )
            elif plot_key == "decision_curve_implied":
                plot_data(
                    the_df, 
                    ax=ax_list[j][i], 
                    hide_yticks = i > 0,
                    hide_xticks=j<len(plot_keys)-1,
                    x_var=plot_config_dict[plot_key]['x_var'],
                    group_var_name=group_var_name,
                    **plot_config_dict['decision_curve_treat_all_implied']
                )
            ax_list[j][i].text(
                0.02, 1.02, 
                string.ascii_uppercase[j*len(lambda_values) + i],
                transform=ax_list[j][i].transAxes, 
                size=12, weight='bold')

    handles, labels = ax_list[-1, -1].get_legend_handles_labels()
    fig.text(0.5, xlabel_height, 'Threshold', ha='center', size=18)
    fig.align_ylabels(ax_list[:, 0])
    plt.figlegend(
        handles, labels, bbox_to_anchor=bbox_to_anchor, frameon=False
    )
#     return fig

In [None]:
plot_keys_dict = {
    'performance': ['calibration_curve', 'tpr', 'fpr'],
    'decision_curves': [
        'decision_curve', 
        'decision_curve_diff', 
        'decision_curve_implied', 
        'decision_curve_implied_diff'
    ],
    'decision_curves_threshold_075': [
        'decision_curve_075',
        'decision_curve_075_diff',
        'decision_curve_075_implied',
        'decision_curve_075_implied_diff'
    ],
    'decision_curves_threshold_20': [
        'decision_curve_20',
        'decision_curve_20_diff',
        'decision_curve_20_implied',
        'decision_curve_20_implied_diff',
    ]
}


In [None]:
plot_grid_config = {
    ('race_eth', 'performance'): {
        'bbox_to_anchor': (1.05, 0.6),
        'xlabel_height': 0.0,
        'titlepad': 15
    },
    ('race_eth', 'decision_curves'): {
        'bbox_to_anchor': (1.05, 0.6),
        'xlabel_height': 0.02,
        'titlepad': 15
    },
    ('race_eth', 'decision_curves_threshold_075'): {
        'bbox_to_anchor': (1.05, 0.6),
        'xlabel_height': 0.02,
        'titlepad': 15
    },
    ('race_eth', 'decision_curves_threshold_20'): {
        'bbox_to_anchor': (1.05, 0.6),
        'xlabel_height': 0.02,
        'titlepad': 15
    },
    ('gender_concept_name', 'performance'): {
        'bbox_to_anchor': (1.02, 0.55),
        'xlabel_height': 0.0,
        'titlepad': 15
    },
    ('gender_concept_name', 'decision_curves'): {
        'bbox_to_anchor': (1.02, 0.55),
        'xlabel_height': 0.02,
        'titlepad': 15
    },
    ('gender_concept_name', 'decision_curves_threshold_075'): {
        'bbox_to_anchor': (1.02, 0.55),
        'xlabel_height': 0.02,
        'titlepad': 15
    },
    ('gender_concept_name', 'decision_curves_threshold_20'): {
        'bbox_to_anchor': (1.02, 0.55),
        'xlabel_height': 0.02,
        'titlepad': 15
    },
    ('race_eth_gender', 'performance'): {
        'bbox_to_anchor': (1.0, 0.73),
        'xlabel_height': 0.0,
        'titlepad': 15
    },
    ('race_eth_gender', 'decision_curves'): {
        'bbox_to_anchor': (1.0, 0.7),
        'xlabel_height': 0.02,
        'titlepad': 15
    },
    ('race_eth_gender', 'decision_curves_threshold_075'): {
        'bbox_to_anchor': (1.0, 0.7),
        'xlabel_height': 0.02,
        'titlepad': 15
    },
    ('race_eth_gender', 'decision_curves_threshold_20'): {
        'bbox_to_anchor': (1.0, 0.7),
        'xlabel_height': 0.02,
        'titlepad': 15
    },
}

In [None]:
figures_path = '../zipcode_cvd/experiments/figures/optum/eo_rr/bootstrapped'

In [None]:
plt.close()
attribute='race_eth'
plot_key='decision_curves'
group_objective_metric = 'mmd'
make_plot_grid(
    data_dict_pivot[attribute].query('group_objective_metric == @group_objective_metric'), 
    plot_keys=plot_keys_dict[plot_key], 
    group_var_name=attribute,
    **plot_grid_config[(attribute, plot_key)]
)

In [None]:
for attribute, group_objective_metric, plot_key in itertools.product(
    attributes, 
    ['mmd', 'threshold_rate'], 
    plot_keys_dict.keys()
):
    
    make_plot_grid(
        data_dict_pivot[attribute].query('group_objective_metric == @group_objective_metric'), 
        plot_keys=plot_keys_dict[plot_key], 
        group_var_name=attribute,
        **plot_grid_config[(attribute, plot_key)]
    )
    figure_path = os.path.join(figures_path, attribute, group_objective_metric)
    os.makedirs(figure_path, exist_ok=True)
    plt.savefig(os.path.join(figure_path, 'eo_grid_{}.png'.format(plot_key)), dpi=180, bbox_inches='tight')
    plt.savefig(os.path.join(figure_path, 'eo_grid_{}.pdf'.format(plot_key)), bbox_inches='tight')
    plt.close()