In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
import torch
import random
from torch.utils.data import Dataset, DataLoader
from datasets_dataloader_pytorch import CustomDataset, load_data
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, confusion_matrix
from sklearn.metrics import roc_auc_score, f1_score, recall_score, precision_score, average_precision_score

from variants import CAMELOT_GMM, class_weight, Camelot_GRU, Camelot_denoising
from model_utils import MyLRScheduler, calc_l1_l2_loss
from utils import calc_pred_loss, calc_dist_loss, calc_clus_loss, torch_log

metrics = ['AUC', 'F1 score', 'Recall', 'NMI']
seeds = [1001, 1012, 1134, 2475, 6138, 7415, 1663, 7205, 9253, 1782]

In [2]:
results = np.zeros((len(seeds), 4))
for index, SEED in enumerate(seeds):
    torch.random.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = CustomDataset(time_range=(0, 10))

    # Stratified Sampling for train and val
    train_idx, test_idx = train_test_split(np.arange(len(dataset)),
                                                test_size=0.4,
                                                random_state=SEED,
                                                shuffle=True,
                                                stratify=np.argmax(dataset.y,axis=-1))

    # Subset dataset for train and val
    train_val_dataset = dataset.get_subset(train_idx)
    test_dataset = dataset.get_subset(test_idx)

    train_idx,  val_idx = train_test_split(np.arange(len(train_val_dataset)),
                                                test_size=0.4,
                                                random_state=SEED,
                                                shuffle=True,
                                                stratify=np.argmax(train_val_dataset.y,axis=-1))

    train_dataset = train_val_dataset.get_subset(train_idx)
    val_dataset = train_val_dataset.get_subset(val_idx)

    train_loader, val_loader, test_loader = load_data(train_dataset, val_dataset, test_dataset)

    model = CAMELOT_GMM(input_shape=(train_dataset.x.shape[1], train_dataset.x.shape[2]), seed=SEED, num_clusters=10, latent_dim=64)
    model = model.to(device)

    train_x = torch.tensor(train_dataset.x).to(device)
    train_y = torch.tensor(train_dataset.y).to(device)
    val_x = torch.tensor(val_dataset.x).to(device)
    val_y = torch.tensor(val_dataset.y).to(device)

    model.initialize((train_x, train_y), (val_x, val_y))

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    cluster_optim = torch.optim.Adam([model.cluster_rep_set], lr=0.001)

    lr_scheduler = MyLRScheduler(optimizer, patience=15, min_lr=0.00001, factor=0.25)
    cluster_lr_scheduler = MyLRScheduler(cluster_optim, patience=15, min_lr=0.00001, factor=0.25)

    loss_mat = np.zeros((100, 4, 2))

    best_loss = 1e5
    count = 0
    for i in trange(100):
        for step, (x_train, y_train) in enumerate(train_loader):
            optimizer.zero_grad()
            cluster_optim.zero_grad()
            
            y_pred, probs = model.forward_pass(x_train)
            
            loss_weights = class_weight(y_train)
            
            common_loss = calc_pred_loss(y_train, y_pred, loss_weights)
            
            enc_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
            + calc_l1_l2_loss(part=model.Encoder) 
            enc_loss.backward(retain_graph=True, inputs=list(model.Encoder.parameters()))
            
            idnetifier_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
            + calc_l1_l2_loss(layers=[model.Identifier.fc2])
            idnetifier_loss.backward(retain_graph=True, inputs=list(model.Identifier.parameters()))
            
            pred_loss = common_loss + calc_l1_l2_loss(layers=[model.Predictor.fc2, model.Predictor.fc3])
            pred_loss.backward(retain_graph=True, inputs=list(model.Predictor.parameters()))
            
            clus_loss = common_loss + model.beta * calc_clus_loss(model.cluster_rep_set)
            clus_loss.backward(inputs=model.cluster_rep_set)

            optimizer.step()
            cluster_optim.step()

            loss_mat[i, 0, 0] += enc_loss.item()
            loss_mat[i, 1, 0] += idnetifier_loss.item()
            loss_mat[i, 2, 0] += pred_loss.item()
            loss_mat[i, 3, 0] += clus_loss.item()
                            
        with torch.no_grad():
            for step, (x_val, y_val) in enumerate(val_loader):
                y_pred, probs = model.forward_pass(x_val)
                
                loss_weights = class_weight(y_val)
                
                common_loss = calc_pred_loss(y_val, y_pred, loss_weights)
            
                enc_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
                + calc_l1_l2_loss(part=model.Encoder) 

                idnetifier_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
                + calc_l1_l2_loss(layers=[model.Identifier.fc2])

                pred_loss = common_loss + calc_l1_l2_loss(layers=[model.Predictor.fc2, model.Predictor.fc3])

                clus_loss = common_loss + model.beta * calc_clus_loss(model.cluster_rep_set)
                
                loss_mat[i, 0, 1] += enc_loss.item()
                loss_mat[i, 1, 1] += idnetifier_loss.item()
                loss_mat[i, 2, 1] += pred_loss.item()
                loss_mat[i, 3, 1] += clus_loss.item()
            
            if i >= 30:
                if loss_mat[i, 0, 1] < best_loss:
                    count = 0
                    best_loss = loss_mat[i, 0, 1]
                    torch.save(model.state_dict(), './best_model')
                else:
                    count += 1
                    if count >= 50:
                        model.load_state_dict(torch.load('./best_model'))
        lr_scheduler.step(loss_mat[i, 0, 1])
        cluster_lr_scheduler.step(loss_mat[i, 0, 1])

    model.load_state_dict(torch.load('./best_model'))

    real, preds = [], []
    with torch.no_grad():
        for _, (x, y) in enumerate(test_loader):
            y_pred, _ = model.forward_pass(x)
            preds.extend(list(y_pred.cpu().detach().numpy()))
            real.extend(list(y.cpu().detach().numpy()))

    auc = roc_auc_score(real, preds, average=None)

    labels_true, labels_pred = np.argmax(real, axis=1), np.argmax(preds, axis=1)

    # Compute F1
    f1 = f1_score(labels_true, labels_pred, average=None)

    # Compute Recall
    rec = recall_score(labels_true, labels_pred, average=None)

    # Compute NMI
    nmi = normalized_mutual_info_score(labels_true, labels_pred)

    print(f'AUCROC: \t{auc.mean():.5f}, \t{auc}')
    print(f'F1-score: \t{f1.mean():.5f}, \t{f1}')
    print(f'Recall: \t{rec.mean():.5f}, \t{rec}')
    print(f'NMI: \t\t{nmi:.5f}')
    
    results[index, 0] = auc.mean()
    results[index, 1] = f1.mean()
    results[index, 2] = rec.mean()
    results[index, 3] = nmi


MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:07<00:00, 1077.80it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.62it/s]


Encoder initialization done!
GMM initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.37it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.90it/s]


AUCROC: 	0.77202, 	[0.88815746 0.77457713 0.7570725  0.66826324]
F1-score: 	0.31836, 	[0.         0.5        0.77342823 0.        ]
Recall: 	0.35247, 	[0.         0.71625767 0.69361702 0.        ]
NMI: 		0.09827

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:07<00:00, 1062.37it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.75it/s]


Encoder initialization done!
GMM initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.36it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.92it/s]


AUCROC: 	0.77126, 	[0.85081673 0.77689006 0.74723754 0.71010611]
F1-score: 	0.33555, 	[0.         0.49557522 0.84662577 0.        ]
Recall: 	0.34171, 	[0.         0.51533742 0.85148936 0.        ]
NMI: 		0.10981

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:07<00:00, 1054.09it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.79it/s]


Encoder initialization done!
GMM initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.41it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:53<00:00,  1.86it/s]


AUCROC: 	0.78316, 	[0.87332571 0.79366398 0.76300521 0.70263267]
F1-score: 	0.34148, 	[0.         0.52094972 0.84497957 0.        ]
Recall: 	0.35206, 	[0.         0.57208589 0.83617021 0.        ]
NMI: 		0.11841

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:06<00:00, 1121.10it/s]
 50%|█████     | 50/100 [00:12<00:12,  3.96it/s]


Encoder initialization done!
GMM initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:12<00:12,  3.91it/s]


Identifier initialization done!


100%|██████████| 100/100 [02:35<00:00,  1.56s/it]


AUCROC: 	0.76169, 	[0.86565665 0.78824379 0.76648805 0.62636148]
F1-score: 	0.32973, 	[0.         0.51324308 0.80565693 0.        ]
Recall: 	0.35505, 	[0.         0.66871166 0.75148936 0.        ]
NMI: 		0.10676

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:09<00:00, 839.21it/s]
 50%|█████     | 50/100 [00:54<00:54,  1.09s/it]


