In [None]:
import numpy as np 
import pandas as pd
from sklearn.metrics import roc_auc_score, confusion_matrix, recall_score
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import bootstrap
import sys
from collections import defaultdict
import json
import pickle
# setting path
sys.path.append('../..')
from utils.eval_utils import get_temp_df

In [None]:
# Import azure-core elements
import azureml.core
from azureml.core.workspace import Workspace
from azureml.core import ScriptRunConfig, Environment, Experiment
from azureml.core.environment import CondaDependencies
from azureml.core import Workspace, Datastore, Dataset
from azureml.data.dataset_factory import DataType

# Initiate workspace
workspace = Workspace.from_config()

# Define datastore and load dataset
datastore_name = 'sp_data'
datastore = Datastore.get(workspace, datastore_name)

datastore_paths = [(datastore, '/patients.parquet')] 
ds = Dataset.Tabular.from_parquet_files(path=datastore_paths)
dfP = ds.to_pandas_dataframe()
dfP.head()

In [None]:
events = pd.read_parquet("../../data/acuteReadmission/events_acute_labels.parquet")
import math
# make new encounter column, as acute and non acute have different encounter columns in events dataframe
events["encounter"] = [int(i) if math.isnan(i)==False else int(j) for (i,j) in list(zip(events.EncounterKey_dis.values, events.EncounterKey.values))]

event = events.merge(
        dfP[["DurableKey", "BirthDate"]],
        left_on = 'PatientDurableKey',
        right_on = 'DurableKey')\
        .drop(columns='DurableKey')
event['Age'] = np.floor((pd.to_datetime(event.Date_dis) -pd.to_datetime(event.BirthDate)).dt.days / 365.25).astype(int)

age_df = event[["encounter", "Age"]].copy()

In [None]:
diagnosis_df = pd.read_parquet("../../data/acuteReadmission/afregningsdiagnose-copy.parquet")

def sort_SKS_hierarchy(SKSCode):
    """
    Function to keep DF20 as the most important for action diagnosis.
    """
    if SKSCode.startswith('DF'):
        if SKSCode == 'DF20':
            return '0' + SKSCode
        return SKSCode
    return 'Z' + SKSCode  # Place non-'DF' codes at the end

def get_action_diagnosis(afregnings):
    """
    Function to get information about the action diagnosis or
    main diagnosis related to the admission.
    Parameters
    ------------
    - afregnings. Afregningsdiagnose dataframe.
    """
    afregnings_action = afregnings[
        ['PatientDurableKey',
        'EncounterKey',
        'SKSCode',
        'IsActionDiagnosis']
    ]
    # We only want one main for encounter since we merge to encounters
    afregnings_action = afregnings_action[afregnings_action.IsActionDiagnosis == 1]
    afregnings_action.drop(columns='IsActionDiagnosis', inplace = True)
    # We need to connect to encounters table --> we need some encounterKey
    afregnings_action = afregnings_action[afregnings_action.EncounterKey != -1]
    # Create a temporary sorting column based on hierarchy
    afregnings_action['SortingColumn'] = afregnings_action['SKSCode'].apply(sort_SKS_hierarchy)
    # Group by EncounterKey, sort, and select the first row in each group
    afregnings_action = afregnings_action.sort_values(
        by=['EncounterKey', 'SortingColumn']
        ).groupby('EncounterKey').head(1)
    
    afregnings_action = afregnings_action.drop(columns='SortingColumn')
    # Reset the index
    afregnings_action = afregnings_action.reset_index(drop=True)
    return afregnings_action

afregnings_action = get_action_diagnosis(diagnosis_df)

In [None]:
# Load "encounters" parquet file
datastore_paths = [(datastore, '/encounters.parquet')] 
ds = Dataset.Tabular.from_parquet_files(path=datastore_paths)
cols = ["EncounterKey","PatientDurableKey","DepartmentKey"]
ds = ds.keep_columns(cols)
dfE = ds.to_pandas_dataframe()
dfE.head()

# Load "departments" parquet file
datastore_paths = [(datastore, '/departments.parquet')] 
ds = Dataset.Tabular.from_parquet_files(path=datastore_paths)
dfDe = ds.to_pandas_dataframe()
dfDe.head()

department_df = pd.merge(dfE, dfDe, how="inner", on=["DepartmentKey"])

In [None]:
dischargesum_psyroberta_p4_epoch12 = pd.read_csv("../../result_files/dischargesum_psyroberta_p4_epoch12_results.csv")
dischargesum_roberta_epoch12 = pd.read_csv("../../result_files/dischargesum_roberta_epoch12_results.csv")

