In [None]:
import numpy as np
import pandas as pd
from sklearn.calibration import calibration_curve
from sklearn.metrics import roc_curve, precision_recall_curve
import torch
from sklearn.decomposition import PCA
# matplotlib
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.transforms as mtransfor
import seaborn as sns
plt.style.use('seaborn')

## Read models' outcome prediction result

In [None]:
tj_adacare = pd.read_pickle('./saved_pkl/tongji_adacare_outcome.pkl')
tj_retain = pd.read_pickle('./saved_pkl/tongji_retain_outcome.pkl')
tj_agent = pd.read_pickle('./saved_pkl/tongji_agent_outcome.pkl')

hm_concare = pd.read_pickle('./saved_pkl/hm_concare_outcome.pkl')
hm_adacare = pd.read_pickle('./saved_pkl/hm_adacare_outcome.pkl')
hm_rnn = pd.read_pickle('./saved_pkl/hm_rnn_outcome.pkl')

tj_adacare_outcome_true, tj_adacare_outcome_pred = tj_adacare['outcome_true'], tj_adacare['outcome_pred']
tj_retain_outcome_true, tj_retain_outcome_pred = tj_retain['outcome_true'], tj_retain['outcome_pred']
tj_agent_outcome_true, tj_agent_outcome_pred = tj_agent['outcome_true'], tj_agent['outcome_pred']

hm_concare_outcome_true, hm_concare_outcome_pred = hm_concare['outcome_true'], hm_concare['outcome_pred']
hm_adacare_outcome_true, hm_adacare_outcome_pred = hm_adacare['outcome_true'], hm_adacare['outcome_pred']
hm_rnn_outcome_true, hm_rnn_outcome_pred = hm_rnn['outcome_true'], hm_rnn['outcome_pred']

## ROC Plot

In [None]:
tj_random_probs = [0 for _ in range(len(tj_adacare_outcome_true))]
tj_p_fpr, tj_p_tpr, _ = roc_curve(tj_adacare_outcome_true, tj_random_probs, pos_label=1)

hm_random_probs = [0 for _ in range(len(hm_adacare_outcome_true))]
hm_p_fpr, hm_p_tpr, _ = roc_curve(hm_adacare_outcome_true, hm_random_probs, pos_label=1)

In [None]:
# [TJH] plot roc curves

tj_adacare_fpr, tj_adacare_tpr, thresh1 = roc_curve(tj_adacare_outcome_true, tj_adacare_outcome_pred, pos_label=1)
tj_retain_fpr, tj_retain_tpr, thresh2 = roc_curve(tj_retain_outcome_true, tj_retain_outcome_pred, pos_label=1)
tj_agent_fpr, tj_agent_tpr, thresh3 = roc_curve(tj_agent_outcome_true, tj_agent_outcome_pred, pos_label=1)

plt.plot(tj_p_fpr, tj_p_tpr, linestyle='-.', color='grey', label='Chance')
plt.plot(tj_adacare_fpr, tj_adacare_tpr, linestyle='solid',color='orange', label='AdaCare')
plt.plot(tj_retain_fpr, tj_retain_tpr, linestyle='dashed',color='dodgerblue', label='RETAIN')
plt.plot(tj_agent_fpr, tj_agent_tpr, linestyle='dotted',color='violet', label='Dr. Agent')

# # title
# plt.title('ROC curve')
# x label
plt.xlabel('False Positive Rate')
# y label
plt.ylabel('True Positive Rate')

plt.legend(loc='lower right')
plt.savefig('tjh_roc.pdf', dpi=500, format="pdf", bbox_inches="tight")
plt.show();

In [None]:
# [CDSL] plot roc curves

hm_concare_fpr, hm_concare_tpr, thresh1 = roc_curve(hm_concare_outcome_true, hm_concare_outcome_pred, pos_label=1)
hm_adacare_fpr, hm_adacare_tpr, thresh2 = roc_curve(hm_adacare_outcome_true, hm_adacare_outcome_pred, pos_label=1)
hm_rnn_fpr, hm_rnn_tpr, thresh3 = roc_curve(hm_rnn_outcome_true, hm_rnn_outcome_pred, pos_label=1)

plt.plot(hm_p_fpr, hm_p_tpr, linestyle='-.', color='grey', label='Chance')
plt.plot(hm_concare_fpr, hm_concare_tpr, linestyle='solid',color='orange', label='ConCare')
plt.plot(hm_adacare_fpr, hm_adacare_tpr, linestyle='dashed',color='dodgerblue', label='AdaCare')
plt.plot(hm_rnn_fpr, hm_rnn_tpr, linestyle='dotted',color='violet', label='RNN')

# # title
# plt.title('ROC curve')
# x label
plt.xlabel('False positive rate')
# y label
plt.ylabel('True positive rate')

plt.legend(loc='lower right')
plt.savefig('cdsl_roc.pdf', dpi=500, format="pdf", bbox_inches="tight")
plt.show();

## PRC Plot

In [None]:
# [TJH] plot precision-recall curves

