In [None]:
import pandas as pd
import numpy as np
import scipy
import sys
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
sys.path.append('../..')
from utils.eval_utils import get_temp_df, compute_mcc 
np.random.seed(12)

sns.set_palette('mako_r')
sns.set_style("white", {"axes.edgecolor": ".8"})

In [None]:
""" 

The functions in this cell are from https://gist.github.com/rajpurkar/f96c131ba3aeffb1927255d4363496a9 
Please see code description in their repo.

"""

def calculate_metric(metric, model_predictions, test_labels):
    # edited this function to include mcc
    if metric == 'auc':
        return roc_auc_score(test_labels, model_predictions)
    elif metric == 'mcc':
        binary_predictions = (model_predictions >= 0.5).astype(int)
        return compute_mcc(test_labels, binary_predictions)
    else:
        raise ValueError(f"Invalid metric: {metric}")

def print_model_results(model_name, metric, model_predictions, test_set_labels):
    print(f"{model_name}:")
    print(f"{metric.capitalize()}: {calculate_metric(metric, model_predictions, test_set_labels)}")
    print()

def calculate_bootstrap_difference(metric, model_1_predictions, model_2_predictions, test_set_labels, n_resamples=1000):
    """
    Calculate the metric differences between two models using bootstrapping.
    Parameters:
        metric (str): The metric to calculate ('auc' or 'accuracy').
        model_1_predictions (numpy.ndarray): Model 1 predictions.
        model_2_predictions (numpy.ndarray): Model 2 predictions.
        test_set_labels (numpy.ndarray): Binary test labels.
        n_resamples (int): Number of resamples in bootstrapping.
    Returns:
        numpy.ndarray: Array of differences in the metric for each bootstrapped sample.
        float: Observed difference in the metric on the original test set.
    """
    model_1_metric = calculate_metric(metric, model_1_predictions, test_set_labels)
    model_2_metric = calculate_metric(metric, model_2_predictions, test_set_labels)
    observed_difference = model_1_metric - model_2_metric

    differences = np.empty(n_resamples)
    n_samples = len(test_set_labels)

    for i in range(n_resamples):
        bootstrap_indices = np.random.choice(range(n_samples), size=n_samples, replace=True)
        new_test_set_labels = test_set_labels[bootstrap_indices]
        new_model_1_predictions = model_1_predictions[bootstrap_indices]
        new_model_2_predictions = model_2_predictions[bootstrap_indices]

        model_1_metric = calculate_metric(metric, new_model_1_predictions, new_test_set_labels)
        model_2_metric = calculate_metric(metric, new_model_2_predictions, new_test_set_labels)
        differences[i] = model_1_metric - model_2_metric

    differences = differences - observed_difference

    return differences, observed_difference

def calculate_p_value(differences, observed_difference):
    return sum(np.abs(differences) >= np.abs(observed_difference)) / len(differences)

def interpret_p_value(p_value):
    null_hypothesis = "There is no difference in the performance of the two models."
    alternative_hypothesis = "There is a difference in the performance of the two models."
    if p_value < 0.05:
        print("Reject the null hypothesis in favor of the alternative hypothesis.")
        print(f"{alternative_hypothesis} (p-value = {p_value:e})")
    else:
        print("Fail to reject the null hypothesis.")
        print(f"{null_hypothesis} (p-value = {p_value:e})")

def plot_histogram(differences, observed_difference, metric):
    plt.figure()
    plt.hist(differences, bins='auto')
    plt.axvline(observed_difference, color='r', linestyle='dashed', linewidth=2, label='Observed Difference')
    plt.title(f'Histogram of Bootstrapped Differences ({metric.upper()})')
    plt.xlabel(f'Difference in {metric.upper()}')
    plt.ylabel('Frequency')
    plt.legend()

In [None]:
# loading result files

dischargesum_psyroberta_p4_epoch12 = pd.read_csv("../../result_files/dischargesum_psyroberta_p4_epoch12_results.csv")
dischargesum_roberta_epoch12 = pd.read_csv("../../result_files/dischargesum_roberta_epoch12_results.csv")