allnotes_psyroberta_p4_epoch12 = pd.read_csv("../../result_files/allnotes_psyroberta_p4_epoch12_results.csv")
allnotes_psyroberta_p4_dedupcont_epoch12 = pd.read_csv("../../result_files/allnotes_psyroberta_p4_dedupcont_epoch12_results.csv")
allnotes_roberta_dedupcont_epoch12 = pd.read_csv("../../result_files/allnotes_roberta_dedupcont_epoch12_results.csv")

dischargesum_medabert = pd.read_csv("../../result_files/dischargesum_medabert_results.csv")
allnotes_medabert_dedupcont = pd.read_csv("../../result_files/allnotes_medabert_results.csv")

dischargesum_bert = pd.read_csv("../../result_files/dischargesum_bert_results.csv")
allnotes_bert_dedupcont = pd.read_csv("../../result_files/allnotes_bert_results.csv")

In [None]:
def agegroups(x):
    if x=="Unknown":
        return x
    elif x < 18:
        return "Children"
    elif x >= 18 and x < 35:
        return "Young adults"
    elif x >= 35 and x < 55:
        return "Adults"
    else:
        return "Seniors"

def skscode_to_diagnosis(sks):
    if sks.startswith("DF20"):
        return "Schizophrenia"
    elif sks.startswith("DF2"):
        return "Other psychosis"
    elif sks.startswith("DF30") or sks.startswith("DF31"):
        return "Bipolar/manic"
    elif sks.startswith("DF32") or sks.startswith("DF33"):
        return "Depression"
    elif sks.startswith("DF40") or sks.startswith("DF41") or sks.startswith("DF42"):
        return "Anxiety/OCD"
    elif sks.startswith("DF6"):
        return "Personality disorder"
    elif sks.startswith("DF1"):
        return "SUD"
    else:
        return "Other"
    
def get_attributes(temp_test):
    temp_test["pid"] = temp_test["ID"].apply(lambda x: x.split("_")[0])
    temp_test["eid"] = temp_test["ID"].apply(lambda x: x.split("_")[1])

    print("Getting intersection with Patient table")
    intersection_test = set(temp_test["pid"].astype("int").tolist()).intersection(set(dfP.DurableKey.values.tolist()))

    print("Matching sex and age")
    test_sex = [dfP[dfP["DurableKey"]==int(i)].Sex.item() if int(i) in intersection_test else "Unknown" for i in temp_test["pid"].tolist()]
    test_age = [age_df[age_df.encounter==int(i)].Age.item() if int(i) in age_df.encounter.values else "Unknown" for i in temp_test["eid"].tolist()]
    test_age_groups = [agegroups(i) for i in test_age]
    
    temp_test["Prediction"] = [1 if i>=0.5 else 0 for i in temp_test.p_mean]
    temp_test["Sex"] = test_sex
    temp_test["Age group"] =test_age_groups

    # region
    print("Matching region")
    test_region = [department_df[department_df.EncounterKey==int(i)].RegionId.item() if int(i) in department_df.EncounterKey.values else "Unknown" for i in temp_test["eid"].tolist()]
    temp_test["Region"] = test_region

    # current diagnosis
    print("Getting intersection with diagnosis table")
    intersection_diagnosis = set(temp_test["eid"].astype("int").tolist()).intersection(set(afregnings_action.EncounterKey.values.tolist()))
    print("Matching diagnosis")
    test_diagnosis = np.array([afregnings_action[afregnings_action.EncounterKey==int(i)].SKSCode.item() if int(i) in intersection_diagnosis else "Unknown" for i in temp_test["eid"].tolist()])
    test_diagnosis_simple = np.array([sks[1:3] if sks.startswith("DF") else "Other" for sks in test_diagnosis])
    temp_test["Diagnosis"] = test_diagnosis_simple
    temp_test["Diagnosis_specific"] =  np.array([skscode_to_diagnosis(sks) for sks in test_diagnosis])
    return temp_test

def specificity(targets, preds):
    return recall_score(targets, preds, pos_label=0)

def bootstrap_conf_errors(targets, preds, metric):
    overall_res = metric(targets, preds)
    conf = bootstrap((targets, preds), metric, vectorized=False, paired=True,random_state=22, n_resamples=1000)
    low_conf, high_conf = np.round(conf.confidence_interval[0],3), np.round(conf.confidence_interval[1],3)
    low_err = overall_res-low_conf
    upp_err = high_conf-overall_res
    return overall_res, low_err, upp_err

