In [None]:
#By Phyllis Thangaraj (pt2281@columbia.edu), Nicholas Tatonetti Lab at Columbia University Irving Medical Center
#Part of manuscript: "Comparative analysis, applications, and interpretation of electronic health record-based stroke phenotyping methods
#This script is to calculate calibration scores adn build Supplementary Figures 5-19

In [None]:
import os
import sys
import csv
import random
import numpy as np
import scipy as sp
import scipy.optimize as opt
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
% matplotlib inline

def nearest(li, x):
  # assume li is sorted
  if len(li) == 1:
    return li[0]
  elif li[len(li)/2] == x:
    return li[len(li)/2]
  elif x < li[len(li)/2]:
    return nearest(li[:(len(li)/2)], x)
  elif x > li[len(li)/2]:
    return nearest(li[len(li)/2:], x)

def calibrate(model_probas, calibration_hash):
    sorted_xes = sorted(calibration_hash.keys())
    calibrated_probas = list()
    for model_proba in model_probas:
        nearest_yhat = nearest(sorted_xes, np.round(model_proba, 3))
        calibrated_probas.append(calibration_hash[nearest_yhat])
    return calibrated_probas

casecontrol=['SN','SI','SC','SCI','SR','TN','TI','TC','TCI','TR','CN','CI','CC','CCI','CR']

yact_test=defaultdict(dict)
yhat_test=defaultdict(dict)
#take in test probabilities

train_rmse=defaultdict(dict)
test_rmse_01=defaultdict(dict)
cal_train_rmse=defaultdict(dict)
cal_test_rmse_01=defaultdict(dict)
testing_set_labels_filename=
cross_val_testing_set_probabilities_filename=
calibration_figures_filename=

