In [1]:
import os

import numpy as np
import pickle
import time
import pandas as pd

from tqdm.auto import tqdm

import INN
import torch
from torch.optim import Adam

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import log_loss, brier_score_loss, accuracy_score, confusion_matrix

import GPy
import optunity as opt

import matplotlib.pyplot as plt

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

Using device: cuda


In [3]:
retrain = True

# Train Data

In [4]:
with open('../../data/data_train.pt', 'rb') as file:
    X_train, y_train = pickle.load(file)

print(f'{X_train.shape = }')
print(f'{y_train.shape = }')

X_train.shape = (2313665, 33)
y_train.shape = (2313665, 2)


# Parameters

In [5]:
INN_parameters = {
    'in_features': X_train.shape[1],
    'out_features': y_train.shape[1],
    'device': device
}

loss_weights = {
    'bce_factor': 10,
    'dvg_factor': 1,
    'logdet_factor': 1,
    'rcst_factor': 1
}

lr = 5e-4

In [6]:
hyperparameter_search_space_boundaries = {
    'n_blocks': [1, 12],
    'n_coupling_network_hidden_layers': [1, 5],
    'n_coupling_network_hidden_nodes': [4, 512 + 256],
}

In [7]:
n_epochs = 32
batch_size = 512

# Helper Functions

In [8]:
def scale_hyperparameters(hyperparameters):
    return np.array([h * (boundaries[1] - boundaries[0]) + boundaries[0] for h, boundaries in zip(hyperparameters, hyperparameter_search_space_boundaries.values())])

In [9]:
def GP_log_loss_upper_confidence_bound(n_blocks, n_coupling_network_hidden_layers, n_coupling_network_hidden_nodes, gp):
    mean, var = gp.predict_noiseless(np.array([[n_blocks, n_coupling_network_hidden_layers, n_coupling_network_hidden_nodes]]))
    return mean + np.sqrt(var)

# Load GP-Results

In [10]:
print(f'Loading Results ...')
with open(f'../../hyperparameter_results/INN.pt', 'rb') as file:
    Q, E = pickle.load(file)
print(f'Loaded Results')

GP = GPy.models.GPRegression(Q, E, kernel=GPy.kern.Matern52(3))
GP.optimize(messages=False);

Loading Results ...
Loaded Results


# Find Best Hyperparameters

In [11]:
hyperparameter_best_upper_confidence_bound = opt.minimize(
    lambda **kwargs: GP_log_loss_upper_confidence_bound(gp=GP, **kwargs),
    **{k: [0, 1] for k in hyperparameter_search_space_boundaries.keys()}
)[0]

hyperparameter_best_upper_confidence_bound_scaled = scale_hyperparameters(hyperparameter_best_upper_confidence_bound.values()).round().astype(int)

In [12]:
best_sampled_hyperparameters = scale_hyperparameters(Q[np.argmin(E)]).round().astype(int)
print(f'{best_sampled_hyperparameters=}')

best_sampled_hyperparameters=array([  3,   1, 748])


# Final Training

In [13]:
sc_X_train = StandardScaler()
X_train_scaled = sc_X_train.fit_transform(X_train)

if retrain:
    for i in range(5):
        # scale features

        #create classifier
        inn = INN.INN(**INN_parameters, 
            n_blocks=best_sampled_hyperparameters[0], 
            coupling_network_layers=[best_sampled_hyperparameters[2]] * best_sampled_hyperparameters[1]
        )
        inn.train()

        X_train_scaled_cuda = torch.Tensor(X_train_scaled).to(device)
        y_train_cuda = torch.Tensor(y_train).to(device)

        # fit
        loss_history = inn.fit(X_train_scaled_cuda, y_train_cuda, 
            n_epochs=n_epochs,
            batch_size=batch_size,
            optimizer=Adam(inn.parameters(), lr=lr), 
            loss_weights=loss_weights,
            verbose=1,
        );

        with open(f'../../evaluation_results/models/INN_{i}.pt', 'wb') as file:
            pickle.dump(inn.to('cpu'), file)

        with open(f'../../evaluation_results/loss_history/INN_{i}.pt', 'wb') as file:
            pickle.dump(loss_history, file)

        del inn, X_train_scaled_cuda, y_train_cuda

else:
    if os.path.exists('../../evaluation_results/models/INN.pt'):
        with open('../../evaluation_results/models/INN.pt', 'rb') as file:
            inn = pickle.load(file)
    if os.path.exists('../../evaluation_results/loss_history/INN.pt'):
        with open('../../evaluation_results/loss_history/INN.pt', 'rb') as file:
            loss_history = pickle.load(file)

