In [1]:
import pandas as pd
import pickle
import numpy as np
import math
import warnings
import model_metrics_helper
from sklearn.metrics import confusion_matrix

warnings.filterwarnings("ignore")
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

pd.options.display.max_columns = 500
pd.options.display.max_rows = 500
pd.options.display.float_format = '{:,.4f}'.format

### Load Test Data

In [2]:
test_scores_df = pd.read_csv('/gpfs/data/paulab/sepsis_floor_prediction/sepsis_real_time_prediction/model_results/test_set_results_20220704_221130.csv')
test_scores_df.rename(columns={'SepsisTrueLabel':'SepsisLabel'}, inplace=True)

In [7]:
test_scores_df.iloc[:, 1:].head(5)

Unnamed: 0,LOS,PredictedProbability,SepsisLabel,t_timezero,SepsisLabel_0.005,SepsisLabel_0.01,SepsisLabel_0.02,SepsisLabel_0.03,SepsisLabel_0.04,SepsisLabel_0.05,SepsisLabel_0.06,SepsisLabel_0.07,SepsisLabel_0.09,SepsisLabel_0.1,SepsisLabel_0.15,SepsisLabel_0.2,SepsisLabel_0.25,SepsisLabel_0.3,SepsisLabel_0.4,SepsisLabel_0.5,PredictedProbabilityAblated,SepsisLabel_0.005Ablated,SepsisLabel_0.01Ablated,SepsisLabel_0.02Ablated,SepsisLabel_0.03Ablated,SepsisLabel_0.04Ablated,SepsisLabel_0.05Ablated,SepsisLabel_0.06Ablated,SepsisLabel_0.07Ablated,SepsisLabel_0.09Ablated,SepsisLabel_0.1Ablated,SepsisLabel_0.15Ablated,SepsisLabel_0.2Ablated,SepsisLabel_0.25Ablated,SepsisLabel_0.3Ablated,SepsisLabel_0.4Ablated,SepsisLabel_0.5Ablated
0,0,0.0084,0,,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.0084,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,3,0.0044,0,,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.0044,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,6,0.0047,0,,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.0047,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,9,0.0042,0,,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.0042,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,12,0.0061,0,,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.0061,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


### Run SIRS and MEWS for the test data to compare as baselines with ML results at different thresholds

In [8]:
main_df = pd.read_csv("/gpfs/data/paulab/bvg228/sepsis_real_time_prediction/data/NYU_6hr_preprocessed_48hr_main_df_20220526_100348.csv")
main_df = main_df[main_df['ID'].isin(test_scores_df['ID'])]
main_df.rename(columns={'WHITE BLOOD CELL COUNT':'WBC', 'Systolic_BP':'SBP', 'GCS_Score':'gcs_total_score'}, inplace=True)
test_scores_df = test_scores_df.merge(main_df[['ID', 'LOS', 'rel_time']], on=['ID', 'LOS'], how='left')
test_scores_df['rel_time'] = test_scores_df['rel_time'].astype(int)
admission_time = main_df[['ID', 'AdmissionInstant']].drop_duplicates(keep='first')
test_scores_df = test_scores_df.merge(admission_time, on=['ID'], how='left') 
test_scores_df['AlertTime'] = test_scores_df['LOS'] + 1


# SIRS score
SIRS = model_metrics_helper.SIRS(main_df)
SIRS['rel_time'] = SIRS['rel_time'].astype(int)
# MEWS score
MEWS = model_metrics_helper.MEWS(main_df)
MEWS['rel_time'] = MEWS['rel_time'].astype(int)

# merge with prediction results
test_scores_df = test_scores_df.merge(SIRS, on = ['ID', 'rel_time'], how = "left")
test_scores_df = test_scores_df.merge(MEWS, on = ['ID', 'rel_time'], how = "left")

test_scores_df["SIRS"] = test_scores_df["SIRS"] >= 2
test_scores_df["MEWS"] = test_scores_df["MEWS"] >= 5

test_scores_df['SIRS'] = test_scores_df['SIRS'].astype(int)
test_scores_df['MEWS'] = test_scores_df['MEWS'].astype(int)
test_scores_df.rename(columns={'SIRS':'SepsisLabel_SIRS', 'MEWS':'SepsisLabel_MEWS'}, inplace=True)
test_scores_df.iloc[:, 1:].head(5)

