In [1]:
import pickle
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from sklearn.metrics import log_loss, brier_score_loss, accuracy_score, confusion_matrix

import matplotlib.pyplot as plt

In [2]:
with open('../data/data_test.pt', 'rb') as file:
    X_test, y_test = pickle.load(file)

print(f'{X_test.shape = }')
print(f'{y_test.shape = }')

X_test.shape = (630290, 28)
y_test.shape = (630290, 2)


In [3]:
model_name_list = ['LR', 'RF', 'SVC']

In [7]:
y_pred, results = {}, {}

for model_name in model_name_list:
    with open(f'./{model_name}.pt', 'rb') as file:
        y_pred[model_name] = pickle.load(file)


for i, y_label in enumerate(['hosp', 'death']):
    results[y_label] = {}
    
    for model_name in model_name_list:
        results[y_label][model_name] = {}
        results[y_label][model_name]['binary_crossentropy'] = log_loss(y_test[:, i], y_pred[model_name][y_label][:, 1])
        results[y_label][model_name]['brier_score'] = brier_score_loss(y_test[:, i], y_pred[model_name][y_label][:, 1])
        results[y_label][model_name]['accuracy'] = accuracy_score(y_test[:, i], y_pred[model_name][y_label][:, 1].round())
        [[results[y_label][model_name]['TN'], results[y_label][model_name]['FN']],
        [results[y_label][model_name]['FP'], results[y_label][model_name]['TP']]] = confusion_matrix(y_test[:, i], (y_pred[model_name][y_label][:, 1] < 0.5).astype(int))

# Results

## Hospitalization

In [8]:
pd.DataFrame(results['hosp']).T.sort_values('binary_crossentropy')

Unnamed: 0,binary_crossentropy,brier,accuracy,TN,FN,FP,TP
SVC,0.21331,0.057606,0.931294,415.0,585207.0,1778.0,42890.0
RF,0.363646,0.054708,0.934254,3734.0,581888.0,6960.0,37708.0
LR,0.404312,0.109761,0.918031,16953.0,568669.0,9957.0,34711.0


## Fatality

In [6]:
pd.DataFrame(results['death']).T.sort_values('binary_crossentropy')

Unnamed: 0,binary_crossentropy,brier,accuracy,TN,TP,FN,FP
RF,0.035165,0.009267,0.988689,310.0,6819.0,621176.0,1985.0
SVC,0.043526,0.011111,0.987379,442.0,7513.0,621044.0,1291.0
LR,0.06201,0.011576,0.987618,20.0,7784.0,621466.0,1020.0