allnotes_psyroberta_p4_epoch12 = pd.read_csv("../../result_files/allnotes_psyroberta_p4_epoch12_results.csv")
allnotes_psyroberta_p4_dedupcont_epoch12 = pd.read_csv("../../result_files/allnotes_psyroberta_p4_dedupcont_epoch12_results.csv")
allnotes_roberta_dedupcont_epoch12 = pd.read_csv("../../result_files/allnotes_roberta_dedupcont_epoch12_results.csv")

dischargesum_medabert = pd.read_csv("../../result_files/dischargesum_medabert_results.csv")
allnotes_medabert_dedupcont = pd.read_csv("../../result_files/allnotes_medabert_results.csv")

dischargesum_bert = pd.read_csv("../../result_files/dischargesum_bert_results.csv")
allnotes_bert_dedupcont = pd.read_csv("../../result_files/allnotes_bert_results.csv")

lr_results_path1 = "../../logistic_regression/LogRegDischargeBest_temp_test.csv"
lr_dis_temp_test = pd.read_csv(lr_results_path1, index_col=0)

lr_results_path2 = "../../logistic_regression/LogRegAllNotDeDupBest_temp_test.csv"
lr_all_temp_test = pd.read_csv(lr_results_path2, index_col=0)

In [None]:
models_discharge = [dischargesum_psyroberta_p4_epoch12, 
          dischargesum_roberta_epoch12,
          dischargesum_medabert,
          dischargesum_bert,
          lr_dis_temp_test]

models_allnotes = [allnotes_psyroberta_p4_dedupcont_epoch12,
          allnotes_roberta_dedupcont_epoch12,
          allnotes_medabert_dedupcont,
          allnotes_bert_dedupcont,
          lr_all_temp_test]

model_names_discharge = ["PsyRoBERTa (Discharge Summaries)",
               "RøBÆRTa (Discharge Summaries)",
               "MeDa-BERT (Discharge Summaries)",
               "BERT (Discharge Summaries)",
               "LR (Discharge Summaries)"]

model_names_allnotes = ["PsyRoBERTa (All Notes)",
               "RøBÆRTa (All Notes)",
               "MeDa-BERT (All Notes)",
               "BERT (All Notes)",
               "LR (All Notes)"]

In [None]:
EPOCH=11

arr1 = np.zeros(shape=(len(models_discharge), len(models_discharge)))

for i, (m1, model_name1) in enumerate(list(zip(models_discharge, model_names_discharge))):

    if "LR" in model_name1:
        temp_test1 = m1
    else:
        res1 = m1[m1.epoch==EPOCH]
        temp_test1 = get_temp_df(res1, split="test")

    for j, (m2, model_name2) in enumerate(list(zip(models_discharge, model_names_discharge))):
        
        if "LR" in model_name2:
            temp_test2 = m2
        else:
            res2 = m2[m2.epoch==EPOCH]
            temp_test2 = get_temp_df(res2, split="test")

        print(model_name1, "VS", model_name2)
        print()

        print_model_results(model_name1, "mcc", temp_test1.p_mean, temp_test1.target)
        print_model_results(model_name2, "mcc", temp_test2.p_mean, temp_test2.target)
        #print_model_results(model_name1, "auc", temp_test1.p_mean, temp_test1.target)
        #print_model_results(model_name2, "auc", temp_test2.p_mean, temp_test2.target)

        assert temp_test1.target.values.tolist()==temp_test2.target.values.tolist()

        differences, observed_difference = calculate_bootstrap_difference("mcc", temp_test1.p_mean, temp_test2.p_mean, temp_test1.target, n_resamples=1000)
        #differences, observed_difference = calculate_bootstrap_difference("auc", temp_test1.p_mean, temp_test2.p_mean, temp_test1.target, n_resamples=1000)

        pval = calculate_p_value(differences, observed_difference)

        arr1[i,j] = pval

        print(pval)

        interpret_p_value(pval)

        plot_histogram(differences, observed_difference, "mcc")
        #plot_histogram(differences, observed_difference, "auc")
        
        plt.savefig(f"../../output/pval_bootstraphist_{model_name1}_VS_{model_name2}.png", bbox_inches="tight")
        plt.savefig(f"../../output/pval_bootstraphist_{model_name1}_VS_{model_name2}.pdf", bbox_inches="tight")
        
        #plt.savefig(f"../../output/pval_bootstraphist_AUC_{model_name1}_VS_{model_name2}.png", bbox_inches="tight")
        #plt.savefig(f"../../output/pval_bootstraphist_AUC_{model_name1}_VS_{model_name2}.pdf", bbox_inches="tight")
        
        print("________________________________")
        print()

