In [1]:
import os

import numpy as np
import pickle
import time
import pandas as pd

from tqdm.auto import tqdm

import INN
import torch
from torch.optim import Adam

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import log_loss, brier_score_loss, accuracy_score, confusion_matrix

import GPy
import optunity as opt

import matplotlib.pyplot as plt

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

Using device: cpu


In [3]:
retrain = True

# Train Data

In [4]:
with open('../../data/data_train.pt', 'rb') as file:
    X_train, y_train = pickle.load(file)

print(f'{X_train.shape = }')
print(f'{y_train.shape = }')

X_train.shape = (2521156, 28)
y_train.shape = (2521156, 2)


# Parameters

In [5]:
INN_parameters = {
    'in_features': X_train.shape[1],
    'out_features': y_train.shape[1],
    'device': device
}

loss_weights = {
    'bce_factor': 10,
    'dvg_factor': 1,
    'logdet_factor': 1,
    'rcst_factor': 1
}

lr = 5e-4

In [6]:
hyperparameter_search_space_boundaries = {
    'n_blocks': [1, 12],
    'n_coupling_network_hidden_layers': [1, 5],
    'n_coupling_network_hidden_nodes': [4, 512 + 256],
}

In [7]:
n_epochs = 32
batch_size = 512

# Helper Functions

In [8]:
def scale_hyperparameters(hyperparameters):
    return np.array([h * (boundaries[1] - boundaries[0]) + boundaries[0] for h, boundaries in zip(hyperparameters, hyperparameter_search_space_boundaries.values())])

In [9]:
def GP_log_loss_upper_confidence_bound(n_blocks, n_coupling_network_hidden_layers, n_coupling_network_hidden_nodes, gp):
    mean, var = gp.predict_noiseless(np.array([[n_blocks, n_coupling_network_hidden_layers, n_coupling_network_hidden_nodes]]))
    return mean + np.sqrt(var)

# Load GP-Results

In [10]:
print(f'Loading Results ...')
with open(f'../../hyperparameter_results/INN.pt', 'rb') as file:
    Q, E = pickle.load(file)
print(f'Loaded Results')

GP = GPy.models.GPRegression(Q, E, kernel=GPy.kern.Matern52(3))
GP.optimize(messages=False);

Loading Results ...
Loaded Results


# Find Best Hyperparameters

In [11]:
hyperparameter_best_upper_confidence_bound = opt.minimize(
    lambda **kwargs: GP_log_loss_upper_confidence_bound(gp=GP, **kwargs),
    **{k: [0, 1] for k in hyperparameter_search_space_boundaries.keys()}
)[0]

hyperparameter_best_upper_confidence_bound_scaled = scale_hyperparameters(hyperparameter_best_upper_confidence_bound.values()).round().astype(int)

In [12]:
best_sampled_hyperparameters = scale_hyperparameters(Q[np.argmin(E)]).round().astype(int)
print(f'{best_sampled_hyperparameters=}')

best_sampled_hyperparameters=array([  3,   1, 748])


# Final Training

In [19]:
if retrain:
    for i in range(5):
        # scale features
        sc_X_train = StandardScaler()
        X_train_scaled = sc_X_train.fit_transform(X_train)

        #create classifier
        inn = INN.INN(**INN_parameters, 
            n_blocks=best_sampled_hyperparameters[0], 
            coupling_network_layers=[best_sampled_hyperparameters[2]] * best_sampled_hyperparameters[1]
        )
        inn.train()

        X_train_scaled_cuda = torch.Tensor(X_train_scaled).to(device)
        y_train_cuda = torch.Tensor(y_train).to(device)

        # fit
        loss_history = inn.fit(X_train_scaled_cuda, y_train_cuda, 
            n_epochs=n_epochs,
            batch_size=batch_size,
            optimizer=Adam(inn.parameters(), lr=lr), 
            loss_weights=loss_weights,
            verbose=1,
        );

        with open(f'../../evaluation_results/models/INN_{i}.pt', 'wb') as file:
            pickle.dump(inn.to('cpu'), file)

        with open(f'../../evaluation_results/loss_history/INN_{i}.pt', 'wb') as file:
            pickle.dump(loss_history, file)

        del inn, X_train_scaled_cuda, y_train_cuda

else:
    if os.path.exists('../../evaluation_results/models/INN.pt'):
        with open('../../evaluation_results/models/INN.pt', 'rb') as file:
            inn = pickle.load(file)
    if os.path.exists('../../evaluation_results/loss_history/INN.pt'):
        with open('../../evaluation_results/loss_history/INN.pt', 'rb') as file:
            loss_history = pickle.load(file)

