In [None]:
import os
import random
import sys
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import torch.nn as nn
import torch.optim as optim
from sklearn import metrics
import random
import torch
from os import listdir
from os.path import isfile, join

from cyclops.processors.column_names import EVENT_NAME
from cyclops.utils.file import load_pickle
from models.temporal.optimizer import Optimizer, EarlyStopper
from models.temporal.utils import (
    get_data,
    get_device,
    get_temporal_model,
    load_checkpoint,
)
from models.temporal.metrics import print_metrics_binary
from drift_detection.gemini.utils import prep, get_use_case_params, import_dataset_hospital
from drift_detection.drift_detector.plotter import plot_pretty_confusion_matrix
from drift_detection.gemini.constants import DIAGNOSIS_DICT, HOSPITALS

In [4]:
DATASET = "gemini"
ID = "random"

In [None]:
reps = list()
for i in range(1, 6):
    
    test = list()
    codes = list()
    
    for USE_CASE in ["mortality","mortality_cm","mortality_DxC"]:
        DIR=os.path.join("/mnt/nfs/project/delirium/drift_exp/OCT-18-2022/gemini",USE_CASE,"saved_models")

        use_case_params = get_use_case_params(DATASET, USE_CASE)

        torch.manual_seed(i)
        random.seed(i)
        np.random.seed(i)

        X_train_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_train_X_"+ID)
        y_train_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_train_y_"+ID)
        X_val_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_val_X_"+ID)
        y_val_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_val_y_"+ID)
        X_test_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_test_X_"+ID)
        y_test_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_test_y_"+ID)

        X_train = prep(X_train_vec.data)
        y_train = prep(y_train_vec.data)
        X_val = prep(X_val_vec.data)
        y_val = prep(y_val_vec.data)
        X_test = prep(X_test_vec.data)
        y_test = prep(y_test_vec.data)

        output_dim = 1
        batch_size = 64
        input_dim = X_train.shape[2]
        timesteps = X_train.shape[1]
        hidden_dim = 64
        layer_dim = 2
        dropout = 0.2
        learning_rate = 2e-3
        weight_decay = 1e-6
        last_timestep_only = False
        model_name="lstm"
        device = get_device()

        model_params = {
            "device": device,
            "input_dim": input_dim,
            "hidden_dim": hidden_dim,
            "layer_dim": layer_dim,
            "output_dim": output_dim,
            "dropout_prob": dropout,
            "last_timestep_only": last_timestep_only,
        }
        model = get_temporal_model(model_name, model_params).to(device)
        
        filepath=os.path.join(DIR,ID+"_reweight_positive"+"_"+model_name+"_"+str(i)+".pt")
        model = get_temporal_model(model_name, model_params).to(device)
        model, opt, n_epochs = load_checkpoint(filepath, model)

        (X_train, y_train), (X_val, y_val), (X_test, y_test) = import_dataset_hospital(use_case_params.TAB_VEC_COMB, ID)

        test_dataset = get_data(X_test, y_test)
        test_loader = test_dataset.to_loader(batch_size=1, shuffle=True)

        y_test_labels, y_pred_values, y_pred_labels = opt.evaluate(
            test_loader
        )

        y_pred_values = y_pred_values[y_test_labels != -1]
        y_pred_labels = y_pred_labels[y_test_labels != -1]
        y_test_labels = y_test_labels[y_test_labels != -1]

        test_pred_metrics = print_metrics_binary(y_test_labels, y_pred_values, y_pred_labels, verbose=False)

        test.append(test_pred_metrics)
        codes.append(USE_CASE)

    test = output = {
        k: [d[k] for d in test if k in d]
        for k in set().union(*test)
    }

    reps.append(pd.DataFrame(test))

reps = pd.concat(reps,axis=1)
reps['codes'] = codes
auroc = reps[['auroc','codes']]
auroc = pd.melt(auroc, id_vars=['codes'], value_vars=['auroc'])
auprc = reps[['auprc','codes']]
auprc = pd.melt(auprc, id_vars=['codes'], value_vars=['auprc'])
plot_metrics = pd.concat([auroc,auprc])
plot_metrics = plot_metrics.assign(variable=plot_metrics.variable.map({'auroc':'AUROC','auprc':'AUPRC'}))

