In [None]:
import pandas as pd
import numpy as np
from evaluation.metrics import calibration,distributions_js, reliability_curve, expected_calibration_error, accuracy_confidence, roc_confidence
from sklearn.metrics import log_loss
from sklearn.calibration import calibration_curve
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, average_precision_score,accuracy_score
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
import math
import torch
from common.common import load_obj, H5Recorder
import pyro.distributions as distribution
import torch.nn.functional as F
from sklearn.metrics import brier_score_loss
import seaborn as sns
import torch.distributions as dist
import matplotlib

In [None]:
fileConfig = {
    'BEHRT': '',
    'Whitened-GP':'',
    'KISS-GP':'',
    'BE':'',
    'DBGP': '',
    'BO': '',
    'BE+BO': ''
}

color_map = {
    'BEHRT': 'b',
    'Whitened-GP':'g',
    'KISS-GP':'c',
    'BE':'orange',
    'DBGP': 'r',
    'BO': 'y',
    'BE+BO': 'pink'
}


In [None]:
# AUROC
for k,v in fileConfig.items():
    recorder = H5Recorder(v)
    recorder.open(read=True)
    label = recorder.read('label')
    prob = recorder.read('prob')
    
    if k not in ['BEHRT']:
        # AUROC of mean of predictive probabilities
        prob_mean = np.mean(prob, axis=1).reshape(-1)
        print('AUROC {}: {}'.format(k, roc_auc_score(y_true=label, y_score=prob_mean)))
        
        print('AUPRC {}: {}'.format(k, average_precision_score(y_true=label, y_score=prob_mean)))
    elif k in ['BEHRT']:
        print('AUROC {}: {}'.format(k, roc_auc_score(y_true=label, y_score=prob)))
        print('AUPRC {}: {}'.format(k, average_precision_score(y_true=label, y_score=prob)))
    recorder.close()

In [None]:
font = {
    'font.size':15,
    'axes.labelsize':20,
    'xtick.labelsize':15,
    'ytick.labelsize': 15
}

matplotlib.rcParams.update(font)

In [None]:
# Accuracy over confidence
plt.figure(figsize=(10, 3))
bins = 20

for k,v in fileConfig.items():
    recorder = H5Recorder(v)
    recorder.open(read=True)
    label = recorder.read('label')
    prob = recorder.read('prob')
    
    if k not in ['BEHRT']:
        # AUROC of mean of predictive probabilities
        prob_mean = np.mean(prob, axis=1).reshape(-1)
        
        x, y = accuracy_confidence(prob_mean,label, bins)
        idx = ~np.isnan(y)
        x = x[idx]
        y = y[idx]
        
        plt.plot(x, y, c=color_map.get(k), label=k, marker='o')
        
    recorder.close()

plt.xticks(np.arange(0,1,step=0.1))
plt.xlim(0,1)
plt.legend(bbox_to_anchor=(1.5, 1))
plt.tight_layout()
plt.xlabel('predictive probability')
plt.ylabel('accuracy')

In [None]:
# ROC over confidence
plt.figure(figsize=(10, 3))
bins = 15

for k,v in fileConfig.items():
    recorder = H5Recorder(v)
    recorder.open(read=True)
    label = recorder.read('label')
    prob = recorder.read('prob')
    
    if k not in ['BEHRT']:
        # AUROC of mean of predictive probabilities
        prob_mean = np.mean(prob, axis=1).reshape(-1)
        
        x, boundry, y = roc_confidence(prob_mean, label, bins)
        
        idx = ~np.isnan(y)
        x = x[idx]
        y = y[idx]
        idx = np.array(~(y==0))
        x = x[idx]
        y = y[idx]
        
        plt.plot(x, y, c=color_map.get(k), label=k)
        
    recorder.close()

plt.xticks(np.arange(0,1.1,step=0.1))
plt.xlim(0,1)
plt.legend(bbox_to_anchor=(1.5, 1))
plt.tight_layout()
plt.xlabel('predictive probability')
plt.ylabel('auroc')

In [None]:
# calibration curve
plt.figure(figsize=(10, 4))
bins=50