In [None]:
EPOCH=11

arr2 = np.zeros(shape=(len(models_allnotes), len(models_allnotes)))

for i, (m1, model_name1) in enumerate(list(zip(models_allnotes, model_names_allnotes))):

    if "LR" in model_name1:
        temp_test1 = m1
    else:
        res1 = m1[m1.epoch==EPOCH]
        temp_test1 = get_temp_df(res1, split="test")

    for j, (m2, model_name2) in enumerate(list(zip(models_allnotes, model_names_allnotes))):
        
        if "LR" in model_name2:
            temp_test2 = m2
        else:
            res2 = m2[m2.epoch==EPOCH]
            temp_test2 = get_temp_df(res2, split="test")

        print(model_name1, "VS", model_name2)
        print()

        print_model_results(model_name1, "mcc", temp_test1.p_mean, temp_test1.target)
        print_model_results(model_name2, "mcc", temp_test2.p_mean, temp_test2.target)
        #print_model_results(model_name1, "auc", temp_test1.p_mean, temp_test1.target)
        #print_model_results(model_name2, "auc", temp_test2.p_mean, temp_test2.target)

        assert temp_test1.target.values.tolist()==temp_test2.target.values.tolist()

        differences, observed_difference = calculate_bootstrap_difference("mcc", temp_test1.p_mean, temp_test2.p_mean, temp_test1.target, n_resamples=1000)
        #differences, observed_difference = calculate_bootstrap_difference("auc", temp_test1.p_mean, temp_test2.p_mean, temp_test1.target, n_resamples=1000)

        pval = calculate_p_value(differences, observed_difference)

        arr2[i,j] = pval

        print(np.round(pval, 10))

        interpret_p_value(pval)

        plot_histogram(differences, observed_difference, "mcc")
        #plot_histogram(differences, observed_difference, "auc")

        plt.savefig(f"../../output/pval_bootstraphist_AllNotes_{model_name1}_VS_{model_name2}.png", bbox_inches="tight")
        plt.savefig(f"../../output/pval_bootstraphist_AllNotes_{model_name1}_VS_{model_name2}.pdf", bbox_inches="tight")
        
        #plt.savefig(f"../../output/pval_bootstraphist_AUC_AllNotes_{model_name1}_VS_{model_name2}.png", bbox_inches="tight")
        #plt.savefig(f"../../output/pval_bootstraphist_AUC_Allnotes_{model_name1}_VS_{model_name2}.pdf", bbox_inches="tight")
        
        print("________________________________")
        print()

In [None]:
import pickle

with open("arr1_discharge_mcc_pval.pkl", "wb") as file:
    pickle.dump(arr1, file)

#with open("arr1_discharge_auc_pval.pkl", "wb") as file:
#    pickle.dump(arr1, file)

with open("arr2_allnotes_mcc_pval.pkl", "wb") as file:
    pickle.dump(arr2, file)

#with open("arr2_allnotes_auc_pval.pkl", "wb") as file:
#    pickle.dump(arr2, file)

In [None]:
import matplotlib
from matplotlib.ticker import AutoMinorLocator

# Functions from matplotlib docs