Encoder initialization done!
GMM initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:49<00:49,  1.01it/s]


Identifier initialization done!


100%|██████████| 100/100 [04:06<00:00,  2.46s/it]


AUCROC: 	0.77123, 	[0.88985626 0.77862017 0.75399395 0.66245555]
F1-score: 	0.33313, 	[0.         0.50033715 0.83217391 0.        ]
Recall: 	0.34587, 	[0.         0.5690184  0.81446809 0.        ]
NMI: 		0.10442

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:07<00:00, 1020.29it/s]
 50%|█████     | 50/100 [00:15<00:15,  3.15it/s]


Encoder initialization done!
GMM initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:13<00:13,  3.61it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:54<00:00,  1.85it/s]


AUCROC: 	0.71272, 	[0.81028259 0.73395411 0.7169488  0.58971217]
F1-score: 	0.31027, 	[0.         0.49704724 0.74401382 0.        ]
Recall: 	0.35395, 	[0.         0.77453988 0.6412766  0.        ]
NMI: 		0.10105

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:06<00:00, 1103.90it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.83it/s]


Encoder initialization done!
GMM initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.45it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:51<00:00,  1.93it/s]


AUCROC: 	0.78109, 	[0.88351846 0.79232156 0.7638979  0.6846291 ]
F1-score: 	0.30421, 	[0.         0.49721707 0.71963331 0.        ]
Recall: 	0.35584, 	[0.         0.82208589 0.6012766  0.        ]
NMI: 		0.10720

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:06<00:00, 1106.56it/s]
 50%|█████     | 50/100 [00:12<00:12,  3.89it/s]


Encoder initialization done!
GMM initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.50it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:51<00:00,  1.94it/s]


AUCROC: 	0.77779, 	[0.92788304 0.76809361 0.74467998 0.67049266]
F1-score: 	0.32941, 	[0.         0.50728155 0.81037204 0.        ]
Recall: 	0.35145, 	[0.         0.64110429 0.76468085 0.        ]
NMI: 		0.10910

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:06<00:00, 1105.77it/s]
 50%|█████     | 50/100 [00:12<00:12,  3.90it/s]


Encoder initialization done!
GMM initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.48it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:51<00:00,  1.94it/s]


AUCROC: 	0.77956, 	[0.92758902 0.77341656 0.76717612 0.65006338]
F1-score: 	0.32964, 	[0.         0.51771429 0.80083083 0.        ]
Recall: 	0.35827, 	[0.         0.69478528 0.73829787 0.        ]
NMI: 		0.11422

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:06<00:00, 1111.46it/s]
 50%|█████     | 50/100 [00:12<00:12,  3.89it/s]


Encoder initialization done!
GMM initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.50it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:51<00:00,  1.94it/s]

AUCROC: 	0.77340, 	[0.8854296  0.7782953  0.75083331 0.67904015]
F1-score: 	0.31077, 	[0.         0.49898854 0.6875308  0.05657238]
Recall: 	0.36231, 	[0.         0.56748466 0.59361702 0.28813559]
NMI: 		0.09783





In [3]:
for m, u, std in zip(metrics, results.mean(axis=0), results.std(axis=0)):
    print(f'{m}: {u:.3f} ({std:.3f})')

AUC: 0.768 (0.019)
F1 score: 0.324 (0.012)
Recall: 0.353 (0.006)
NMI: 0.107 (0.006)


