### Explainability API ## 

## Imports

In [44]:
import os
import random
import sys

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import torch.nn as nn
import torch.optim as optim
from sklearn import metrics
import random
import torch

from cyclops.processors.column_names import EVENT_NAME
from cyclops.utils.file import load_pickle
from models.temporal.optimizer import Optimizer, EarlyStopper
from models.temporal.utils import (
    get_data,
    get_device,
    get_temporal_model,
    load_checkpoint,
)

from models.temporal.metrics import print_metrics_binary
from drift_detection.gemini.utils import prep, get_use_case_params, import_dataset_hospital, random_shuffle_and_split
from drift_detection.drift_detector.plotter import plot_pretty_confusion_matrix
from drift_detection.gemini.constants import DIAGNOSIS_DICT, HOSPITALS
from models.static.utils import run_model
from drift_detection.drift_detector.explainer import Explainer

## Get data

In [45]:
DATASET = "gemini"
USE_CASE = "mortality"
DIR=os.path.join("/mnt/nfs/project/delirium/drift_exp/OCT-18-2022/gemini",USE_CASE,"saved_models")
ID = SPLIT = "seasonal_winter"#input("Select data split: ")
DIAGNOSIS_TRAJECTORY = input("Select diagnosis trajectory to filter on: ") 
HOSPITAL = input("Select hospital to filter on: ") 

splice_map = {
    "hospital_id": HOSPITALS
}

if DIAGNOSIS_TRAJECTORY != "all":
    diagnosis_trajectory = '_'.join(DIAGNOSIS_DICT[DIAGNOSIS_TRAJECTORY])
    ID = ID +"_"+ diagnosis_trajectory
    splice_map["diagnosis_trajectory"] = [diagnosis_trajectory]
    
if HOSPITAL != "all":
    ID = HOSPITAL + "_" + ID 
    splice_map["hospital_id"] = [HOSPITAL]
    
use_case_params = get_use_case_params(DATASET, USE_CASE)

Select diagnosis trajectory to filter on:  all
Select hospital to filter on:  all


In [46]:
seed = 1
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

X_train_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_train_X_"+ID)
y_train_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_train_y_"+ID)
X_val_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_val_X_"+ID)
y_val_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_val_y_"+ID)
X_test_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_test_X_"+ID)
y_test_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_test_y_"+ID)

X_train = prep(X_train_vec.data)
y_train = prep(y_train_vec.data)
X_val = prep(X_val_vec.data)
y_val = prep(y_val_vec.data)
X_test = prep(X_test_vec.data)
y_test = prep(y_test_vec.data)

(X_train, y_train), (X_val, y_val) = random_shuffle_and_split(X_train, y_train, X_val, y_val)