100%|██████████| 32/32 [26:04<00:00, 48.88s/it, batch=4517/4518, weighted_loss=-74.899, bce=+0.188, dvg=+7.697, rcst=+0.476, logdet=-84.955]
100%|██████████| 32/32 [25:19<00:00, 47.48s/it, batch=4517/4518, weighted_loss=-74.375, bce=+0.163, dvg=+8.384, rcst=+0.508, logdet=-84.900]
100%|██████████| 32/32 [24:45<00:00, 46.43s/it, batch=4517/4518, weighted_loss=-74.739, bce=+0.165, dvg=+8.354, rcst=+0.489, logdet=-85.231]
100%|██████████| 32/32 [26:18<00:00, 49.32s/it, batch=4517/4518, weighted_loss=-73.860, bce=+0.208, dvg=+8.502, rcst=+0.469, logdet=-84.913]
100%|██████████| 32/32 [26:42<00:00, 50.06s/it, batch=4517/4518, weighted_loss=-72.386, bce=+0.169, dvg=+9.753, rcst=+0.478, logdet=-84.303]


# Test Data

In [14]:
with open('../../data/data_test.pt', 'rb') as file:
    X_test, y_test = pickle.load(file)

print(f'{X_test.shape = }')
print(f'{y_test.shape = }')

X_test_scaled = torch.Tensor(sc_X_train.transform(X_test)).to(device)

X_test.shape = (578417, 33)
y_test.shape = (578417, 2)


# Evaluate

In [15]:
evaluation_results = {'hosp': [], 'death': []}

for j in range(5):

    with open(f'../../evaluation_results/models/INN_{j}.pt', 'rb') as file:
        inn = pickle.load(file).to(device)

    n_batches = len(X_test) // batch_size
    y_proba_pred = np.empty((len(X_test), 2))
    for i_batch in tqdm(range(n_batches + 1)):
        y_proba_pred[i_batch * batch_size: (i_batch+1) * batch_size] = inn.forward(X_test_scaled[i_batch * batch_size: (i_batch+1) * batch_size])[0].detach().cpu().numpy()

    for i, y_label in enumerate(['hosp', 'death']):
        print(f'--- {y_label} ---')
        evaluation_results[y_label].append(np.concatenate([1 - y_proba_pred[:, i].reshape(-1, 1), y_proba_pred[:, i].reshape(-1, 1)], axis=1))

        print(f'binary cross-entropy: {np.round(log_loss(y_test[:, i], evaluation_results[y_label][-1][:, 1]), 4)}')
        print(f'brier loss: {brier_score_loss(y_test[:, i], evaluation_results[y_label][-1][:, 1]).round(4)}')
        print(f'accuracy: {accuracy_score(y_test[:, i], evaluation_results[y_label][-1][:, 1].round()).round(4)}')
        print('confusion matrix:')
        print(confusion_matrix(y_test[:, i], (evaluation_results[y_label][-1][:, 1] > 0.5).astype(int)))
        print()
        time.sleep(0.5)

100%|██████████| 1130/1130 [00:02<00:00, 527.13it/s]


--- hosp ---
binary cross-entropy: 0.2771
brier loss: 0.067
accuracy: 0.9192
confusion matrix:
[[526415  12607]
 [ 34116   5279]]

--- death ---
binary cross-entropy: 0.0679
brier loss: 0.0125
accuracy: 0.9858
confusion matrix:
[[568592   2107]
 [  6134   1584]]



100%|██████████| 1130/1130 [00:02<00:00, 550.72it/s]


--- hosp ---
binary cross-entropy: 0.2229
brier loss: 0.0584
accuracy: 0.932
confusion matrix:
[[536576   2446]
 [ 36864   2531]]

--- death ---
binary cross-entropy: 0.0568
brier loss: 0.0118
accuracy: 0.987
confusion matrix:
[[569493   1206]
 [  6334   1384]]



100%|██████████| 1130/1130 [00:01<00:00, 606.23it/s]


--- hosp ---
binary cross-entropy: 0.2129
brier loss: 0.057
accuracy: 0.9323
confusion matrix:
[[535849   3173]
 [ 35982   3413]]

--- death ---
binary cross-entropy: 0.072
brier loss: 0.0134
accuracy: 0.9843
confusion matrix:
[[567678   3021]
 [  6044   1674]]



100%|██████████| 1130/1130 [00:01<00:00, 632.73it/s]


--- hosp ---
binary cross-entropy: 0.2395
brier loss: 0.0611
accuracy: 0.9265
confusion matrix:
[[530798   8224]
 [ 34297   5098]]

--- death ---
binary cross-entropy: 0.0646
brier loss: 0.0108
accuracy: 0.9881
confusion matrix:
[[570187    512]
 [  6357   1361]]



100%|██████████| 1130/1130 [00:01<00:00, 631.23it/s]


--- hosp ---
binary cross-entropy: 0.2321
brier loss: 0.0598
accuracy: 0.9296
confusion matrix:
[[534499   4523]
 [ 36217   3178]]

--- death ---
binary cross-entropy: 0.0543
brier loss: 0.0116
accuracy: 0.9871
confusion matrix:
[[569584   1115]
 [  6341   1377]]



In [16]:
with open('../../evaluation_results/predictions/INN.pt', 'wb') as file:
    pickle.dump(evaluation_results, file)