In [3]:
results = np.zeros((len(seeds), 4))
for index, SEED in enumerate(seeds):
    torch.random.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = CustomDataset(time_range=(0, 10))

    # Stratified Sampling for train and val
    train_idx, test_idx = train_test_split(np.arange(len(dataset)),
                                                test_size=0.4,
                                                random_state=SEED,
                                                shuffle=True,
                                                stratify=np.argmax(dataset.y,axis=-1))

    # Subset dataset for train and val
    train_val_dataset = dataset.get_subset(train_idx)
    test_dataset = dataset.get_subset(test_idx)

    train_idx,  val_idx = train_test_split(np.arange(len(train_val_dataset)),
                                                test_size=0.4,
                                                random_state=SEED,
                                                shuffle=True,
                                                stratify=np.argmax(train_val_dataset.y,axis=-1))

    train_dataset = train_val_dataset.get_subset(train_idx)
    val_dataset = train_val_dataset.get_subset(val_idx)

    train_loader, val_loader, test_loader = load_data(train_dataset, val_dataset, test_dataset)

    model = Camelot_GRU(input_shape=(train_dataset.x.shape[1], train_dataset.x.shape[2]), seed=SEED, num_clusters=10, latent_dim=64)
    model = model.to(device)

    train_x = torch.tensor(train_dataset.x).to(device)
    train_y = torch.tensor(train_dataset.y).to(device)
    val_x = torch.tensor(val_dataset.x).to(device)
    val_y = torch.tensor(val_dataset.y).to(device)

    model.initialize((train_x, train_y), (val_x, val_y))

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    cluster_optim = torch.optim.Adam([model.cluster_rep_set], lr=0.001)

    lr_scheduler = MyLRScheduler(optimizer, patience=15, min_lr=0.00001, factor=0.25)
    cluster_lr_scheduler = MyLRScheduler(cluster_optim, patience=15, min_lr=0.00001, factor=0.25)

    loss_mat = np.zeros((100, 4, 2))

    best_loss = 1e5
    count = 0
    for i in trange(100):
        for step, (x_train, y_train) in enumerate(train_loader):
            optimizer.zero_grad()
            cluster_optim.zero_grad()
            
            y_pred, probs = model.forward_pass(x_train)
            
            loss_weights = class_weight(y_train)
            
            common_loss = calc_pred_loss(y_train, y_pred, loss_weights)
            
            enc_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
            + calc_l1_l2_loss(part=model.Encoder) 
            enc_loss.backward(retain_graph=True, inputs=list(model.Encoder.parameters()))
            
            idnetifier_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
            + calc_l1_l2_loss(layers=[model.Identifier.fc2])
            idnetifier_loss.backward(retain_graph=True, inputs=list(model.Identifier.parameters()))
            
            pred_loss = common_loss + calc_l1_l2_loss(layers=[model.Predictor.fc2, model.Predictor.fc3])
            pred_loss.backward(retain_graph=True, inputs=list(model.Predictor.parameters()))
            
            clus_loss = common_loss + model.beta * calc_clus_loss(model.cluster_rep_set)
            clus_loss.backward(inputs=model.cluster_rep_set)

            optimizer.step()
            cluster_optim.step()

            loss_mat[i, 0, 0] += enc_loss.item()
            loss_mat[i, 1, 0] += idnetifier_loss.item()
            loss_mat[i, 2, 0] += pred_loss.item()
            loss_mat[i, 3, 0] += clus_loss.item()
                            
        with torch.no_grad():
            for step, (x_val, y_val) in enumerate(val_loader):
                y_pred, probs = model.forward_pass(x_val)
                
                loss_weights = class_weight(y_val)
                
                common_loss = calc_pred_loss(y_val, y_pred, loss_weights)
            
                enc_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
                + calc_l1_l2_loss(part=model.Encoder) 

                idnetifier_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
                + calc_l1_l2_loss(layers=[model.Identifier.fc2])

                pred_loss = common_loss + calc_l1_l2_loss(layers=[model.Predictor.fc2, model.Predictor.fc3])

                clus_loss = common_loss + model.beta * calc_clus_loss(model.cluster_rep_set)
                
                loss_mat[i, 0, 1] += enc_loss.item()
                loss_mat[i, 1, 1] += idnetifier_loss.item()
                loss_mat[i, 2, 1] += pred_loss.item()
                loss_mat[i, 3, 1] += clus_loss.item()
            
            if i >= 30:
                if loss_mat[i, 0, 1] < best_loss:
                    count = 0
                    best_loss = loss_mat[i, 0, 1]
                    torch.save(model.state_dict(), './best_model')
                else:
                    count += 1
                    if count >= 50:
                        model.load_state_dict(torch.load('./best_model'))
        lr_scheduler.step(loss_mat[i, 0, 1])
        cluster_lr_scheduler.step(loss_mat[i, 0, 1])

    model.load_state_dict(torch.load('./best_model'))

    real, preds = [], []
    with torch.no_grad():
        for _, (x, y) in enumerate(test_loader):
            y_pred, _ = model.forward_pass(x)
            preds.extend(list(y_pred.cpu().detach().numpy()))
            real.extend(list(y.cpu().detach().numpy()))

    auc = roc_auc_score(real, preds, average=None)

    labels_true, labels_pred = np.argmax(real, axis=1), np.argmax(preds, axis=1)

    # Compute F1
    f1 = f1_score(labels_true, labels_pred, average=None)

    # Compute Recall
    rec = recall_score(labels_true, labels_pred, average=None)

    # Compute NMI
    nmi = normalized_mutual_info_score(labels_true, labels_pred)

    print(f'AUCROC: \t{auc.mean():.5f}, \t{auc}')
    print(f'F1-score: \t{f1.mean():.5f}, \t{f1}')
    print(f'Recall: \t{rec.mean():.5f}, \t{rec}')
    print(f'NMI: \t\t{nmi:.5f}')
    
    results[index, 0] = auc.mean()
    results[index, 1] = f1.mean()
    results[index, 2] = rec.mean()
    results[index, 3] = nmi


MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:07<00:00, 1076.60it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.69it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.42it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:50<00:00,  1.97it/s]


AUCROC: 	0.76633, 	[0.88588697 0.77185946 0.75292808 0.65463438]
F1-score: 	0.32905, 	[0.         0.50461538 0.8115747  0.        ]
Recall: 	0.34966, 	[0.         0.62883436 0.76978723 0.        ]
NMI: 		0.10370

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:06<00:00, 1100.72it/s]
 50%|█████     | 50/100 [00:12<00:12,  3.99it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.51it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:50<00:00,  1.99it/s]


AUCROC: 	0.76308, 	[0.88012088 0.74462874 0.70888232 0.71867043]
F1-score: 	0.32137, 	[0.         0.50482936 0.70271605 0.07792208]
Recall: 	0.39067, 	[0.         0.60122699 0.60553191 0.3559322 ]
NMI: 		0.09408

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:07<00:00, 1085.11it/s]
 50%|█████     | 50/100 [00:12<00:12,  3.99it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:10<00:10,  4.57it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:49<00:00,  2.01it/s]


AUCROC: 	0.72043, 	[0.78025155 0.71248393 0.68482289 0.70416942]
F1-score: 	0.33558, 	[0.         0.50549451 0.83682732 0.        ]
Recall: 	0.34706, 	[0.         0.56441718 0.82382979 0.        ]
NMI: 		0.10594

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:06<00:00, 1107.13it/s]
 50%|█████     | 50/100 [00:12<00:12,  4.03it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.43it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:50<00:00,  1.99it/s]


AUCROC: 	0.75705, 	[0.86030709 0.78133911 0.75903426 0.62750283]
F1-score: 	0.31575, 	[0.         0.50331126 0.75970874 0.        ]
Recall: 	0.35591, 	[0.         0.75766871 0.66595745 0.        ]
NMI: 		0.10010

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:07<00:00, 1084.54it/s]
 50%|█████     | 50/100 [00:12<00:12,  3.95it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:10<00:10,  4.57it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:49<00:00,  2.01it/s]


AUCROC: 	0.75598, 	[0.89910977 0.73690952 0.7277859  0.66009714]
F1-score: 	0.32616, 	[0.         0.50643275 0.72865182 0.06956522]
Recall: 	0.37528, 	[0.         0.66411043 0.63361702 0.20338983]
NMI: 		0.08907

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:06<00:00, 1106.86it/s]
 50%|█████     | 50/100 [00:12<00:12,  4.02it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.26it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:53<00:00,  1.88it/s]


AUCROC: 	0.77352, 	[0.90526789 0.7793179  0.75719359 0.65228157]
F1-score: 	0.32591, 	[0.         0.5002978  0.80336058 0.        ]
Recall: 	0.34923, 	[0.         0.64417178 0.75276596 0.        ]
NMI: 		0.10292

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:07<00:00, 1086.58it/s]
 50%|█████     | 50/100 [00:12<00:12,  3.87it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.48it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:49<00:00,  2.01it/s]


AUCROC: 	0.78835, 	[0.90450016 0.7826086  0.75749047 0.70881614]
F1-score: 	0.29946, 	[0.         0.48165569 0.71619914 0.        ]
Recall: 	0.34706, 	[0.         0.78527607 0.60297872 0.        ]
NMI: 		0.08944

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:07<00:00, 1083.94it/s]
 50%|█████     | 50/100 [00:12<00:12,  4.04it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:10<00:10,  4.57it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:49<00:00,  2.00it/s]