2023-01-31 00:01:28,525 [1;37mINFO[0m cyclops.utils.file - Loading pickled data from /mnt/nfs/project/delirium/drift_exp/OCT-18-2022/gemini/mortality/./data/4-final/aligned_comb_train_X_seasonal_winter.pkl
2023-01-31 00:01:28,910 [1;37mINFO[0m cyclops.utils.file - Loading pickled data from /mnt/nfs/project/delirium/drift_exp/OCT-18-2022/gemini/mortality/./data/4-final/aligned_comb_train_y_seasonal_winter.pkl
2023-01-31 00:01:28,949 [1;37mINFO[0m cyclops.utils.file - Loading pickled data from /mnt/nfs/project/delirium/drift_exp/OCT-18-2022/gemini/mortality/./data/4-final/aligned_comb_val_X_seasonal_winter.pkl
2023-01-31 00:01:29,050 [1;37mINFO[0m cyclops.utils.file - Loading pickled data from /mnt/nfs/project/delirium/drift_exp/OCT-18-2022/gemini/mortality/./data/4-final/aligned_comb_val_y_seasonal_winter.pkl
2023-01-31 00:01:29,063 [1;37mINFO[0m cyclops.utils.file - Loading pickled data from /mnt/nfs/project/delirium/drift_exp/OCT-18-2022/gemini/mortality/./data/4-final/align

## Get temporal model

In [53]:
output_dim = 1
batch_size = 64
input_dim = X_train.shape[2]
timesteps = X_train.shape[1]
hidden_dim = 64
layer_dim = 2
dropout = 0.2
n_epochs = 128
learning_rate = 2e-3
weight_decay = 1e-6
last_timestep_only = False
device = get_device()

model_params = {
    "device": device,
    "input_dim": input_dim,
    "hidden_dim": hidden_dim,
    "layer_dim": layer_dim,
    "output_dim": output_dim,
    "dropout_prob": dropout,
    "last_timestep_only": last_timestep_only,
}

model_name = "lstm"
model = get_temporal_model(model_name, model_params).to(device)

filepath=os.path.join(DIR,ID+"_reweight_positive"+"_"+model_name+"_"+str(seed)+".pt")
if os.path.exists(filepath):
    model, opt, n_epochs = load_checkpoint(filepath, model)


In [54]:
import shap
explainer = shap.DeepExplainer(model, X_train[:500])
shapvalues = explainer.shap_values(X_val[:500])

TypeError: 'int' object is not callable

## Get static model

In [None]:
MODEL_NAME = input("Select Model: ")
MODEL_PATH = PATH + "_".join([SHIFT, OUTCOME, "_".join(HOSPITALS), MODEL_NAME]) + ".pkl"
if os.path.exists(MODEL_PATH):
    optimised_model = pickle.load(open(MODEL_PATH, "rb"))
else:
    optimised_model = run_model(MODEL_NAME, X_tr_final, y_tr, X_val_final, y_val)
    pickle.dump(optimised_model, open(MODEL_PATH, "wb"))

## Explain difference in static model predictions ## 

In [None]:
explainer = Explainer(optimised_model, X_tr_final)
explainer.get_explainer()

In [None]:
timesteps = ["T1_", "T2_", "T3_", "T4_", "T5_", "T6_"]

flattened_feats = []
for ts in timesteps:
    flattened_feats.append(ts + feats)
flattened_feats = list(itertools.chain.from_iterable(flattened_feats))

X_val_df = pd.DataFrame(X_val_final, columns=flattened_feats)
val_shap_values = explainer.get_shap_values(X_val_df)
X_test_df = pd.DataFrame(X_t_final, columns=flattened_feats)
test_shap_values = explainer.get_shap_values(X_test_df)

shap_diff = np.mean(np.abs(test_shap_values.values), axis=0) - np.mean(
    np.abs(val_shap_values.values), axis=0
)
shap_min = -0.001
shap_max = 0.001
shap_diff_sorted, feats_sorted = zip(
    *sorted(zip(shap_diff, flattened_feats), reverse=True)
)
shap_diff_sorted, feats_sorted = zip(
    *(
        (
            (x, y)
            for x, y in zip(shap_diff_sorted, feats_sorted)
            if (x > shap_max or x < shap_min)
        )
    )
)

shap_feats = {"feature": feats_sorted, "shap_diff": list(shap_diff_sorted)}

fig, ax = plt.subplots(figsize=(10, 18))
y_pos = np.arange(len(shap_feats["shap_diff"]))
ax.barh(y_pos, shap_feats["shap_diff"], align="center")
ax.set_yticks(y_pos, labels=shap_feats["feature"])
ax.invert_yaxis()  # labels read top-to-bottom
ax.set_xlabel("Mean Difference in Shap Value")
ax.set_title("Features")
plt.show()

shap_diff_sorted, feats_sorted = zip(
    *sorted(zip(shap_diff, flattened_feats), reverse=True)
)
shap_diff_sorted, feats_sorted = zip(
    *(((x, y) for x, y in zip(shap_diff_sorted, feats_sorted) if (x != 0)))
)

for t in ["T1_", "T2_", "T4_", "T4_", "T5_", "T6_"]:
    shap_feats = {"feature": feats_sorted, "shap_diff": list(shap_diff_sorted)}
    shap_feats = {
        k: [
            x
            for i, x in enumerate(v)
            if any(ts in shap_feats["feature"][i] for ts in [t])
        ]
        for k, v in shap_feats.items()
    }
    shap_feats["feature"] = list(map(lambda x: x.replace(t, ""), shap_feats["feature"]))
    fig, ax = plt.subplots(figsize=(12, 12))
    y_pos = np.arange(len(shap_feats["shap_diff"]))
    ax.barh(y_pos, shap_feats["shap_diff"], align="center")
    ax.set_yticks(y_pos, labels=shap_feats["feature"])
    ax.invert_yaxis()  # labels read top-to-bottom
    ax.set_xlabel("Mean Difference in Shap Value |Target - Source|")
    ax.set_title("Features")
    plt.show()