Unnamed: 0,LOS,PredictedProbability,SepsisLabel,t_timezero,SepsisLabel_0.005,SepsisLabel_0.01,SepsisLabel_0.02,SepsisLabel_0.03,SepsisLabel_0.04,SepsisLabel_0.05,SepsisLabel_0.06,SepsisLabel_0.07,SepsisLabel_0.09,SepsisLabel_0.1,SepsisLabel_0.15,SepsisLabel_0.2,SepsisLabel_0.25,SepsisLabel_0.3,SepsisLabel_0.4,SepsisLabel_0.5,PredictedProbabilityAblated,SepsisLabel_0.005Ablated,SepsisLabel_0.01Ablated,SepsisLabel_0.02Ablated,SepsisLabel_0.03Ablated,SepsisLabel_0.04Ablated,SepsisLabel_0.05Ablated,SepsisLabel_0.06Ablated,SepsisLabel_0.07Ablated,SepsisLabel_0.09Ablated,SepsisLabel_0.1Ablated,SepsisLabel_0.15Ablated,SepsisLabel_0.2Ablated,SepsisLabel_0.25Ablated,SepsisLabel_0.3Ablated,SepsisLabel_0.4Ablated,SepsisLabel_0.5Ablated,rel_time,AdmissionInstant,AlertTime,SepsisLabel_SIRS,SepsisLabel_MEWS
0,0,0.0084,0,,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.0084,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,2015-08-13 21:17:00,1,0,0
1,3,0.0044,0,,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.0044,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4,2015-08-13 21:17:00,4,0,0
2,6,0.0047,0,,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.0047,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7,2015-08-13 21:17:00,7,0,0
3,9,0.0042,0,,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.0042,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10,2015-08-13 21:17:00,10,0,0
4,12,0.0061,0,,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.0061,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,13,2015-08-13 21:17:00,13,0,0


### Run threshold metrics

In [9]:
def threshold_metrics(df, threshold, no_lactate=False):
    actual = df["SepsisLabel"]
    predicted = df[f'SepsisLabel_{threshold}']
    tp, fn, fp, tn = confusion_matrix(actual,predicted,labels=[1,0]).reshape(-1)
    
    acc = (tp + tn)/(tp + tn + fn + fp)
    ppv = tp / (tp + fp)
    npv = tn / (tn + fn)
    sens = tp/ (tp + fn)
    spec = tn/ (tn + fp)
    f1 = (2*tp) / (2*tp + fp + fn)
    
    return [int(tp), int(fn), int(fp), int(tn), acc, ppv, npv, sens, spec, f1]

def threshold_first_alert(test_scores_df, threshold):
    df = test_scores_df[test_scores_df[f'SepsisLabel_{threshold}'] == 1]
    n_total_possible_sepsis_alerts = df.shape[0]
    n_total_possible_sepsis_alerts_day_yr = n_total_possible_sepsis_alerts / 365.
    df.drop_duplicates(subset=['ID'], keep='first', inplace=True)
    df['LeadTimeValue'] = -(df['t_timezero'] - df['AlertTime'])
    df['abs_time'] = pd.to_datetime(test_scores_df['AdmissionInstant']) + pd.to_timedelta(test_scores_df['AlertTime'], unit='h')
    df['time_to_first_alert'] = (pd.to_datetime(df['abs_time']) - pd.to_datetime(df['AdmissionInstant'])).dt.total_seconds() / 3600.
    
    n_total_first_sepsis_alert = df.shape[0]
    n_total_first_sepsis_alert_day_yr = n_total_first_sepsis_alert / 365.
    df.dropna(subset=['t_timezero'], inplace=True) #there's no false positive. because dropped NA.
    n_tp_first_sepsis_alert = df.shape[0]
    n_early_first_sepsis_alert = df[df['LeadTimeValue'] < 0].shape[0]
    median_first_alert_time = df.groupby(f'SepsisLabel_{threshold}')['LeadTimeValue'].median().values[0]
    mean_first_alert_time_from_admission = df.groupby(f'SepsisLabel_{threshold}')['time_to_first_alert'].mean().values[0]
    return [n_total_possible_sepsis_alerts, n_total_possible_sepsis_alerts_day_yr,  n_total_first_sepsis_alert, n_total_first_sepsis_alert_day_yr, n_tp_first_sepsis_alert, n_early_first_sepsis_alert, median_first_alert_time, mean_first_alert_time_from_admission]