AUCROC: 	0.78076, 	[0.92784221 0.75823952 0.76469366 0.67225936]
F1-score: 	0.33205, 	[0.09929078 0.45175936 0.77714286 0.        ]
Recall: 	0.41372, 	[0.35       0.61042945 0.69446809 0.        ]
NMI: 		0.12201

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:07<00:00, 1088.79it/s]
 50%|█████     | 50/100 [00:12<00:12,  3.94it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.43it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:50<00:00,  1.96it/s]


AUCROC: 	0.78416, 	[0.92479582 0.79696257 0.7754283  0.6394351 ]
F1-score: 	0.32020, 	[0.         0.51378958 0.7670303  0.        ]
Recall: 	0.36117, 	[0.         0.77147239 0.67319149 0.        ]
NMI: 		0.11311

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:07<00:00, 1080.58it/s]
 50%|█████     | 50/100 [00:12<00:12,  3.87it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.49it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:50<00:00,  1.98it/s]

AUCROC: 	0.74896, 	[0.84658608 0.77079771 0.74461885 0.63384054]
F1-score: 	0.32546, 	[0.         0.49759615 0.80425436 0.        ]
Recall: 	0.34778, 	[0.         0.63496933 0.75617021 0.        ]
NMI: 		0.09260





In [6]:
for m, u, std in zip(metrics, results.mean(axis=0), results.std(axis=0)):
    print(f'{m}: {u:.3f} ({std:.3f})')

AUC: 0.764 (0.019)
F1 score: 0.323 (0.010)
Recall: 0.364 (0.022)
NMI: 0.101 (0.010)