In [None]:
plot_metrics = pd.concat([auroc,auprc])
plot_metrics = plot_metrics[plot_metrics.codes != "all cause"]
plot_metrics = plot_metrics.assign(variable=plot_metrics.variable.map({'auroc':'AUROC','auprc':'AUPRC','prec1':'PPV','rec1': 'TPR'}))
plot_metrics = plot_metrics.assign(codes=plot_metrics.codes.map({'mortality':'base','mortality_cm':'base+CM','mortality_DxC':'base+DxC'}))
sns.set(rc={'figure.figsize':(2.7,5.27)})
sns.set(style="darkgrid")
sns.set_style(style='white')
sns.boxplot(x = 'codes', y = 'value', hue = 'variable', data = plot_metrics,)
plt.xticks(rotation=45)
plt.legend(bbox_to_anchor=(1.15, 1), loc=2, borderaxespad=0.)

In [None]:
reps = list()
for i in range(1, 6):
    
    test = list()
    codes = list()
    models = list()
    
    for USE_CASE in ["mortality","mortality_cm","mortality_DxC"]:
        DIR=os.path.join("/mnt/nfs/project/delirium/drift_exp/OCT-18-2022/gemini",USE_CASE,"saved_models")

        use_case_params = get_use_case_params(DATASET, USE_CASE)

        torch.manual_seed(i)
        random.seed(i)
        np.random.seed(i)

        X_train_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_train_X_"+ID)
        y_train_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_train_y_"+ID)
        X_val_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_val_X_"+ID)
        y_val_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_val_y_"+ID)
        X_test_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_test_X_"+ID)
        y_test_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_test_y_"+ID)

        X_train = prep(X_train_vec.data)
        y_train = prep(y_train_vec.data)
        X_val = prep(X_val_vec.data)
        y_val = prep(y_val_vec.data)
        X_test = prep(X_test_vec.data)
        y_test = prep(y_test_vec.data)

        output_dim = 1
        batch_size = 64
        input_dim = X_train.shape[2]
        timesteps = X_train.shape[1]
        hidden_dim = 64
        layer_dim = 2
        dropout = 0.2
        learning_rate = 2e-3
        weight_decay = 1e-6
        last_timestep_only = False
        model_name="lstm"
        device = get_device()

        model_params = {
            "device": device,
            "input_dim": input_dim,
            "hidden_dim": hidden_dim,
            "layer_dim": layer_dim,
            "output_dim": output_dim,
            "dropout_prob": dropout,
            "last_timestep_only": last_timestep_only,
        }
        model = get_temporal_model(model_name, model_params).to(device)
        
        filepath=os.path.join(DIR,ID+"_reweight_positive"+"_"+model_name+"_"+str(i)+".pt")
        model = get_temporal_model(model_name, model_params).to(device)
        model, opt, n_epochs = load_checkpoint(filepath, model)

        for value in list(DIAGNOSIS_DICT.values()) + ['all cause']:
            torch.manual_seed(i)
            random.seed(i)

            if value == 'all cause':
                diagnosis_trajectory = 'all cause'
                diagnosis_id = SPLIT
            elif value in ['male','female','adult_18_29','adult_30_44','adult_45_64','geriatric']:
                diagnosis_trajectory = diagnosis_id = value
            else:
                diagnosis_trajectory = value[0]+"_"+value[1]
                diagnosis_id = SPLIT+"_"+ diagnosis_trajectory

            if diagnosis_trajectory in ["H00_H59","H60_H95", "O00_O99", "P00_P96", "Q00_Q99","U07_U08"]:
                continue


            (X_train, y_train), (X_val, y_val), (X_test, y_test) = import_dataset_hospital(use_case_params.TAB_VEC_COMB, diagnosis_id)

            test_dataset = get_data(X_train, y_train)
            test_loader = test_dataset.to_loader(batch_size=1, shuffle=True)

            y_test_labels, y_pred_values, y_pred_labels = opt.evaluate(
                test_loader
            )

            y_pred_values = y_pred_values[y_test_labels != -1]
            y_pred_labels = y_pred_labels[y_test_labels != -1]
            y_test_labels = y_test_labels[y_test_labels != -1]

            test_pred_metrics = print_metrics_binary(y_test_labels, y_pred_values, y_pred_labels, verbose=False)

            test.append(test_pred_metrics)
            codes.append(diagnosis_trajectory)
            models.append(USE_CASE)

    test = output = {
        k: [d[k] for d in test if k in d]
        for k in set().union(*test)
    }

    reps.append(pd.DataFrame(test))