tj_adacare_precision, tj_adacare_recall, thresh1 = precision_recall_curve(tj_adacare_outcome_true, tj_adacare_outcome_pred, pos_label=1)
tj_retain_precision, tj_retain_recall, thresh2 = precision_recall_curve(tj_retain_outcome_true, tj_retain_outcome_pred, pos_label=1)
tj_agent_precision, tj_agent_recall, thresh3 = precision_recall_curve(tj_agent_outcome_true, tj_agent_outcome_pred, pos_label=1)

plt.plot(tj_adacare_precision, tj_adacare_recall, linestyle='solid',color='orange', label='AdaCare')
plt.plot(tj_retain_precision, tj_retain_recall, linestyle='dashed',color='dodgerblue', label='RETAIN')
plt.plot(tj_agent_precision, tj_agent_recall, linestyle='dotted',color='violet', label='Dr. Agent')

# # title
# plt.title('PRC curve')
# x label
plt.xlabel('Recall')
# y label
plt.ylabel('Precision')

plt.legend(loc='lower left')
plt.savefig('tjh_prc.pdf', dpi=500, format="pdf", bbox_inches="tight")
plt.show();

In [None]:
# [CDSL] plot precision-recall curves

hm_concare_precision, hm_concare_recall, thresh1 = precision_recall_curve(hm_concare_outcome_true, hm_concare_outcome_pred, pos_label=1)
hm_adacare_precision, hm_adacare_recall, thresh2 = precision_recall_curve(hm_adacare_outcome_true, hm_adacare_outcome_pred, pos_label=1)
hm_rnn_precision, hm_rnn_recall, thresh3 = precision_recall_curve(hm_rnn_outcome_true, hm_rnn_outcome_pred, pos_label=1)

plt.plot(hm_concare_precision, hm_concare_recall, linestyle='solid',color='orange', label='ConCare')
plt.plot(hm_adacare_precision, hm_adacare_recall, linestyle='dashed',color='dodgerblue', label='AdaCare')
plt.plot(hm_rnn_precision, hm_rnn_recall, linestyle='dotted',color='violet', label='RNN')

# # title
# plt.title('PRC curve')
# x label
plt.xlabel('Recall')
# y label
plt.ylabel('Precision')

plt.legend(loc='lower left')
plt.savefig('cdsl_prc.pdf', dpi=500, format="pdf", bbox_inches="tight")
plt.show();

## Calibration Plot

In [None]:
tj_adacare_prob_true, tj_adacare_prob_pred = calibration_curve(tj_adacare_outcome_true, tj_adacare_outcome_pred, n_bins=10)
tj_retain_prob_true, tj_retain_prob_pred = calibration_curve(tj_retain_outcome_true, tj_retain_outcome_pred, n_bins=10)
tj_agent_prob_true, tj_agent_prob_pred = calibration_curve(tj_agent_outcome_true, tj_agent_outcome_pred, n_bins=10)

fig, ax = plt.subplots()
# only these two lines are calibration curves
plt.plot(tj_adacare_prob_pred, tj_adacare_prob_true, marker='o', linewidth=1, label='AdaCare')
plt.plot(tj_retain_prob_pred, tj_retain_prob_true, marker='v', linewidth=1, label='RETAIN')
plt.plot(tj_agent_prob_pred, tj_agent_prob_true, marker='s', linewidth=1, label='Dr. Agent')

# reference line, legends, and axis labels
line = mlines.Line2D([0, 1], [0, 1], linestyle='-.', color='grey')
transform = ax.transAxes
line.set_transform(transform)
ax.add_line(line)
ax.set_xlabel('Predicted probability')
ax.set_ylabel('True probability in each bin')
plt.legend(loc='lower right')
plt.savefig('tjh_calibration.pdf', dpi=500, format="pdf", bbox_inches="tight")
plt.show()

In [None]:
hm_concare_prob_true, hm_concare_prob_pred = calibration_curve(hm_concare_outcome_true, hm_concare_outcome_pred, n_bins=10)
hm_adacare_prob_true, hm_adacare_prob_pred = calibration_curve(hm_adacare_outcome_true, hm_adacare_outcome_pred, n_bins=10)
hm_rnn_prob_true, hm_rnn_prob_pred = calibration_curve(hm_rnn_outcome_true, hm_rnn_outcome_pred, n_bins=10)

fig, ax = plt.subplots()
# only these two lines are calibration curves
plt.plot(hm_concare_prob_pred, hm_concare_prob_true, marker='o', linewidth=1, label='ConCare')
plt.plot(hm_adacare_prob_pred, hm_adacare_prob_true, marker='v', linewidth=1, label='AdaCare')
plt.plot(hm_rnn_prob_pred, hm_rnn_prob_true, marker='s', linewidth=1, label='RNN')

# reference line, legends, and axis labels
line = mlines.Line2D([0, 1], [0, 1], linestyle='-.', color='grey')
transform = ax.transAxes
line.set_transform(transform)
ax.add_line(line)
ax.set_xlabel('Predicted probability')
ax.set_ylabel('True probability in each bin')
plt.legend(loc='lower right')
plt.savefig('cdsl_calibration.pdf', dpi=500, format="pdf", bbox_inches="tight")
plt.show()