In [2]:
results = np.zeros((len(seeds), 4))
for index, SEED in enumerate(seeds):
    torch.random.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = CustomDataset(time_range=(0, 10))

    # Stratified Sampling for train and val
    train_idx, test_idx = train_test_split(np.arange(len(dataset)),
                                                test_size=0.4,
                                                random_state=SEED,
                                                shuffle=True,
                                                stratify=np.argmax(dataset.y,axis=-1))

    # Subset dataset for train and val
    train_val_dataset = dataset.get_subset(train_idx)
    test_dataset = dataset.get_subset(test_idx)

    train_idx,  val_idx = train_test_split(np.arange(len(train_val_dataset)),
                                                test_size=0.4,
                                                random_state=SEED,
                                                shuffle=True,
                                                stratify=np.argmax(train_val_dataset.y,axis=-1))

    train_dataset = train_val_dataset.get_subset(train_idx)
    val_dataset = train_val_dataset.get_subset(val_idx)

    train_loader, val_loader, test_loader = load_data(train_dataset, val_dataset, test_dataset)

    model = Camelot_denoising(input_shape=(train_dataset.x.shape[1], train_dataset.x.shape[2]), seed=SEED, num_clusters=10, latent_dim=64)
    model = model.to(device)

    train_x = torch.tensor(train_dataset.x).to(device)
    train_y = torch.tensor(train_dataset.y).to(device)
    val_x = torch.tensor(val_dataset.x).to(device)
    val_y = torch.tensor(val_dataset.y).to(device)

    model.initialize((train_x, train_y), (val_x, val_y))

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    cluster_optim = torch.optim.Adam([model.cluster_rep_set], lr=0.001)

    lr_scheduler = MyLRScheduler(optimizer, patience=15, min_lr=0.00001, factor=0.25)
    cluster_lr_scheduler = MyLRScheduler(cluster_optim, patience=15, min_lr=0.00001, factor=0.25)

    loss_mat = np.zeros((100, 4, 2))

    best_loss = 1e5
    count = 0
    for i in trange(100):
        for step, (x_train, y_train) in enumerate(train_loader):
            optimizer.zero_grad()
            cluster_optim.zero_grad()
            
            y_pred, probs = model.forward_pass(x_train)
            
            loss_weights = class_weight(y_train)
            
            common_loss = calc_pred_loss(y_train, y_pred, loss_weights)
            
            enc_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
            + calc_l1_l2_loss(part=model.Encoder) 
            enc_loss.backward(retain_graph=True, inputs=list(model.Encoder.parameters()))
            
            idnetifier_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
            + calc_l1_l2_loss(layers=[model.Identifier.fc2])
            idnetifier_loss.backward(retain_graph=True, inputs=list(model.Identifier.parameters()))
            
            pred_loss = common_loss + calc_l1_l2_loss(layers=[model.Predictor.fc2, model.Predictor.fc3])
            pred_loss.backward(retain_graph=True, inputs=list(model.Predictor.parameters()))
            
            clus_loss = common_loss + model.beta * calc_clus_loss(model.cluster_rep_set)
            clus_loss.backward(inputs=model.cluster_rep_set)

            optimizer.step()
            cluster_optim.step()

            loss_mat[i, 0, 0] += enc_loss.item()
            loss_mat[i, 1, 0] += idnetifier_loss.item()
            loss_mat[i, 2, 0] += pred_loss.item()
            loss_mat[i, 3, 0] += clus_loss.item()
                            
        with torch.no_grad():
            for step, (x_val, y_val) in enumerate(val_loader):
                y_pred, probs = model.forward_pass(x_val)
                
                loss_weights = class_weight(y_val)
                
                common_loss = calc_pred_loss(y_val, y_pred, loss_weights)
            
                enc_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
                + calc_l1_l2_loss(part=model.Encoder) 

                idnetifier_loss = common_loss + model.alpha * calc_dist_loss(probs) + \
                + calc_l1_l2_loss(layers=[model.Identifier.fc2])

                pred_loss = common_loss + calc_l1_l2_loss(layers=[model.Predictor.fc2, model.Predictor.fc3])

                clus_loss = common_loss + model.beta * calc_clus_loss(model.cluster_rep_set)
                
                loss_mat[i, 0, 1] += enc_loss.item()
                loss_mat[i, 1, 1] += idnetifier_loss.item()
                loss_mat[i, 2, 1] += pred_loss.item()
                loss_mat[i, 3, 1] += clus_loss.item()
            
            if i >= 30:
                if loss_mat[i, 0, 1] < best_loss:
                    count = 0
                    best_loss = loss_mat[i, 0, 1]
                    torch.save(model.state_dict(), './best_model')
                else:
                    count += 1
                    if count >= 50:
                        model.load_state_dict(torch.load('./best_model'))
        lr_scheduler.step(loss_mat[i, 0, 1])
        cluster_lr_scheduler.step(loss_mat[i, 0, 1])

    model.load_state_dict(torch.load('./best_model'))

    real, preds = [], []
    with torch.no_grad():
        for _, (x, y) in enumerate(test_loader):
            y_pred, _ = model.forward_pass(x)
            preds.extend(list(y_pred.cpu().detach().numpy()))
            real.extend(list(y.cpu().detach().numpy()))

    auc = roc_auc_score(real, preds, average=None)

    labels_true, labels_pred = np.argmax(real, axis=1), np.argmax(preds, axis=1)

    # Compute F1
    f1 = f1_score(labels_true, labels_pred, average=None)

    # Compute Recall
    rec = recall_score(labels_true, labels_pred, average=None)

    # Compute NMI
    nmi = normalized_mutual_info_score(labels_true, labels_pred)

    print(f'AUCROC: \t{auc.mean():.5f}, \t{auc}')
    print(f'F1-score: \t{f1.mean():.5f}, \t{f1}')
    print(f'Recall: \t{rec.mean():.5f}, \t{rec}')
    print(f'NMI: \t\t{nmi:.5f}')
    
    results[index, 0] = auc.mean()
    results[index, 1] = f1.mean()
    results[index, 2] = rec.mean()
    results[index, 3] = nmi


MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:07<00:00, 1023.75it/s]
 50%|█████     | 50/100 [00:14<00:14,  3.37it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:12<00:12,  3.95it/s]


Identifier initialization done!


100%|██████████| 100/100 [01:01<00:00,  1.63it/s]


AUCROC: 	0.76494, 	[0.9008984  0.75622021 0.77199639 0.63063523]
F1-score: 	0.31936, 	[0.         0.51303015 0.76442075 0.        ]
Recall: 	0.36025, 	[0.         0.76993865 0.67106383 0.        ]
NMI: 		0.11115

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:07<00:00, 978.71it/s] 
 50%|█████     | 50/100 [00:15<00:15,  3.26it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.31it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:54<00:00,  1.82it/s]


AUCROC: 	0.75550, 	[0.85320157 0.73391938 0.73845301 0.69644359]
F1-score: 	0.28916, 	[0.         0.47324193 0.68341183 0.        ]
Recall: 	0.34407, 	[0.         0.82055215 0.55574468 0.        ]
NMI: 		0.08705

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:06<00:00, 1100.36it/s]
 50%|█████     | 50/100 [00:12<00:12,  3.86it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.48it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.90it/s]