def bootstrap_conf(targets, preds, metric):
    overall_res = metric(targets, preds)
    conf = bootstrap((targets, preds), metric, vectorized=False, paired=True,random_state=22, n_resamples=1000)
    low_conf, high_conf = np.round(conf.confidence_interval[0],3), np.round(conf.confidence_interval[1],3)
    return overall_res, low_conf, high_conf

def calculate_rates_for_table(temp_test, model_name):

    data_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(str)))

    tn, fp, fn, tp = confusion_matrix(temp_test.target, temp_test.Prediction).ravel()
    base_tpr = tp/(tp+fn)
    base_tnr = tn/(tn+fp)
    base_auc = roc_auc_score(temp_test.target, temp_test.p_mean)
    print("Base: TPR={}, TNR={}, AUROC={}".format(base_tpr, base_tnr, base_auc))

    sex_temp = temp_test[temp_test["Sex"]=="Kvinde"].copy()

    tn, fp, fn, tp = confusion_matrix(sex_temp.target, sex_temp.Prediction).ravel()
    tpr = tp/(tp+fn)
    tnr = tn/(tn+fp)
    female_tpr = recall_score(sex_temp.target, sex_temp.Prediction, pos_label=1)
    female_tnr = specificity(sex_temp.target, sex_temp.Prediction)
    assert tnr==female_tnr
    assert tpr==female_tpr
    female_tpr, female_low_tpr, female_upp_tpr = bootstrap_conf(sex_temp.target, sex_temp.Prediction, recall_score)
    female_tnr, female_low_tnr, female_upp_tnr = bootstrap_conf(sex_temp.target, sex_temp.Prediction, specificity)
    female_auc, female_low_err, female_upp_err = bootstrap_conf(sex_temp.target, sex_temp.p_mean, roc_auc_score)
    data_dict["Sex"]["Female"]["TPR"] = f'{female_tpr}, [{female_low_tpr}, {female_upp_tpr}]'
    data_dict["Sex"]["Female"]["TNR"] = f'{female_tnr}, [{female_low_tnr}, {female_upp_tnr}]'
    data_dict["Sex"]["Female"]["AUC"] = f'{female_auc}, [{female_low_err}, {female_upp_err}]'

    sex_temp = temp_test[temp_test["Sex"]=="Mand"].copy()
    male_tpr, male_low_tpr, male_upp_tpr = bootstrap_conf(sex_temp.target, sex_temp.Prediction, recall_score)
    male_tnr, male_low_tnr, male_upp_tnr = bootstrap_conf(sex_temp.target, sex_temp.Prediction, specificity)
    male_auc, male_low_err, male_upp_err = bootstrap_conf(sex_temp.target, sex_temp.p_mean, roc_auc_score)
    data_dict["Sex"]["Male"]["TPR"] = f'{male_tpr}, [{male_low_tpr}, {male_upp_tpr}]'
    data_dict["Sex"]["Male"]["TNR"] = f'{male_tnr}, [{male_low_tnr}, {male_upp_tnr}]'
    data_dict["Sex"]["Male"]["AUC"] = f'{male_auc}, [{male_low_err}, {male_upp_err}]'

    ["Below 18", "18-34","35-54", "Above 54"]
    age_temp = temp_test[temp_test["Age group"]=="Seniors"].copy()
    seniors_tpr, seniors_low_tpr, seniors_upp_tpr = bootstrap_conf(age_temp.target, age_temp.Prediction, recall_score)
    seniors_tnr, seniors_low_tnr, seniors_upp_tnr = bootstrap_conf(age_temp.target, age_temp.Prediction, specificity)
    seniors_auc, seniors_low_err, seniors_upp_err = bootstrap_conf(age_temp.target, age_temp.p_mean, roc_auc_score)
    data_dict["Age"]["Above 54"]["TPR"] = f'{seniors_tpr}, [{seniors_low_tpr}, {seniors_upp_tpr}]'
    data_dict["Age"]["Above 54"]["TNR"] = f'{seniors_tnr}, [{seniors_low_tnr}, {seniors_upp_tnr}]'
    data_dict["Age"]["Above 54"]["AUC"] = f'{seniors_auc}, [{seniors_low_err}, {seniors_upp_err}]'

    age_temp = temp_test[temp_test["Age group"]=="Adults"].copy()
    adults_tpr, adults_low_tpr, adults_upp_tpr = bootstrap_conf(age_temp.target, age_temp.Prediction, recall_score)
    adults_tnr, adults_low_tnr, adults_upp_tnr = bootstrap_conf(age_temp.target, age_temp.Prediction, specificity)
    adults_auc, adults_low_err, adults_upp_err = bootstrap_conf(age_temp.target, age_temp.p_mean, roc_auc_score)
    data_dict["Age"]["35-54"]["TPR"] = f'{adults_tpr}, [{adults_low_tpr}, {adults_upp_tpr}]'
    data_dict["Age"]["35-54"]["TNR"] = f'{adults_tnr}, [{adults_low_tnr}, {adults_upp_tnr}]'
    data_dict["Age"]["35-54"]["AUC"] = f'{adults_auc}, [{adults_low_err}, {adults_upp_err}]'

    age_temp = temp_test[temp_test["Age group"]=="Young adults"].copy()
    youngadults_tpr, youngadults_low_tpr, youngadults_upp_tpr = bootstrap_conf(age_temp.target, age_temp.Prediction, recall_score)
    youngadults_tnr, youngadults_low_tnr, youngadults_upp_tnr = bootstrap_conf(age_temp.target, age_temp.Prediction, specificity)
    youngadults_auc, youngadults_low_err, youngadults_upp_err = bootstrap_conf(age_temp.target, age_temp.p_mean, roc_auc_score)
    data_dict["Age"]["18-34"]["TPR"] = f'{youngadults_tpr}, [{youngadults_low_tpr}, {youngadults_upp_tpr}]'
    data_dict["Age"]["18-34"]["TNR"] = f'{youngadults_tnr}, [{youngadults_low_tnr}, {youngadults_upp_tnr}]'
    data_dict["Age"]["18-34"]["AUC"] = f'{youngadults_auc}, [{youngadults_low_err}, {youngadults_upp_err}]'

    age_temp = temp_test[temp_test["Age group"]=="Children"].copy()
    children_tpr, children_low_tpr, children_upp_tpr = bootstrap_conf(age_temp.target, age_temp.Prediction, recall_score)
    children_tnr, children_low_tnr, children_upp_tnr = bootstrap_conf(age_temp.target, age_temp.Prediction, specificity)
    children_auc, children_low_err, children_upp_err = bootstrap_conf(age_temp.target, age_temp.p_mean, roc_auc_score)
    data_dict["Age"]["Below 18"]["TPR"] = f'{children_tpr}, [{children_low_tpr}, {children_upp_tpr}]'
    data_dict["Age"]["Below 18"]["TNR"] = f'{children_tnr}, [{children_low_tnr}, {children_upp_tnr}]'
    data_dict["Age"]["Below 18"]["AUC"] = f'{children_auc}, [{children_low_err}, {children_upp_err}]'

    region_temp = temp_test[temp_test.Region==10].copy()
    region10_tpr, region10_low_tpr, region10_upp_tpr = bootstrap_conf(region_temp.target, region_temp.Prediction, recall_score)
    region10_tnr, region10_low_tnr, region10_upp_tnr = bootstrap_conf(region_temp.target, region_temp.Prediction, specificity)
    region10_auc, region10_low_err, region10_upp_err = bootstrap_conf(region_temp.target, region_temp.p_mean, roc_auc_score)
    data_dict["Region"]["Capital Region"]["TPR"] = f'{region10_tpr}, [{region10_low_tpr}, {region10_upp_tpr}]'
    data_dict["Region"]["Capital Region"]["TNR"] = f'{region10_tnr}, [{region10_low_tnr}, {region10_upp_tnr}]'
    data_dict["Region"]["Capital Region"]["AUC"] = f'{region10_auc}, [{region10_low_err}, {region10_upp_err}]'

    region_temp = temp_test[temp_test.Region==20].copy()
    region20_tpr, region20_low_tpr, region20_upp_tpr = bootstrap_conf(region_temp.target, region_temp.Prediction, recall_score)
    region20_tnr, region20_low_tnr, region20_upp_tnr = bootstrap_conf(region_temp.target, region_temp.Prediction, specificity)
    region20_auc, region20_low_err, region20_upp_err = bootstrap_conf(region_temp.target, region_temp.p_mean, roc_auc_score)
    data_dict["Region"]["Region Zealand"]["TPR"] = f'{region20_tpr}, [{region20_low_tpr}, {region20_upp_tpr}]'
    data_dict["Region"]["Region Zealand"]["TNR"] = f'{region20_tnr}, [{region20_low_tnr}, {region20_upp_tnr}]'
    data_dict["Region"]["Region Zealand"]["AUC"] = f'{region20_auc}, [{region20_low_err}, {region20_upp_err}]'

    
    labels = temp_test.Diagnosis_specific.value_counts()[::-1].keys().tolist()
    labels.pop(-2)
    labels += ["Other"]
    print(labels)

    for diag in labels:
        diag_temp = temp_test[temp_test.Diagnosis_specific==diag].copy()
        diag_tpr, diag_low_tpr, diag_upp_tpr = bootstrap_conf(diag_temp.target, diag_temp.Prediction, recall_score)
        diag_tnr, diag_low_tnr, diag_upp_tnr = bootstrap_conf(diag_temp.target, diag_temp.Prediction, specificity)
        diag_auc, diag_low_err, diag_upp_err = bootstrap_conf(diag_temp.target, diag_temp.p_mean, roc_auc_score)

        data_dict["Diagnosis"][diag]["TPR"] = f'{diag_tpr}, [{diag_low_tpr}, {diag_upp_tpr}]'
        data_dict["Diagnosis"][diag]["TNR"] = f'{diag_tnr}, [{diag_low_tnr}, {diag_upp_tnr}]'
        data_dict["Diagnosis"][diag]["AUC"] = f'{diag_auc}, [{diag_low_err}, {diag_upp_err}]'

    data = json.loads(json.dumps(data_dict))
    with open(f'fairness_results_dict_{model_name}.pkl', 'wb') as file:
        pickle.dump(data, file)