def heatmap(data, row_labels, col_labels, ax=None,
            cbar_kw=None, cbarlabel="", mask=True, **kwargs):
    """
    Create a heatmap from a numpy array and two lists of labels.

    Parameters
    ----------
    data
        A 2D numpy array of shape (M, N).
    row_labels
        A list or array of length M with the labels for the rows.
    col_labels
        A list or array of length N with the labels for the columns.
    ax
        A `matplotlib.axes.Axes` instance to which the heatmap is plotted.  If
        not provided, use current Axes or create a new one.  Optional.
    cbar_kw
        A dictionary with arguments to `matplotlib.Figure.colorbar`.  Optional.
    cbarlabel
        The label for the colorbar.  Optional.
    **kwargs
        All other arguments are forwarded to `imshow`.
    """

    if ax is None:
        ax = plt.gca()

    if cbar_kw is None:
        cbar_kw = {}

    if mask==True:
        # Lower triangle mask
        lower_mask = np.tri(data.shape[0], data.shape[1], k=-1)

        # Upper triangle mask
        upper_mask = lower_mask.T

        # Mask the upper triangle
        masked_data_upper = np.ma.array(data, mask=upper_mask)
        # Plot the heatmap
        im = ax.imshow(masked_data_upper, **kwargs)
    else:
        im = ax.imshow(data, **kwargs)

    # Create colorbar
    cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
    cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")

    ax.set_xticks(np.arange(data.shape[1]), minor=False)
    ax.set_yticks(np.arange(data.shape[0]), minor=False)

    # Show all ticks and label them with the respective list entries.
    ax.set_xticklabels(labels=col_labels,
                  rotation=0, ha="center", rotation_mode="anchor", fontsize=9)
    ax.set_yticklabels(labels=row_labels, fontsize=9)

    # Let the horizontal axes labeling appear on top.
    ax.tick_params(top=False, bottom=False,
                   labeltop=False, labelbottom=True)

    # Turn spines off and create white grid.
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["left"].set_visible(False)
    ax.spines["bottom"].set_visible(False)

    
    ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
    ax.minorticks_on()
    ax.tick_params(which="minor", bottom=False, left=False)
    # Only show minor gridlines once in between major gridlines.
    ax.xaxis.set_minor_locator(AutoMinorLocator(2))
    ax.yaxis.set_minor_locator(AutoMinorLocator(2))

    return im, cbar


def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
                     textcolors=("black", "white"),
                     threshold=None, **textkw):
    """
    A function to annotate a heatmap.

    Parameters
    ----------
    im
        The AxesImage to be labeled.
    data
        Data used to annotate.  If None, the image's data is used.  Optional.
    valfmt
        The format of the annotations inside the heatmap.  This should either
        use the string format method, e.g. "$ {x:.2f}", or be a
        `matplotlib.ticker.Formatter`.  Optional.
    textcolors
        A pair of colors.  The first is used for values below a threshold,
        the second for those above.  Optional.
    threshold
        Value in data units according to which the colors from textcolors are
        applied.  If None (the default) uses the middle of the colormap as
        separation.  Optional.
    **kwargs
        All other arguments are forwarded to each call to `text` used to create
        the text labels.
    """

    if not isinstance(data, (list, np.ndarray)):
        data = im.get_array()

    # Normalize the threshold to the images color range.
    if threshold is not None:
        threshold = im.norm(threshold)
    else:
        threshold = im.norm(data.max())/2.

    # Set default alignment to center, but allow it to be
    # overwritten by textkw.
    kw = dict(horizontalalignment="center",
              verticalalignment="center")
    kw.update(textkw)

    # Get the formatter in case a string is supplied
    if isinstance(valfmt, str):
        valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)

    # Loop over the data and create a `Text` for each "pixel".
    # Change the text's color depending on the data.
    texts = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            kw.update(color=textcolors[int(im.norm(data[i, j]) < threshold)])
            text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
            texts.append(text)

    return texts

In [None]:
from matplotlib.colors import ListedColormap

#sns.set_palette('mako')
#sns.set_style("white", {"axes.edgecolor": ".8"})

def pval_heatmap(arr, vec, metric, dataset):
    """
    arr: array of p-values for the difference in model performance
    vec: the models' classification scores
    metric: "mcc" or "auc"
    dataset: "discharge" or "allnotes"
    """

    cmap = ListedColormap(sns.color_palette("mako", n_colors=10))

    fig, ax = plt.subplots(2,1, figsize=(6,6), gridspec_kw={'height_ratios': [4, 0.8]})

    def make_stars(pval):
        if pval < 0.0001:
            return "****"
        elif pval < 0.001:
            return "***"
        elif pval < 0.01:
            return "**"
        elif pval < 0.05:
            return "*"
        else:
            return ""

    fmt = matplotlib.ticker.FuncFormatter(lambda x, pos: make_stars(x))

    x = ["PsyRoBERTa", "RøBÆRTa", "MeDa-BERT", "BERT", "LR"]
    y = x

    im, cbar = heatmap(arr, x, y, ax=ax[0],
                    cmap=cmap, 
                    cbarlabel="p-value"
                    )
    texts = annotate_heatmap(im, valfmt=fmt)

    im2, cbar2 = heatmap(vec, [""], y, ax=ax[1],
                    cmap=cmap, 
                    #norm=,
                    #cbar_kw=dict(ticks=np.arange(0, 1)), 
                    mask=False,
                    cbarlabel=metric
                    )
    texts2 = annotate_heatmap(im2)

    plt.savefig(f"../../output/pval_{metric}_{dataset}_results.png", bbox_inches="tight")
    plt.savefig(f"../../output/pval_{metric}_{dataset}_results.pdf", bbox_inches="tight")