for c in range(0,len(casecontrol)):
    labels=np.load(testing_set_labels_filename+casecontrol[c]+".npy")
    datadict = dict()
    datadict['LR'] = np.load(cross_val_testing_set_probabilities_filename+casecontrol[c]+".npy")[0,:,:]
    datadict['RF'] = np.load(cross_val_testing_set_probabilities_filename+casecontrol[c]+".npy")[1,:,:]
    datadict['AB'] = np.load(cross_val_testing_set_probabilities_filename+casecontrol[c]+".npy")[2,:,:]
    datadict['GB'] = np.load(cross_val_testing_set_probabilities_filename+casecontrol[c]+".npy")[3,:,:]
    datadict['EN'] = np.load(cross_val_testing_set_probabilities_filename+casecontrol[c]+".npy")[4,:,:]
    bootstrap_ratio = 100
    fold_downsample_ratio = 0.3
    downsample_ratio = 0.1
    #per model estimated versus actual y values (yhat vs yact)
    sns.set(style='ticks', font_scale=1.2)
    plt.figure(figsize=(3*len(datadict),6),dpi=300)
    for fig_i, (key, model_output) in enumerate(datadict.items()):
        plt.subplot(2,len(datadict),fig_i+1)
        for i in range(0,10):
            fold_model_output = model_output[:,i]
            fold_labels = labels
            # will be too slow, going to downsample
            zipped = zip(fold_model_output, fold_labels)
            sampled_zipped = random.sample(zipped, int(downsample_ratio*len(fold_labels)))
            sorted_values, sorted_labels = zip(*sorted(sampled_zipped))
            bin_size = 100
            stride = 10
            preds = list()
            actuals = list()
            #generate bin averages to get yact
            for i in range(bin_size, len(sorted_values), stride):
                yhat = np.mean(sorted_values[(i-bin_size):(i+bin_size)])
                yact = np.mean(sorted_labels[(i-bin_size):(i+bin_size)])
                preds.append( yhat )
                actuals.append( yact )
        # will be too slow, going to downsample
        zipped = zip(np.mean(model_output,axis=1), labels)
        sampled_zipped = random.sample(zipped, int(downsample_ratio*len(labels)))
        sorted_values, sorted_labels = zip(*sorted(sampled_zipped))

        bin_size = 100
        stride = 10

        preds = list()
        actuals = list()
        yact_test[casecontrol[c]][key] = defaultdict(list)
        for i in range(bin_size, len(sorted_values), stride):
            yhat = np.mean(sorted_values[(i-bin_size):(i+bin_size)])
            yact = np.mean(sorted_labels[(i-bin_size):(i+bin_size)])
            preds.append( yhat )
            actuals.append( yact )
            yact_test[casecontrol[c]][key][np.round(yhat,3)].append(yact)
        zipped=zip(preds,actuals)
        pa_sampled_zipped = random.sample(zipped, int(.01*len(preds)))
        preds_s, acts_s = zip(*sorted(pa_sampled_zipped))
        plt.plot(preds_s, acts_s, 'k.', alpha=0.4)
        for rounded_yhat, yactuals in yact_test[casecontrol[c]][key].items():
            yact_test[casecontrol[c]][key][rounded_yhat] = np.mean(yactuals)
        map_x, map_y = zip(*sorted(yact_test[casecontrol[c]][key].items()))
        train_rmse[casecontrol[c]][key]=np.sqrt(mean_squared_error(map_x_comb[casecontrol[c]][key], map_y_comb[casecontrol[c]][key]))
        args_rmse=np.argwhere(np.array(map_x)>=.1).flatten() 
        if len(args_rmse)>0:
                test_rmse_01[casecontrol[c]][key]=np.sqrt(mean_squared_error(map_x, map_y))
        else:
                test_rmse_01[casecontrol[c]][key]='N/A'
        plt.plot(map_x_comb[casecontrol[c]][key],map_y_comb[casecontrol[c]][key], c='#34579e',marker='.',linestyle="none",label='train')
        plt.plot(map_x, map_y, 'k.',markersize=10,label='test')
        plt.plot([0, 1], [0, 1], '#6c6d70',linestyle=":", alpha=.5)
        plt.xlabel('Stroke Score')
        if fig_i == 0:
            plt.ylabel('Proportion Stroke Pts')
        plt.xlim(0,1)
        plt.ylim(0,1)
        plt.legend(loc='upper left',prop={'size': 12},frameon=False,handlelength=0,labelspacing=0,columnspacing=0,edgecolor='white')
        plt.title(label_sub[c]+ " " +key)
        sns.despine()
    #use calibration lookup table to get yhat

    for fig_i, (key, model_output) in enumerate(datadict.items()):
        plt.subplot(2,len(datadict),len(datadict)+fig_i+1)
        preds=[]
        actuals=[]
        for rounded_yhat,yactuals in yact_test[casecontrol[c]][key].items():
            nearest_rounded_yhat = nearest(sorted(calibration_hashes[casecontrol[c]][key].keys()), rounded_yhat)
            if nearest_rounded_yhat in calibration_hashes[casecontrol[c]][key].keys():
                preds.append(calibration_hashes[casecontrol[c]][key][nearest_rounded_yhat])
                actuals.append(yact_test[casecontrol[c]][key][rounded_yhat])
            else:
                raise Exception("rounded_yhat: %f was not present" % rounded_yhat)
        train_preds=cal_y_hat_train[casecontrol[c]][key]
        train_acts=y_act_train[casecontrol[c]][key]
        cal_train_rmse[casecontrol[c]][key] = np.sqrt(mean_squared_error(train_preds, train_acts))
        zipped=zip(train_preds,train_acts)
        tpa_sampled_zipped = random.sample(zipped, int(1*len(train_preds)))
        t_preds_s, t_actuals_s = zip(*sorted(tpa_sampled_zipped))
        plt.plot(t_preds_s,t_actuals_s,marker='.',linestyle="none",c='#34579e',alpha=.6,label='train $rmse=%.4f$' %cal_train_rmse[casecontrol[c]][key])
        plt.plot([0, 1], [0, 1], '#6c6d70',linestyle=":", alpha=.5)
        args_rmse=np.argwhere(np.array(preds)>=.1).flatten() 
        if len(args_rmse)>0:
            cal_test_rmse_01[casecontrol[c]][key] = np.sqrt(mean_squared_error([preds[j] for j in args_rmse], [actuals[j] for j in args_rmse]))
        else:
            cal_test_rmse_01[casecontrol[c]][key]='N/A'
        zipped=zip(preds,actuals)
        pa_sampled_zipped = random.sample(zipped, int(1*len(preds)))
        preds_s, actuals_s = zip(*sorted(pa_sampled_zipped))
        if cal_test_rmse_01[casecontrol[c]][key]=='N/A':
            plt.plot(preds_s, actuals_s, 'k.', alpha=1.0,markersize=10,label='test $rmse(0.1)=N/A$')
        else:
            plt.plot(preds_s, actuals_s, 'k.', alpha=1.0,markersize=10,label='test $rmse(0.1)=%.4f$'%cal_test_rmse_01[casecontrol[c]][key])
        plt.xlabel('Calibrated Stroke Score')
        if fig_i == 0:
            plt.ylabel('Proportion Stroke Pts')
        plt.legend(loc='lower right',prop={'size': 12},frameon=False,borderpad=0,borderaxespad=0,labelspacing=0,columnspacing=0,handlelength=0,edgecolor='white')
        plt.xlim(0,1)
        plt.ylim(0,1)
        plt.title(label_sub[c]+ " " +key)
        sns.despine()
    plt.tight_layout()
    plt.savefig(calibration_figures_filename+casecontrol[c]+"models.png",dpi=300)