def calculate_rates(temp_test, model_name):
    sns.set(style="whitegrid")
    sns.set_palette('mako_r', n_colors=6)
    model_name = model_name.replace(" ", "").replace("(", "").replace(")", "")

    tn, fp, fn, tp = confusion_matrix(temp_test.target, temp_test.Prediction).ravel()
    base_tpr = tp/(tp+fn)
    base_tnr = tn/(tn+fp)
    base_auc = roc_auc_score(temp_test.target, temp_test.p_mean)
    print("Base: TPR={}, TNR={}, AUROC={}".format(base_tpr, base_tnr, base_auc))

    sex_temp = temp_test[temp_test["Sex"]=="Kvinde"].copy()

    tn, fp, fn, tp = confusion_matrix(sex_temp.target, sex_temp.Prediction).ravel()
    tpr = tp/(tp+fn)
    tnr = tn/(tn+fp)
    female_tpr = recall_score(sex_temp.target, sex_temp.Prediction, pos_label=1)
    female_tnr = specificity(sex_temp.target, sex_temp.Prediction)
    assert tnr==female_tnr
    assert tpr==female_tpr
    female_tpr, female_low_tpr, female_upp_tpr = bootstrap_conf_errors(sex_temp.target, sex_temp.Prediction, recall_score)
    female_tnr, female_low_tnr, female_upp_tnr = bootstrap_conf_errors(sex_temp.target, sex_temp.Prediction, specificity)
    female_auc, female_low_err, female_upp_err = bootstrap_conf_errors(sex_temp.target, sex_temp.p_mean, roc_auc_score)
    
    sex_temp = temp_test[temp_test["Sex"]=="Mand"].copy()
    male_tpr, male_low_tpr, male_upp_tpr = bootstrap_conf_errors(sex_temp.target, sex_temp.Prediction, recall_score)
    male_tnr, male_low_tnr, male_upp_tnr = bootstrap_conf_errors(sex_temp.target, sex_temp.Prediction, specificity)
    male_auc, male_low_err, male_upp_err = bootstrap_conf_errors(sex_temp.target, sex_temp.p_mean, roc_auc_score)

    age_temp = temp_test[temp_test["Age group"]=="Seniors"].copy()
    seniors_tpr, seniors_low_tpr, seniors_upp_tpr = bootstrap_conf_errors(age_temp.target, age_temp.Prediction, recall_score)
    seniors_tnr, seniors_low_tnr, seniors_upp_tnr = bootstrap_conf_errors(age_temp.target, age_temp.Prediction, specificity)
    seniors_auc, seniors_low_err, seniors_upp_err = bootstrap_conf_errors(age_temp.target, age_temp.p_mean, roc_auc_score)

    age_temp = temp_test[temp_test["Age group"]=="Adults"].copy()
    adults_tpr, adults_low_tpr, adults_upp_tpr = bootstrap_conf_errors(age_temp.target, age_temp.Prediction, recall_score)
    adults_tnr, adults_low_tnr, adults_upp_tnr = bootstrap_conf_errors(age_temp.target, age_temp.Prediction, specificity)
    adults_auc, adults_low_err, adults_upp_err = bootstrap_conf_errors(age_temp.target, age_temp.p_mean, roc_auc_score)

    age_temp = temp_test[temp_test["Age group"]=="Young adults"].copy()
    youngadults_tpr, youngadults_low_tpr, youngadults_upp_tpr = bootstrap_conf_errors(age_temp.target, age_temp.Prediction, recall_score)
    youngadults_tnr, youngadults_low_tnr, youngadults_upp_tnr = bootstrap_conf_errors(age_temp.target, age_temp.Prediction, specificity)
    youngadults_auc, youngadults_low_err, youngadults_upp_err = bootstrap_conf_errors(age_temp.target, age_temp.p_mean, roc_auc_score)

    age_temp = temp_test[temp_test["Age group"]=="Children"].copy()
    children_tpr, children_low_tpr, children_upp_tpr = bootstrap_conf_errors(age_temp.target, age_temp.Prediction, recall_score)
    children_tnr, children_low_tnr, children_upp_tnr = bootstrap_conf_errors(age_temp.target, age_temp.Prediction, specificity)
    children_auc, children_low_err, children_upp_err = bootstrap_conf_errors(age_temp.target, age_temp.p_mean, roc_auc_score)

    region_temp = temp_test[temp_test.Region==10].copy()
    region10_tpr, region10_low_tpr, region10_upp_tpr = bootstrap_conf_errors(region_temp.target, region_temp.Prediction, recall_score)
    region10_tnr, region10_low_tnr, region10_upp_tnr = bootstrap_conf_errors(region_temp.target, region_temp.Prediction, specificity)
    region10_auc, region10_low_err, region10_upp_err = bootstrap_conf_errors(region_temp.target, region_temp.p_mean, roc_auc_score)

    region_temp = temp_test[temp_test.Region==20].copy()
    region20_tpr, region20_low_tpr, region20_upp_tpr = bootstrap_conf_errors(region_temp.target, region_temp.Prediction, recall_score)
    region20_tnr, region20_low_tnr, region20_upp_tnr = bootstrap_conf_errors(region_temp.target, region_temp.Prediction, specificity)
    region20_auc, region20_low_err, region20_upp_err = bootstrap_conf_errors(region_temp.target, region_temp.p_mean, roc_auc_score)

    fig, axes = plt.subplots(3,3, figsize=(15,10))

    tpr_vals = [female_tpr, male_tpr]
    tnr_vals = [female_tnr, male_tnr]
    x = np.arange(len(tpr_vals))
    labels = ["Female", "Male"]
    print(labels)
    print(tpr_vals)
    print(tnr_vals)
    axes[0,0].bar(x,tpr_vals, yerr=[[female_low_tpr, male_low_tpr],
                                    [female_upp_tpr, male_upp_tpr]])
    axes[0,0].set_ylim([0, 1])
    axes[0,0].set_xticks(x)
    axes[0,0].set_xticklabels(labels)
    axes[0,0].axhline(y=base_tpr, xmin=0,xmax=1, color="black", ls="--", lw=1)
    axes[0,0].set_ylabel("TPR")
    
    axes[1,0].bar(x,tnr_vals, yerr=[[female_low_tnr, male_low_tnr],
                                    [female_upp_tnr, male_upp_tnr]])
    axes[1,0].set_ylim([0, 1])
    axes[1,0].set_xticks(x)
    axes[1,0].set_xticklabels(labels)
    axes[1,0].axhline(y=base_tnr, xmin=0,xmax=1, color="black", ls="--", lw=1)
    axes[1,0].set_ylabel("TNR")

    sex_auc = [female_auc, male_auc]
    axes[2,0].bar(x, sex_auc, yerr=[[female_low_err, male_low_err],
                                    [female_upp_err, male_upp_err]])
    axes[2,0].set_ylim([0, 1])
    axes[2,0].set_xticks(x)
    axes[2,0].set_xticklabels(labels)
    axes[2,0].axhline(y=base_auc, xmin=0,xmax=1, color="black", ls="--", lw=1)
    axes[2,0].set_ylabel("AUROC")


    age_tpr_vals = [children_tpr, youngadults_tpr, adults_tpr, seniors_tpr]
    age_tnr_vals = [children_tnr, youngadults_tnr, adults_tnr, seniors_tnr]
    x = np.arange(len(age_tpr_vals))
    labels = ["Below 18", "18-34","35-54", "Above 54"]
    print(labels)
    print(age_tpr_vals)
    print(age_tnr_vals)

    axes[0,1].bar(x,age_tpr_vals, yerr=[[children_low_tpr, youngadults_low_tpr, adults_low_tpr, seniors_low_tpr],
                                        [children_upp_tpr, youngadults_upp_tpr, adults_upp_tpr, seniors_upp_tpr]])
    axes[0,1].set_ylim([0, 1])
    axes[0,1].set_xticks(x)
    axes[0,1].set_xticklabels(labels)
    axes[0,1].axhline(y=base_tpr, xmin=0,xmax=1, color="black", ls="--", lw=1)

    axes[1,1].bar(x,age_tnr_vals, yerr=[[children_low_tnr, youngadults_low_tnr, adults_low_tnr, seniors_low_tnr],
                                        [children_upp_tnr, youngadults_upp_tnr, adults_upp_tnr, seniors_upp_tnr]])
    axes[1,1].set_ylim([0, 1])
    axes[1,1].set_xticks(x)
    axes[1,1].set_xticklabels(labels)
    axes[1,1].axhline(y=base_tnr, xmin=0,xmax=1, color="black", ls="--", lw=1)

    age_auc = [children_auc, youngadults_auc, adults_auc, seniors_auc]
    print(age_auc)
    axes[2,1].bar(x, age_auc, yerr=[[children_low_err, youngadults_low_err, adults_low_err, seniors_low_err],
                                    [children_upp_err, youngadults_upp_err, adults_upp_err, seniors_upp_err]])
    axes[2,1].set_ylim([0, 1])
    axes[2,1].set_xticks(x)
    axes[2,1].set_xticklabels(labels)
    axes[2,1].axhline(y=base_auc, xmin=0,xmax=1, color="black", ls="--", lw=1)

    region_tpr_vals = [region10_tpr, region20_tpr]
    region_tnr_vals = [region10_tnr, region20_tnr]
    x = np.arange(len(region_tpr_vals))
    labels = ["Capital Region", "Region Zealand"]
    print(labels)
    print("Region TPR",region_tpr_vals)
    print("Region TNR",region_tnr_vals)

    axes[0,2].bar(x,region_tpr_vals, yerr=[[region10_low_tpr, region20_low_tpr],
                                           [region10_upp_tpr, region20_upp_tpr]])
    axes[0,2].set_ylim([0, 1])
    axes[0,2].set_xticks(x)
    axes[0,2].set_xticklabels(labels)
    axes[0,2].axhline(y=base_tpr, xmin=0,xmax=1, color="black", ls="--", lw=1)

    axes[1,2].bar(x,region_tnr_vals, yerr=[[region10_low_tnr, region20_low_tnr],
                                           [region10_upp_tnr, region20_upp_tnr]])
    axes[1,2].set_ylim([0, 1])
    axes[1,2].set_xticks(x)
    axes[1,2].set_xticklabels(labels)
    axes[1,2].axhline(y=base_tnr, xmin=0,xmax=1, color="black", ls="--", lw=1)

    region_auc = [region10_auc, region20_auc]
    print(region_auc)
    axes[2,2].bar(x, region_auc, yerr=[[region10_low_err, region20_low_err],
                                       [region10_upp_err, region20_upp_err]])
    axes[2,2].set_ylim([0, 1])
    axes[2,2].set_xticks(x)
    axes[2,2].set_xticklabels(labels)
    axes[2,2].axhline(y=base_auc, xmin=0,xmax=1, color="black", ls="--", lw=1)

    plt.savefig("../figures_to_download/equal_odds_auc_demog_{}.pdf".format(model_name), bbox_inches="tight")
    plt.savefig("../figures_to_download/equal_odds_auc_demog_{}.png".format(model_name), bbox_inches="tight")
    plt.show()

    fig, ax = plt.subplots(4,1, figsize=(15,15))
    diagnosis_tpr_vals = []
    diagnosis_upp_tpr = []
    diagnosis_low_tpr = []
    diagnosis_tnr_vals = []
    diagnosis_upp_tnr = []
    diagnosis_low_tnr = []
    diagnosis_auc = []
    diagnosis_low_errors = []
    diagnosis_upp_errors = []
    labels = temp_test.Diagnosis_specific.value_counts()[::-1].keys().tolist()
    labels.pop(-2)
    labels += ["Other"]
    print(labels)

    for diag in labels:
        diag_temp = temp_test[temp_test.Diagnosis_specific==diag].copy()
        diag_tpr, diag_low_tpr, diag_upp_tpr = bootstrap_conf_errors(diag_temp.target, diag_temp.Prediction, recall_score)
        diag_tnr, diag_low_tnr, diag_upp_tnr = bootstrap_conf_errors(diag_temp.target, diag_temp.Prediction, specificity)
        diag_auc, diag_low_err, diag_upp_err = bootstrap_conf_errors(diag_temp.target, diag_temp.p_mean, roc_auc_score)
        
        diagnosis_tpr_vals.append(diag_tpr)
        diagnosis_low_tpr.append(diag_low_tpr)
        diagnosis_upp_tpr.append(diag_upp_tpr)
        diagnosis_tnr_vals.append(diag_tnr)
        diagnosis_low_tnr.append(diag_low_tnr)
        diagnosis_upp_tnr.append(diag_upp_tnr)
        diagnosis_auc.append(diag_auc)
        diagnosis_low_errors.append(diag_low_err)
        diagnosis_upp_errors.append(diag_upp_err)
    x = np.arange(len(diagnosis_tpr_vals))
    print("Diagnosis TPR",diagnosis_tpr_vals)
    print("Diagnosis TNR",diagnosis_tnr_vals)
    print("Diagnosis AUC",diagnosis_auc)
    ax[0].bar(x, diagnosis_tpr_vals, yerr=[diagnosis_low_tpr, diagnosis_upp_tpr])
    ax[1].bar(x, diagnosis_tnr_vals, yerr=[diagnosis_low_tnr, diagnosis_upp_tnr])
    ax[2].bar(x, diagnosis_auc, yerr=[diagnosis_low_errors, diagnosis_upp_errors])
    ax[0].set_ylim([0, 1])
    ax[0].set_xticks(x)
    ax[0].set_xticklabels(labels)
    ax[0].axhline(y=base_tpr, xmin=0,xmax=1, color="black", ls="--", lw=1)
    ax[0].set_ylabel("TPR")
    ax[1].set_ylim([0, 1])
    ax[1].set_xticks(x)
    ax[1].set_xticklabels(labels)
    ax[1].axhline(y=base_tnr, xmin=0,xmax=1, color="black", ls="--", lw=1)
    ax[1].set_ylabel("TNR")
    ax[2].set_ylim([0, 1])
    ax[2].set_xticks(x)
    ax[2].set_xticklabels(labels)
    ax[2].axhline(y=base_auc, xmin=0,xmax=1, color="black", ls="--", lw=1)
    ax[2].set_ylabel("AUROC")
    
    temp_test['Diagnosis_specific'] = pd.Categorical(temp_test['Diagnosis_specific'], labels)
    
    sns.histplot(data=temp_test, x="Diagnosis_specific",hue="target",multiple="stack",shrink=.8, ax=ax[3])
    #plt.savefig("../figures_to_download/equal_odds_auc_diag_{}.pdf".format(model_name), bbox_inches="tight")
    #plt.savefig("../figures_to_download/equal_odds_auc_diag_{}.png".format(model_name), bbox_inches="tight")