def final_threshold_metric_df(test_scores_df, thresholds):
    threshold_df = pd.DataFrame()
    threshold_df.index = ['tp', 'fn', 'fp', 'tn', 'acc', 'ppv', 'npv', 'sens', 'spec', 'F1']
    for threshold in thresholds:
        threshold_df[f'SepsisLabel_{threshold}'] = threshold_metrics(test_scores_df, threshold)
    threshold_df = threshold_df.T.reset_index().rename(columns={'index':'Threshold'})
    cols_to_int = ['tp', 'fn', 'fp', 'tn']
    for col in cols_to_int:
        threshold_df[col] = threshold_df[col].astype(int)
        
    lead_time_df = pd.DataFrame()
    lead_time_df.index = ['N_Total_Possible_Sepsis_Alerts', 'N_Total_Possible_Sepsis_Alerts_Per_Day', 'N_Total_First_Sepsis_Alert', 'N_Total_First_Sepsis_Alert_Per_Day','N_TP_First_Sepsis_Alert', 'N_Early_First_Sepsis_Alert', 'Median_First_Alert_Time_From_TimeZero', 'Avg_First_Alert_Time_From_Admission']
    for threshold in thresholds:
        lead_time_df[f'SepsisLabel_{threshold}'] = threshold_first_alert(test_scores_df, threshold)
    lead_time_df = lead_time_df.T.reset_index().rename(columns={'index':'Threshold'})
    cols_to_int = ['N_Total_Possible_Sepsis_Alerts', 'N_Total_Possible_Sepsis_Alerts_Per_Day', 'N_Total_First_Sepsis_Alert', 'N_Total_First_Sepsis_Alert_Per_Day','N_TP_First_Sepsis_Alert', 'N_Early_First_Sepsis_Alert']
    for col in cols_to_int:
        lead_time_df[col] = lead_time_df[col].astype(int)
             
    final_df = threshold_df.merge(lead_time_df, on=['Threshold'], how='left')
    final_df['Threshold'] = final_df['Threshold'].apply(lambda x: 'ML_' + x.split('_')[1] if x not in ['SepsisLabel_SIRS', 'SepsisLabel_MEWS'] else x.split('_')[1])
    
    return final_df
    


In [10]:
thresholds = [0.005, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.09,  0.10, 0.15, 0.20, 0.25, 0.3, 0.4,0.5, 'SIRS', 'MEWS']
ablated_thresholds = [str(threshold)+ 'Ablated' for threshold in thresholds if threshold not in ['SIRS', 'MEWS']]
thresholds = ablated_thresholds + thresholds

# test scores with full test set
final_df_w_poa = final_threshold_metric_df(test_scores_df, thresholds)

#test scores with t_timezero.isna() OR t_timezero > 1
test_scores_df_no_poa = test_scores_df[(test_scores_df['t_timezero'] > 1) | (test_scores_df['t_timezero'].isna())]
final_df_no_poa = final_threshold_metric_df(test_scores_df_no_poa, thresholds)

#combine both dataframes
final_full_df = pd.concat([final_df_w_poa, final_df_no_poa], keys=['Inclusion Criteria: All Test Set', 'Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1'], axis=1)
#final_full_df.to_excel('/gpfs/data/paulab/sepsis_floor_prediction/sepsis_real_time_prediction/model_results/threshold_metrics_20220627_172236.xlsx')

In [11]:
final_full_df