In [None]:
pval_heatmap(arr1, np.array([[0.2849685192985258, 
                    0.22584728413447375,
                    0.26413445007234954,
                    0.21451372766702514,
                    0.22276459614793295]]),
                    "MCC",
                    "discharge")

#pval_heatmap(arr1, np.array([[0.7113156785886554,
#                            0.6887363504523216,
#                            0.7002935070331648,
#                            0.6787924770709964,
#                            0.6798484690301403]]),
#                            "AUC",
#                            "discharge")

In [None]:
pval_heatmap(arr2, np.array([[0.3033716717576357,
                                0.26337014027043776,
                                0.29526487056699585,
                                0.1927258837645364,
                                0.25815224591611685]]),
                                "MCC",
                                "allnotes")

#pval_heatmap(arr2, np.array([[0.7362246997573167,
#                                0.7280923025865467,
#                                0.734360125282612,
#                                0.7178771027358901,
#                                0.7182119848167431]]),
#                                "AUC",
#                                "allnotes")

In [None]:
def calculate_bootstrap_metric_distributions(metric, model_1_predictions, model_2_predictions, test_set_labels, n_resamples=1000):
   
    model_1_metric = calculate_metric(metric, model_1_predictions, test_set_labels)
    model_2_metric = calculate_metric(metric, model_2_predictions, test_set_labels)
    observed_difference = model_1_metric - model_2_metric

    differences = np.empty(n_resamples)
    n_samples = len(test_set_labels)

    bootstrap_estimates_A = []
    bootstrap_estimates_B = []
    for i in range(n_resamples):
        bootstrap_indices = np.random.choice(range(n_samples), size=n_samples, replace=True)
        new_test_set_labels = test_set_labels[bootstrap_indices]
        new_model_1_predictions = model_1_predictions[bootstrap_indices]
        new_model_2_predictions = model_2_predictions[bootstrap_indices]

        model_1_metric = calculate_metric(metric, new_model_1_predictions, new_test_set_labels)
        model_2_metric = calculate_metric(metric, new_model_2_predictions, new_test_set_labels)
        bootstrap_estimates_A.append(model_1_metric)
        bootstrap_estimates_B.append(model_2_metric)
        #differences[i] = model_1_metric - model_2_metric

    #differences = differences - observed_difference
    # Perform a two-sample test to test whether the mean metric of model A is equal to the mean metric of model B or not
    # Use the scipy.stats.ttest_ind function
    t_stat, p_value = scipy.stats.ttest_ind(bootstrap_estimates_A, bootstrap_estimates_B)
    print(f"The t-statistic is {t_stat:.3f} and the p-value is {p_value:.3f}")

    return bootstrap_estimates_A, bootstrap_estimates_B

model_name1, m1, model_name2, m2 = model_names_discharge[0], models_discharge[0], model_names_discharge[2], models_discharge[2]
EPOCH = 11

res1 = m1[m1.epoch==EPOCH]
temp_test1 = get_temp_df(res1, split="test")

res2 = m2[m2.epoch==EPOCH]
temp_test2 = get_temp_df(res2, split="test")

print(model_name1, "VS", model_name2)
print()

print_model_results(model_name1, "mcc", temp_test1.p_mean, temp_test1.target)
print_model_results(model_name2, "mcc", temp_test2.p_mean, temp_test2.target)

distA, distB = calculate_bootstrap_metric_distributions("mcc", temp_test1.p_mean, temp_test2.p_mean, temp_test1.target, n_resamples=1000)

fig, ax = plt.subplots()
ax.hist(distA, bins=100);
ax.hist(distB, bins=100);

In [None]:
t_stat, p_value = scipy.stats.ttest_ind(distA, distB, equal_var='False')
print(f"The t-statistic is {t_stat:.3f} and the p-value is {p_value:.100f}")