In [None]:
models = [dischargesum_psyroberta_p4_epoch12, 
          dischargesum_roberta_epoch12,
          allnotes_psyroberta_p4_epoch12,
          allnotes_psyroberta_p4_dedupcont_epoch12,
          allnotes_roberta_dedupcont_epoch12,
          dischargesum_medabert,
          allnotes_medabert_dedupcont,
          dischargesum_bert,
          allnotes_bert_dedupcont]

model_names = ["PsyRoBERTa (Discharge Summaries)",
               "RoBERTa (Discharge Summaries)",
               "PsyRoBERTa (All Notes, Not Dedup.)",
               "PsyRoBERTa (All Notes)",
               "RoBERTa (All Notes)",
               "MeDa-BERT (Discarge Summaries)",
               "MeDa-BERT (All Notes)",
               "BERT (Discharge Summaries)",
               "BERT (All Notes)"]

EPOCH=11

for m, model_name in list(zip(models, model_names)):
    print(model_name)
    res = m[m.epoch==EPOCH]
    temp_test_ = get_temp_df(res, split="test")
    temp_test = get_attributes(temp_test_)
    calculate_rates_for_table(temp_test, model_name)

In [None]:
lr_discharge_path = "../../logistic_regression/LogRegDischargeBest_temp_test.csv"
lr_temp_test_ = pd.read_csv(lr_discharge_path, index_col=0)
lr_temp_test = get_attributes(lr_temp_test_)
calculate_rates_for_table(lr_temp_test, "LR (Discharge Summaries)")

lr_allnotes_path = "../../logistic_regression/LogRegAllNotDeDupBest_temp_test.csv"
lr_temp_test2_ = pd.read_csv(lr_allnotes_path, index_col=0)
lr_temp_test2 = get_attributes(lr_temp_test2_)
calculate_rates_for_table(lr_temp_test2, "LR (All Notes)")