Unnamed: 0_level_0,Inclusion Criteria: All Test Set,Inclusion Criteria: All Test Set,Inclusion Criteria: All Test Set,Inclusion Criteria: All Test Set,Inclusion Criteria: All Test Set,Inclusion Criteria: All Test Set,Inclusion Criteria: All Test Set,Inclusion Criteria: All Test Set,Inclusion Criteria: All Test Set,Inclusion Criteria: All Test Set,Inclusion Criteria: All Test Set,Inclusion Criteria: All Test Set,Inclusion Criteria: All Test Set,Inclusion Criteria: All Test Set,Inclusion Criteria: All Test Set,Inclusion Criteria: All Test Set,Inclusion Criteria: All Test Set,Inclusion Criteria: All Test Set,Inclusion Criteria: All Test Set,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1,Inclusion Criteria: Encounters with t_timezero.isna() OR t_timezero > 1
Unnamed: 0_level_1,Threshold,tp,fn,fp,tn,acc,ppv,npv,sens,spec,F1,N_Total_Possible_Sepsis_Alerts,N_Total_Possible_Sepsis_Alerts_Per_Day,N_Total_First_Sepsis_Alert,N_Total_First_Sepsis_Alert_Per_Day,N_TP_First_Sepsis_Alert,N_Early_First_Sepsis_Alert,Median_First_Alert_Time_From_TimeZero,Avg_First_Alert_Time_From_Admission,Threshold,tp,fn,fp,tn,acc,ppv,npv,sens,spec,F1,N_Total_Possible_Sepsis_Alerts,N_Total_Possible_Sepsis_Alerts_Per_Day,N_Total_First_Sepsis_Alert,N_Total_First_Sepsis_Alert_Per_Day,N_TP_First_Sepsis_Alert,N_Early_First_Sepsis_Alert,Median_First_Alert_Time_From_TimeZero,Avg_First_Alert_Time_From_Admission
0,ML_0.005Ablated,9514,3,66848,24716,0.3386,0.1246,0.9999,0.9997,0.2699,0.2216,76362,209,6601,18,648,416,-0.6667,1.0046,ML_0.005Ablated,5864,2,66847,24716,0.3139,0.0806,0.9999,0.9997,0.2699,0.1493,72711,199,6369,17,416,416,-2.575,1.0
1,ML_0.01Ablated,9498,19,47505,44059,0.5298,0.1666,0.9996,0.998,0.4812,0.2856,57003,156,5426,14,648,415,-0.6667,1.088,ML_0.01Ablated,5850,16,47504,44059,0.5123,0.1096,0.9996,0.9973,0.4812,0.1976,53354,146,5194,14,416,415,-2.5583,1.1226
2,ML_0.02Ablated,9435,82,31767,59797,0.6849,0.229,0.9986,0.9914,0.6531,0.372,41202,112,4054,11,647,403,-0.55,1.4776,ML_0.02Ablated,5801,65,31766,59797,0.6733,0.1544,0.9989,0.9889,0.6531,0.2671,37567,102,3822,10,415,403,-2.3833,1.7229
3,ML_0.03Ablated,9298,219,24347,67217,0.757,0.2764,0.9968,0.977,0.7341,0.4308,33645,92,3314,9,644,384,-0.3917,1.8804,ML_0.03Ablated,5699,167,24346,67217,0.7484,0.1897,0.9975,0.9715,0.7341,0.3174,30045,82,3082,8,412,384,-1.8833,2.318
4,ML_0.04Ablated,9170,347,19890,71674,0.7998,0.3156,0.9952,0.9635,0.7828,0.4754,29060,79,2930,8,644,367,-0.25,2.1925,ML_0.04Ablated,5612,254,19889,71674,0.7933,0.2201,0.9965,0.9567,0.7828,0.3578,25501,69,2698,7,412,367,-1.7917,2.733
5,ML_0.05Ablated,9048,469,16958,74606,0.8276,0.3479,0.9938,0.9507,0.8148,0.5094,26006,71,2660,7,642,348,-0.1833,2.4953,ML_0.05Ablated,5522,344,16957,74606,0.8224,0.2457,0.9954,0.9414,0.8148,0.3896,22479,61,2428,6,410,348,-1.5667,3.122
6,ML_0.06Ablated,8973,544,14754,76810,0.8487,0.3782,0.993,0.9428,0.8389,0.5398,23727,65,2441,6,642,340,-0.1167,2.7383,ML_0.06Ablated,5470,396,14753,76810,0.8445,0.2705,0.9949,0.9325,0.8389,0.4193,20223,55,2209,6,410,340,-1.5167,3.4512
7,ML_0.07Ablated,8862,655,13142,78422,0.8635,0.4027,0.9917,0.9312,0.8565,0.5623,22004,60,2280,6,640,333,-0.0667,3.0625,ML_0.07Ablated,5384,482,13141,78422,0.8602,0.2906,0.9939,0.9178,0.8565,0.4415,18525,50,2048,5,408,333,-1.4417,3.9191
8,ML_0.09Ablated,8690,827,10717,80847,0.8858,0.4478,0.9899,0.9131,0.883,0.6009,19407,53,2038,5,638,317,0.0167,3.5439,ML_0.09Ablated,5250,616,10716,80847,0.8837,0.3288,0.9924,0.895,0.883,0.4809,15966,43,1806,4,406,317,-1.1833,4.6429
9,ML_0.1Ablated,8598,919,9729,81835,0.8947,0.4691,0.9889,0.9034,0.8937,0.6176,18327,50,1930,5,633,301,0.1167,3.763,ML_0.1Ablated,5187,679,9728,81835,0.8932,0.3478,0.9918,0.8842,0.8938,0.4992,14915,40,1698,4,401,301,-1.1333,4.9052