AUCROC: 	0.77313, 	[0.89078732 0.73689279 0.75440522 0.71044543]
F1-score: 	0.33815, 	[0.12972973 0.44101956 0.78184826 0.        ]
Recall: 	0.46860, 	[0.6        0.57055215 0.70382979 0.        ]
NMI: 		0.11617

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:07<00:00, 1078.95it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.84it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.41it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.90it/s]


AUCROC: 	0.76211, 	[0.89385005 0.74876556 0.75739878 0.64841165]
F1-score: 	0.30739, 	[0.         0.49763033 0.7319406  0.        ]
Recall: 	0.35598, 	[0.         0.80521472 0.6187234  0.        ]
NMI: 		0.10082

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:06<00:00, 1103.46it/s]
 50%|█████     | 50/100 [00:12<00:12,  3.87it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.49it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:54<00:00,  1.85it/s]


AUCROC: 	0.76504, 	[0.89524665 0.75798443 0.73264982 0.67427565]
F1-score: 	0.32580, 	[0.         0.48758752 0.81560284 0.        ]
Recall: 	0.34260, 	[0.         0.58742331 0.78297872 0.        ]
NMI: 		0.09139

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:07<00:00, 1079.96it/s]
 50%|█████     | 50/100 [00:13<00:13,  3.79it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.35it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:53<00:00,  1.88it/s]


AUCROC: 	0.75280, 	[0.88160732 0.76639854 0.74733941 0.61587343]
F1-score: 	0.31935, 	[0.         0.49781659 0.77958127 0.        ]
Recall: 	0.35112, 	[0.         0.6993865  0.70510638 0.        ]
NMI: 		0.09828

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:06<00:00, 1100.15it/s]
 50%|█████     | 50/100 [00:12<00:12,  3.89it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.46it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.90it/s]


AUCROC: 	0.74928, 	[0.87807906 0.74374095 0.72225864 0.65302191]
F1-score: 	0.33245, 	[0.         0.48218347 0.84760705 0.        ]
Recall: 	0.33672, 	[0.         0.48773006 0.85914894 0.        ]
NMI: 		0.10524

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:07<00:00, 1087.18it/s]
 50%|█████     | 50/100 [00:12<00:12,  3.85it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.43it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:53<00:00,  1.88it/s]


AUCROC: 	0.76488, 	[0.92445279 0.7316153  0.756213   0.64723665]
F1-score: 	0.31285, 	[0.         0.49824209 0.75317693 0.        ]
Recall: 	0.35412, 	[0.         0.7607362  0.65574468 0.        ]
NMI: 		0.10024

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:06<00:00, 1100.24it/s]
 50%|█████     | 50/100 [00:12<00:12,  3.87it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.45it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.89it/s]


AUCROC: 	0.78280, 	[0.93373081 0.78482871 0.77156271 0.64107561]
F1-score: 	0.33098, 	[0.         0.52765487 0.79625731 0.        ]
Recall: 	0.36396, 	[0.         0.73159509 0.72425532 0.        ]
NMI: 		0.12287

MIMIC data has been subsettted to the following features: 
 ['DBP', 'ESI', 'HR', 'RR', 'SBP', 'SPO2', 'TEMP', 'age', 'gender'].


100%|██████████| 7701/7701 [00:06<00:00, 1103.17it/s]
 50%|█████     | 50/100 [00:12<00:12,  3.87it/s]


Encoder initialization done!
Kmeans initialization done!
Cluster initialization done!


 50%|█████     | 50/100 [00:11<00:11,  4.43it/s]


Identifier initialization done!


100%|██████████| 100/100 [00:52<00:00,  1.89it/s]

AUCROC: 	0.76780, 	[0.90457367 0.77043117 0.74609745 0.65009142]
F1-score: 	0.32388, 	[0.         0.49182314 0.80370036 0.        ]
Recall: 	0.34514, 	[0.         0.62269939 0.75787234 0.        ]
NMI: 		0.09293





In [3]:
for m, u, std in zip(metrics, results.mean(axis=0), results.std(axis=0)):
    print(f'{m}: {u:.3f} ({std:.3f})')

AUC: 0.764 (0.009)
F1 score: 0.320 (0.013)
Recall: 0.362 (0.036)
NMI: 0.103 (0.011)