reps = pd.concat(reps,axis=1)
reps['codes'] = codes
reps['model'] = models

In [None]:
auroc = reps[['auroc','codes','model']]
auroc = pd.melt(auroc, id_vars=['codes','model'], value_vars=['auroc'])
auroc = auroc[auroc.codes != "all cause"]
auroc = auroc.assign(model=plot_metrics.model.map({'mortality':'base','mortality_cm':'base+CM','mortality_DxC':'base+DxC'}))
sns.set(rc={'figure.figsize':(15.7,5.27)})
sns.set(style="darkgrid")
sns.set_style(style='white')
sns.boxplot(x = 'codes', y = 'value', hue = 'model', data = auroc).set(
    xlabel="ICD Diagnosis Trajectory", 
    ylabel="AUROC"
)
mortality_auroc = reps.loc[reps['model'] == 'mortality'][['auroc']].iloc[[-1]].mean(axis=1)
plt.plot([-0.5, 15.5], [mortality_auroc, mortality_auroc],'--', label='Avg base', color='blue')
cm_auroc = reps.loc[reps['model'] == 'mortality_cm'][['auroc']].iloc[[-1]].mean(axis=1)
plt.plot([-0.5, 15.5], [cm_auroc, cm_auroc],'--', label='Avg base+CM', color='orange')
dxc_auroc = reps.loc[reps['model'] == 'mortality_DxC'][['auroc']].iloc[[-1]].mean(axis=1)
plt.plot([-0.5, 15.5], [dxc_auroc, dxc_auroc],'--', label='Avg base+DxC', color='green')
plt.xticks(rotation=45)
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

In [None]:
auprc = reps[['auprc','codes','model']]
auprc = pd.melt(auprc, id_vars=['codes','model'], value_vars=['auprc'])
auprc = auprc[auprc.codes != "all cause"]
auprc = auprc.assign(model=plot_metrics.model.map({'mortality':'base','mortality_cm':'base+CM','mortality_DxC':'base+DxC'}))
sns.set(rc={'figure.figsize':(15.7,5.27)})
sns.set(style="darkgrid")
sns.set_style(style='white')
sns.boxplot(x = 'codes', y = 'value', hue = 'model', data = auprc).set(
    xlabel="ICD Diagnosis Trajectory", 
    ylabel="AUPRC"
)
mortality_auprc = reps.loc[reps['model'] == 'mortality'][['auprc']].iloc[[-1]].mean(axis=1)
plt.plot([-0.5, 15.5], [mortality_auprc, mortality_auprc],'--', label='Avg base', color='blue')
cm_auprc = reps.loc[reps['model'] == 'mortality_cm'][['auprc']].iloc[[-1]].mean(axis=1)
plt.plot([-0.5, 15.5], [cm_auprc, cm_auprc],'--', label='Avg base+CM', color='orange')
dxc_auprc = reps.loc[reps['model'] == 'mortality_DxC'][['auprc']].iloc[[-1]].mean(axis=1)
plt.plot([-0.5, 15.5], [dxc_auprc, dxc_auprc],'--', label='Avg base+DxC', color='green')
plt.xticks(rotation=45)
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)