## Draw OSMAE/EMP scores on different threshold

In [None]:
covid_scores = pd.read_pickle('./saved_pkl/covid_evaluation_scores.pkl')
emp, osmae, thresholds = covid_scores["emp"][1:], covid_scores["osmae"][1:], covid_scores["threshold"][1:]

In [None]:
## EMP Score
ax = sns.regplot(x=thresholds, y=emp, marker="o", color="g", line_kws={"color": "grey", "linestyle": "dashed"}, ci=99.9999)
plt.xlabel('Threshold γ')
plt.ylabel('EMP score')

plt.savefig('emp_trend.pdf', dpi=500, format="pdf", bbox_inches="tight")
plt.show();

In [None]:
## OSMAE Score
ax = sns.regplot(x=thresholds, y=osmae, marker="o", color="dodgerblue", line_kws={"color": "grey", "linestyle": "dashed"}, ci=99.9999)
plt.xlabel('Threshold γ')
plt.ylabel('OSMAE score')

plt.savefig('osmae_trend.pdf', dpi=500, format="pdf", bbox_inches="tight")
plt.show();

## Draw hidden state PCA result

In [None]:
from app import models

model = models.RETAIN(input_dim=99, hidden_dim=128)

In [None]:
def extract_backbone_param(ckpt):
    backbone = {}
    for k,v in ckpt.items():
        if "backbone" in k:
            new_k = k.replace("backbone.", "")
            backbone[new_k] = v
    return backbone

In [None]:
x = pd.read_pickle("datasets/hm/processed_data/x.pkl")
y = pd.read_pickle("datasets/hm/processed_data/y.pkl")
visits_length = pd.read_pickle("datasets/hm/processed_data/visits_length.pkl")
device = torch.device("cpu")
outcome_status = y[:, 0, 0]
outcome_status = outcome_status.unsqueeze(-1)
patient = x[:, 0, :]
patient = torch.unsqueeze(patient, dim=1)
patient = patient.float()

### Multitask Model

In [None]:
multitask_ckpt = torch.load("./checkpoints/hm_multitask_retain_ep100_kf10_bs512_hid128_1_seed0.pth", map_location=torch.device('cpu'))
multitask_backbone = extract_backbone_param(multitask_ckpt)
model.load_state_dict(multitask_backbone)
out = model(patient, device)
out = torch.squeeze(out)
out = out.detach().numpy()

In [None]:
pca = PCA(2)  # project from 128 to 2 dimensions
projected = pca.fit_transform(out)

concatenated = np.concatenate([projected, outcome_status], axis=1)
df = pd.DataFrame(concatenated, columns = ['Component 1', 'Component 2', 'Outcome'])
df['Outcome'].replace({1: 'Dead', 0: 'Alive'}, inplace=True)

sns.scatterplot(data=df, x="Component 1", y="Component 2", hue="Outcome", style="Outcome", palette=["C1", "C2"])
plt.savefig('multitask_pca.pdf', dpi=500, format="pdf", bbox_inches="tight")

### Outcome Prediction Model

In [None]:
outcome_ckpt = torch.load("./checkpoints/hm_outcome_retain_ep100_kf10_bs512_hid128_1_seed0.pth", map_location=torch.device('cpu'))
outcome_backbone = extract_backbone_param(outcome_ckpt)
model.load_state_dict(outcome_backbone)
out = model(patient, device)
out = torch.squeeze(out)
out = out.detach().numpy()

In [None]:
pca = PCA(2)  # project from 128 to 2 dimensions
projected = pca.fit_transform(out)

concatenated = np.concatenate([projected, outcome_status], axis=1)
df = pd.DataFrame(concatenated, columns = ['Component 1', 'Component 2', 'Outcome'])
df['Outcome'].replace({1: 'Dead', 0: 'Alive'}, inplace=True)

sns.scatterplot(data=df, x="Component 1", y="Component 2", hue="Outcome", style="Outcome", palette=["C1", "C2"])
plt.savefig('outcome_pca.pdf', dpi=500, format="pdf", bbox_inches="tight")

### LOS Prediction model

In [None]:
los_ckpt = torch.load("./checkpoints/hm_los_retain_ep100_kf10_bs512_hid128_1_seed0.pth", map_location=torch.device('cpu'))
los_backbone = extract_backbone_param(los_ckpt)
model.load_state_dict(los_backbone)
out = model(patient, device)
out = torch.squeeze(out)
out = out.detach().numpy()

In [None]:
pca = PCA(2)  # project from 128 to 2 dimensions
projected = pca.fit_transform(out)

concatenated = np.concatenate([projected, outcome_status], axis=1)
df = pd.DataFrame(concatenated, columns = ['Component 1', 'Component 2', 'Outcome'])
df['Outcome'].replace({1: 'Dead', 0: 'Alive'}, inplace=True)

sns.scatterplot(data=df, x="Component 1", y="Component 2", hue="Outcome", style="Outcome", palette=["C1", "C2"])
plt.savefig('los_pca.pdf', dpi=500, format="pdf", bbox_inches="tight")