for k,v in fileConfig.items():
    recorder = H5Recorder(v)
    recorder.open(read=True)
    label = recorder.read('label')
    prob = recorder.read('prob')
    
    if k not in ['BEHRT']:
        num_sample = prob.shape[1]
        acc_list = []
        center_list = []
        for i in range(num_sample):
            y_score, empirical_prob, center=reliability_curve(y_true=label, y_score=prob[:,i].reshape(-1), bins=bins)
            acc_list.append(empirical_prob)
            center_list.append(center)
        
       
        
        acc_list = np.stack(acc_list, axis=0)
        center_list = np.stack(center_list, axis=0)

        mean_acc = np.mean(acc_list, axis=0)
        std_acc = np.std(acc_list, axis=0)
        center = np.mean(center_list, axis=0)
        
        idx = ~np.isnan(mean_acc)
        x = center[idx]
        y = mean_acc[idx]
        sigma = std_acc[idx]
        
        plt.plot(x, y, color_map.get(k), label=k)
        plt.fill(np.concatenate([x, x[::-1]]), np.concatenate([y - 1.96 * sigma,(y + 1.96 * sigma)[::-1]]),
                 alpha=.5, fc=color_map.get(k), ec='None')
    
    recorder.close()

plt.plot(np.linspace(0,1,bins), np.linspace(0,1, bins), label='Perfect Calibrated')
plt.xticks(np.arange(0,1.1,step=0.1))
plt.ylim(0,1)
plt.xlim(0,1)
plt.legend(bbox_to_anchor=(1.5, 1))
plt.tight_layout()
plt.xlabel('predictive probability')
plt.ylabel('fraction of positives')

In [None]:
# Uncertainty difference for samples with probability fall into specific group
n_bins=15

df_list = []

plt.figure(figsize=(9,3))
# for patient ever with sample probability higher than 0.9
for k,v in fileConfig.items():
    recorder = H5Recorder(v)
    recorder.open(read=True)
    label = recorder.read('label')
    prob = recorder.read('prob')
    
    label = np.array(label)
    prob = np.array(prob)
    
    if k not in ['BEHRT']:
        mean_prob = np.mean(prob, axis=1).reshape(-1)
        std_prob = np.std(prob, axis=1).reshape(-1)
        
        idx = (mean_prob>0)&(mean_prob<0.5)
        
        patient_prob = mean_prob[idx]
        patient_std = std_prob[idx]
        patient_label = label[idx]
        
        patient_pos = patient_std[patient_label==1]
        patient_neg = patient_std[patient_label==0]
        
        pos_dist = torch.distributions.Normal(loc=torch.tensor(np.mean(patient_pos)), scale=torch.tensor(np.std(patient_pos)))
        neg_dist = torch.distributions.Normal(loc=torch.tensor(np.mean(patient_neg)), scale=torch.tensor(np.std(patient_neg)))
        kl = dist.kl_divergence(neg_dist, pos_dist)
        print('kl divergence {}:{}'.format(k, kl))
        
        temp_df = pd.DataFrame({'model': k, 'label': patient_label, 'std': patient_std})
        df_list.append(temp_df)
    recorder.close()
    
df = pd.concat(df_list)

sns.boxplot(x="model", y='std', hue="label",data=df, showfliers=False)
plt.legend(bbox_to_anchor=(1.5, 1))

In [None]:
# Embedding Analysis
weights = torch.load('')
word_dict = load_obj('')['token2idx']
word_emb_loc = weights['bert.embeddings.word_embeddings.weight_posterior.loc'].cpu()
word_emb_scale = weights['bert.embeddings.word_embeddings.weight_posterior.scale'].cpu()
sigma = F.softplus(word_emb_scale)
dist = distribution.Normal(loc=word_emb_loc, scale=sigma)
entropy = dist.entropy().sum(dim=1)
entropy_df = pd.DataFrame({'code': list(word_dict.keys()), 'entropy':entropy.numpy()})


In [None]:
entropy_df.sort_values(by='entropy', ascending=True)[0:25]

In [None]:
entropy_df.sort_values(by='entropy', ascending=False)[0:15]