100%|██████████| 32/32 [26:03<00:00, 48.85s/it, batch=4860/4861, weighted_loss=-75.312, bce=+0.116, dvg=+8.347, rcst=+0.458, logdet=-85.273]
100%|██████████| 32/32 [25:53<00:00, 48.53s/it, batch=4860/4861, weighted_loss=-74.539, bce=+0.178, dvg=+8.578, rcst=+0.463, logdet=-85.358]
100%|██████████| 32/32 [24:52<00:00, 46.63s/it, batch=4860/4861, weighted_loss=-76.840, bce=+0.125, dvg=+7.933, rcst=+0.449, logdet=-86.474]
100%|██████████| 32/32 [25:01<00:00, 46.91s/it, batch=4860/4861, weighted_loss=-74.860, bce=+0.138, dvg=+8.318, rcst=+0.455, logdet=-85.014]
100%|██████████| 32/32 [25:11<00:00, 47.24s/it, batch=4860/4861, weighted_loss=-77.321, bce=+0.129, dvg=+7.636, rcst=+0.460, logdet=-86.710]


# Test Data

In [20]:
with open('../../data/data_test.pt', 'rb') as file:
    X_test, y_test = pickle.load(file)

print(f'{X_test.shape = }')
print(f'{y_test.shape = }')

X_test_scaled = torch.Tensor(sc_X_train.transform(X_test)).to(device)

X_test.shape = (622230, 33)
y_test.shape = (622230, 2)


# Evaluate

In [21]:
evaluation_results = {'hosp': [], 'death': []}

for j in range(5):

    with open(f'../../evaluation_results/models/INN_{j}.pt', 'rb') as file:
        inn = pickle.load(file).to(device)

    n_batches = len(X_test) // batch_size
    y_proba_pred = np.empty((len(X_test), 2))
    for i_batch in tqdm(range(n_batches + 1)):
        y_proba_pred[i_batch * batch_size: (i_batch+1) * batch_size] = inn.forward(X_test_scaled[i_batch * batch_size: (i_batch+1) * batch_size])[0].detach().cpu().numpy()

    for i, y_label in enumerate(['hosp', 'death']):
        print(f'--- {y_label} ---')
        evaluation_results[y_label].append(np.concatenate([1 - y_proba_pred[:, i].reshape(-1, 1), y_proba_pred[:, i].reshape(-1, 1)], axis=1))

        print(f'binary cross-entropy: {np.round(log_loss(y_test[:, i], evaluation_results[y_label][-1][:, 1]), 4)}')
        print(f'brier loss: {brier_score_loss(y_test[:, i], evaluation_results[y_label][-1][:, 1]).round(4)}')
        print(f'accuracy: {accuracy_score(y_test[:, i], evaluation_results[y_label][-1][:, 1].round()).round(4)}')
        print('confusion matrix:')
        print(confusion_matrix(y_test[:, i], (evaluation_results[y_label][-1][:, 1] > 0.5).astype(int)))
        print()
        time.sleep(0.5)

100%|██████████| 1216/1216 [00:01<00:00, 724.91it/s]


--- hosp ---
binary cross-entropy: 0.2284
brier loss: 0.0598
accuracy: 0.9295
confusion matrix:
[[575679   3968]
 [ 39900   2683]]

--- death ---
binary cross-entropy: 0.0506
brier loss: 0.012
accuracy: 0.9867
confusion matrix:
[[612476   1502]
 [  6782   1470]]



100%|██████████| 1216/1216 [00:01<00:00, 735.92it/s]


--- hosp ---
binary cross-entropy: 0.2839
brier loss: 0.0666
accuracy: 0.9217
confusion matrix:
[[568667  10980]
 [ 37747   4836]]

--- death ---
binary cross-entropy: 0.0876
brier loss: 0.0136
accuracy: 0.9848
confusion matrix:
[[611165   2813]
 [  6663   1589]]



100%|██████████| 1216/1216 [00:01<00:00, 731.46it/s]


--- hosp ---
binary cross-entropy: 0.2344
brier loss: 0.0588
accuracy: 0.9322
confusion matrix:
[[577727   1920]
 [ 40282   2301]]

--- death ---
binary cross-entropy: 0.0544
brier loss: 0.0119
accuracy: 0.9868
confusion matrix:
[[612587   1391]
 [  6823   1429]]



100%|██████████| 1216/1216 [00:01<00:00, 733.72it/s]


--- hosp ---
binary cross-entropy: 0.2482
brier loss: 0.0617
accuracy: 0.927
confusion matrix:
[[572557   7090]
 [ 38316   4267]]

--- death ---
binary cross-entropy: 0.054
brier loss: 0.011
accuracy: 0.9882
confusion matrix:
[[613585    393]
 [  6949   1303]]



100%|██████████| 1216/1216 [00:01<00:00, 734.23it/s]


--- hosp ---
binary cross-entropy: 0.2558
brier loss: 0.0629
accuracy: 0.9257
confusion matrix:
[[570955   8692]
 [ 37547   5036]]

--- death ---
binary cross-entropy: 0.0658
brier loss: 0.0125
accuracy: 0.9856
confusion matrix:
[[611607   2371]
 [  6596   1656]]



In [22]:
with open('../../evaluation_results/predictions/INN.pt', 'wb') as file:
    pickle.dump(evaluation_results, file)