In [1]:
SEED = 1234
_EPSILON = 1e-08

In [2]:
import sys
sys.path.append('../scr/')

import pandas as pd
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import os
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import datetime

from ddh_torch import DynamicDeepHitTorch
import import_data as impt
from utils_eval import c_index, brier_score
from utils_helper import f_get_boosted_trainset
from utils_helper import loss_Log_Likelihood, loss_Ranking, loss_RNN_Prediction
from utils_helper import f_get_risk_predictions, _f_get_pred
from utils_helper import save_checkpoint, load_checkpoint

In [3]:
## for reproducibility
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

In [5]:
def model_eval(model, dataloader, device):
    model.eval()
    total_loss = 0
    Loss1, Loss2, Loss3 = 0, 0, 0

    for i, (b_data, b_data_mi, b_time, b_label, b_mask1, b_mask2, b_mask3) in enumerate(dataloader):
        b_data, _, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
            b_data.float(), b_data_mi.float(), b_time.float(), b_label.float(), \
            b_mask1.float(), b_mask2.float(), b_mask3.float()
        b_data, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
            b_data.to(device), b_time.to(device), b_label.to(device), \
            b_mask1.to(device), b_mask2.to(device), b_mask3.to(device)
        
        longitudinal_prediction, out = model(b_data)
        out = torch.concat([o.unsqueeze(1) for o in out], 1) # (B, num_Event, num_Category)
        loss1 = loss_Log_Likelihood(out, b_label, b_mask1, b_mask2)
        loss2 = loss_Ranking(out, b_time, b_label, b_mask3, num_Event, num_Category)
        loss3 = loss_RNN_Prediction(longitudinal_prediction, b_data)
        loss = alpha * loss1 + beta * loss2 + gamma * loss3
        
        total_loss += loss
        Loss1 += loss1
        Loss2 += loss2
        Loss3 += loss3
        
    return total_loss/len(dataloader), Loss1/len(dataloader), Loss2/len(dataloader), Loss3/len(dataloader)

### Import Dataset

In [9]:
data_mode                   = 'EEG_nocr' # no competing risk, only one event
data_path                   = '../../eeg/competing-risk/EEG_processed_data_long_nocr_exclude_3.csv'
seed                        = SEED

##### IMPORT DATASET
'''
    num_Category            = max event/censoring time * 1.2
    num_Event               = number of evetns i.e. len(np.unique(label))-1
    max_length              = maximum number of measurements
    x_dim                   = data dimension including delta (1 + num_features)
    x_dim_cont              = dim of continuous features
    x_dim_bin               = dim of binary features
    mask1, mask2, mask3     = used for cause-specific network (FCNet structure)
'''

if data_mode == 'EEG_nocr':
    (x_dim, x_dim_cont, x_dim_bin), (data, time, label, time_original, time_to_last), \
        (mask1, mask2, mask3), (data_mi), trans_discrete_time = \
                    impt.import_dataset(data_path)
    
    # This must be changed depending on the datasets, prediction/evaliation times of interest
    pred_time = [6, 12] # prediction time (in hours)
    eval_time = [12, 24,] # hours evaluation time (for C-index and Brier-Score)
else:
    print ('ERROR:  DATA_MODE NOT FOUND !!!')

_, num_Event, num_Category  = np.shape(mask1)  # dim of mask3: [subj, Num_Event, Num_Category]
max_length                  = np.shape(data)[1]


# file_path = '{}_{}'.format(data_mode, seed)

# if not os.path.exists(file_path):
#     os.makedirs(file_path)
    
num_Event, num_Category, max_length



(1, 119, 12)

### Set Hyper-Parameters

In [11]:
from sklearn.model_selection import ParameterGrid
model_name = 'ddh_eeg_nocr'
param_grid = {'dropout': [0.2], 
              'lr_train': [1e-4, 5e-4],
              'beta': [0.5, 1],
              'gamma': [0.05, 0.1, 0.5],
             }
param_grid_list = list(ParameterGrid(param_grid))
print(len(param_grid_list))
param_grid_list

12


[{'beta': 0.5, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0001},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0005},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0001},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0005},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0001},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0005},
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0001},
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0005},
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0001},
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0005},
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0001},
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0005}]

In [12]:
boost_mode = 'ON' #{'ON', 'OFF'}

param = {'batch_size': 32,

         'num_epoch_burn_in': 20,
         'num_epoch': 100,

         'dropout': 0.2, # 1 - keep_prob
         'lr_train': 1e-4,

         'hidden_rnn': 64,
         'hidden_dim_FC': 64,
         'layers_rnn': 2,
         'hidden_att': 2,
         'hidden_cs' : 2,

         'alpha' :1.0,
         'beta'  :1,
         'gamma' :0.1
        }


batch_size = param['batch_size']
num_epoch  = param['num_epoch']

dropout    = param['dropout']
lr_train   = param['lr_train']

hidden_rnn = param['hidden_rnn']
hidden_dim_FC = param['hidden_dim_FC']
layers_rnn = param['layers_rnn']
hidden_att = param['hidden_att']
hidden_cs = param['hidden_cs']

alpha      = param['alpha']
beta       = param['beta']
gamma      = param['gamma']

device     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


### Split Dataset into Train/Valid/Test Sets

In [13]:
### TRAINING-TESTING SPLIT
(tr_data,te_data, tr_data_mi, te_data_mi, tr_time,te_time, tr_label,te_label, tr_time_orinal,te_time_original,
 tr_time_to_last,te_time_to_last,
 tr_mask1,te_mask1, tr_mask2,te_mask2, tr_mask3,te_mask3) = \
        train_test_split(data, data_mi, time, label, time_original, time_to_last,
                         mask1, mask2, mask3, test_size=0.2, stratify=label, random_state=seed) 

(tr_data,va_data, tr_data_mi, va_data_mi, tr_time,va_time, tr_label,va_label, tr_time_orinal,va_time_original,
 tr_time_to_last, va_time_to_last,
 tr_mask1,va_mask1, tr_mask2,va_mask2, tr_mask3,va_mask3) = \
        train_test_split(tr_data, tr_data_mi, tr_time, tr_label, tr_time_orinal, tr_time_to_last,
                         tr_mask1, tr_mask2, tr_mask3, test_size=0.2, stratify=tr_label, random_state=seed) 

if boost_mode == 'ON':
    tr_data, tr_data_mi, tr_time, tr_label, tr_mask1, tr_mask2, tr_mask3 = \
            f_get_boosted_trainset(tr_data, tr_data_mi, tr_time, tr_label, tr_mask1, tr_mask2, tr_mask3)

In [15]:
tr_loader = DataLoader(list(zip(tr_data, tr_data_mi, tr_time, tr_label, tr_mask1, tr_mask2, tr_mask3)), 
                       batch_size=batch_size, shuffle=True)

va_loader = DataLoader(list(zip(va_data, va_data_mi, va_time, va_label, va_mask1, va_mask2, va_mask3)), 
                       batch_size=batch_size, shuffle=False)

te_loader = DataLoader(list(zip(te_data, te_data_mi, te_time, te_label, te_mask1, te_mask2, te_mask3)), 
                       batch_size=batch_size, shuffle=False)

### Train the Network

This implementation does not consider missing values.

In [16]:
num_Event

1

In [17]:
pred_time = [6, 12]
eval_time = [24, 48, 72]

In [18]:
Best_C_Index = 0

if not os.path.exists(f"checkpoints/{model_name}"):
    os.makedirs(f"checkpoints/{model_name}")

for params in tqdm(param_grid_list):
    dropout = params['dropout']
    lr_train = params['lr_train']
    beta = params['beta']
    gamma = params['gamma']
    print('='*80, '\n', params)

    model = \
        DynamicDeepHitTorch(input_dim  = x_dim,
                            output_dim = num_Category,
                            layers_rnn = param['layers_rnn'],
                            hidden_rnn = param['hidden_rnn'],
                            long_param = {'layers': 2*[param['hidden_dim_FC']], 'dropout': dropout}, 
                            att_param  = {'layers': param['hidden_att']*[param['hidden_dim_FC']], 'dropout': dropout}, 
                            cs_param   = {'layers': param['hidden_cs']*[param['hidden_dim_FC']], 'dropout': dropout},
                            risks      = num_Event
                           )
    optimizer_burn_in = torch.optim.Adam(model.parameters(), lr = lr_train)
    optimizer = torch.optim.Adam(model.parameters(), lr = lr_train)


    model.train()
    model.to(device)

    min_valid = 0.1
    early_stop_patience = 10
    ep = 0 # for tracking early stopping

    experiment_runtime = datetime.datetime.now().strftime('%Y%m%d_%H%M')
    tensorboard_string_extension = f"{model_name}_{SEED}"
    tensorboard_string_extension += f"_alpha{alpha}_beta{beta}_gamma{gamma}_BSZ{batch_size}_Lr{lr_train}"
    tensorboard_string_extension += f"_hiddenRNN{hidden_rnn}_hiddenFC{hidden_dim_FC}"
    tensorboard_string_extension += f"_layersRNN{layers_rnn}_hiddenAtt{hidden_att}_hiddenCS{hidden_cs}"
    tensorboard_string_extension += f"_dropout{dropout}"
    tensorboard_string_extension += "_" + experiment_runtime
    tensorboard_filename = tensorboard_string_extension
    writer = SummaryWriter(f"runs/{model_name}/{tensorboard_filename}")


    for epoch in range(param['num_epoch_burn_in']):
        model.train()
        for i, (b_data, b_data_mi, b_time, b_label, b_mask1, b_mask2, b_mask3) in enumerate(tr_loader):
            b_data, _, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
                b_data.float(), b_data_mi.float(), b_time.float(), b_label.float(), \
                b_mask1.float(), b_mask2.float(), b_mask3.float()
            b_data, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
                b_data.to(device), b_time.to(device), b_label.to(device), \
                b_mask1.to(device), b_mask2.to(device), b_mask3.to(device)

            longitudinal_prediction, out = model(b_data)
            out = torch.concat([o.unsqueeze(1) for o in out], 1) # (B, num_Event, num_Category)

            # loss1 = loss_Log_Likelihood(out, b_label, b_mask1, b_mask2)
            # loss2 = loss_Ranking(out, b_time, b_label, b_mask3, num_Event, num_Category)
            loss3 = loss_RNN_Prediction(longitudinal_prediction, b_data)
            tr_loss = loss3

            optimizer_burn_in.zero_grad()
            tr_loss.backward()
            optimizer_burn_in.step()


    for epoch in range(num_epoch):
        model.train()
        for i, (b_data, b_data_mi, b_time, b_label, b_mask1, b_mask2, b_mask3) in enumerate(tr_loader):
            b_data, _, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
                b_data.float(), b_data_mi.float(), b_time.float(), b_label.float(), \
                b_mask1.float(), b_mask2.float(), b_mask3.float()
            b_data, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
                b_data.to(device), b_time.to(device), b_label.to(device), \
                b_mask1.to(device), b_mask2.to(device), b_mask3.to(device)

            longitudinal_prediction, out = model(b_data)
            out = torch.concat([o.unsqueeze(1) for o in out], 1) # (B, num_Event, num_Category)

            loss1 = loss_Log_Likelihood(out, b_label, b_mask1, b_mask2)
            loss2 = loss_Ranking(out, b_time, b_label, b_mask3, num_Event, num_Category)
            loss3 = loss_RNN_Prediction(longitudinal_prediction, b_data)
            tr_loss = alpha * loss1 + beta * loss2 + gamma * loss3

            optimizer.zero_grad()
            tr_loss.backward()
            optimizer.step()

        model.eval()
        tr_loss, tr_loss1, tr_loss2, tr_loss3 = model_eval(model, tr_loader, device)
        va_loss, va_loss1, va_loss2, va_loss3 = model_eval(model, va_loader, device)

        ### VALIDATION  (based on average C-index of our interest)
        risk_all = f_get_risk_predictions(model, va_data, pred_time, eval_time, device)

        for p, p_time in enumerate(pred_time):
            pred_horizon = int(p_time)
            val_result1 = np.zeros([num_Event, len(eval_time)])

            for t, t_time in enumerate(eval_time):                
                eval_horizon = int(t_time) + pred_horizon
                for k in range(num_Event):
                    val_result1[k, t] = c_index(risk_all[k][:, p, t], va_time, (va_label[:,0] == k+1).astype(int), eval_horizon) #-1 for no event (not comparable)

            if p == 0:
                val_final1 = val_result1
            else:
                val_final1 = np.append(val_final1, val_result1, axis=0)

        tmp_valid = np.mean(val_final1)

        if tmp_valid >  min_valid:
            ep = 0
            min_valid = tmp_valid
            save_checkpoint(f"checkpoints/{model_name}/{writer.get_logdir().split('/')[-1]}_{SEED}.pt", model, optimizer, va_loss)
            print( 'updated.... average c-index = ' + str('%.4f' %(tmp_valid)))
        else:
            ep += 1  
            print("Counter {} of {}".format(ep, early_stop_patience))

        if ep >= early_stop_patience:
            print("Early stopping with best_cindex: {:.4f}".format(min_valid), 
                  "and val_cindex for this epoch: {:.4f}".format(tmp_valid))
            break

        # Tensorboard
        writer.add_scalar("Loss-Training/Loss", tr_loss, epoch)
        writer.add_scalar("Loss-Training/NLL Loss", tr_loss1, epoch)
        writer.add_scalar("Loss-Training/Ranking Loss", tr_loss2, epoch)
        writer.add_scalar("Loss-Training/LSTM MSELoss", tr_loss3, epoch)

        writer.add_scalar("Loss-Validation/Loss", va_loss, epoch)
        writer.add_scalar("Loss-Validation/NLL Loss", va_loss1, epoch)
        writer.add_scalar("Loss-Validation/Ranking Loss", va_loss2, epoch)
        writer.add_scalar("Loss-Validation/LSTM MSELoss", va_loss3, epoch)

        writer.add_scalar("C-index/Validation Avg C-index", min_valid, epoch)

        print('epoch {}, tr_loss: {:.4f}, va_loss: {:.4f}, va c-index: {:.4f}'.\
                      format(epoch+1, tr_loss, va_loss, tmp_valid))
        
    if min_valid > Best_C_Index:
        Best_C_Index = min_valid
        print('Better valudation c-index!!!')
        print(writer.get_logdir().split('/')[-1])

  0%|                                                              | 0/12 [00:00<?, ?it/s]

 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0001}
updated.... average c-index = 0.7240
epoch 1, tr_loss: 28.5015, va_loss: 22.3150, va c-index: 0.7240
updated.... average c-index = 0.7335
epoch 2, tr_loss: 27.8481, va_loss: 22.0414, va c-index: 0.7335
updated.... average c-index = 0.7452
epoch 3, tr_loss: 27.5822, va_loss: 21.9818, va c-index: 0.7452
updated.... average c-index = 0.7512
epoch 4, tr_loss: 27.4168, va_loss: 21.7861, va c-index: 0.7512
updated.... average c-index = 0.7689
epoch 5, tr_loss: 27.2831, va_loss: 21.5237, va c-index: 0.7689
Counter 1 of 10
epoch 6, tr_loss: 27.3608, va_loss: 22.1702, va c-index: 0.7306
Counter 2 of 10
epoch 7, tr_loss: 26.6839, va_loss: 21.9517, va c-index: 0.7535
Counter 3 of 10
epoch 8, tr_loss: 26.3426, va_loss: 22.0553, va c-index: 0.7559
Counter 4 of 10
epoch 9, tr_loss: 26.4737, va_loss: 22.0746, va c-index: 0.7577
Counter 5 of 10
epoch 10, tr_loss: 26.4061, va_loss: 22.1043, va c-index: 0.7513
Counter 6 of 10
epoch 11, tr

  8%|████▎                                              | 1/12 [10:47<1:58:37, 647.01s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.8343 and val_cindex for this epoch: 0.8162
Better valudation c-index!!!
ddh_eeg_nocr_4444_alpha1.0_beta0.5_gamma0.05_BSZ32_Lr0.0001_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230417_2257
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0005}
updated.... average c-index = 0.6979
epoch 1, tr_loss: 24.8727, va_loss: 24.1308, va c-index: 0.6979
Counter 1 of 10
epoch 2, tr_loss: 24.0410, va_loss: 26.9288, va c-index: 0.5716
updated.... average c-index = 0.7379
epoch 3, tr_loss: 21.9254, va_loss: 24.1405, va c-index: 0.7379
updated.... average c-index = 0.7380
epoch 4, tr_loss: 20.7134, va_loss: 24.7415, va c-index: 0.7380
updated.... average c-index = 0.7894
epoch 5, tr_loss: 19.7963, va_loss: 23.6015, va c-index: 0.7894
Counter 1 of 10
epoch 6, tr_loss: 20.2364, va_loss: 25.3740, va c-index: 0.7545
Counter 2 of 10
epoch 7, tr_loss: 20.6514, va_loss: 23.0886, va c-index: 0.7559
Counter 3 of 10
epoch 8, t

 17%|████████▌                                          | 2/12 [14:36<1:06:54, 401.45s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.7894 and val_cindex for this epoch: 0.7832
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0001}
updated.... average c-index = 0.6027
epoch 1, tr_loss: 53.8996, va_loss: 47.6985, va c-index: 0.6027
updated.... average c-index = 0.6727
epoch 2, tr_loss: 52.5273, va_loss: 41.9549, va c-index: 0.6727
updated.... average c-index = 0.6923
epoch 3, tr_loss: 52.2670, va_loss: 42.1783, va c-index: 0.6923
updated.... average c-index = 0.7156
epoch 4, tr_loss: 51.2909, va_loss: 41.9006, va c-index: 0.7156
Counter 1 of 10
epoch 5, tr_loss: 50.8592, va_loss: 41.9939, va c-index: 0.7062
Counter 2 of 10
epoch 6, tr_loss: 50.9600, va_loss: 42.6431, va c-index: 0.6625
Counter 3 of 10
epoch 7, tr_loss: 50.0426, va_loss: 42.4700, va c-index: 0.6886
Counter 4 of 10
epoch 8, tr_loss: 48.9581, va_loss: 41.8110, va c-index: 0.7137
updated.... average c-index = 0.7246
epoch 9, tr_loss: 48.4540, va_loss: 41.7042, va c-index: 0.7246
updated.... ave

 25%|█████████████▎                                       | 3/12 [19:58<54:44, 364.97s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.7724 and val_cindex for this epoch: 0.7627
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0005}
updated.... average c-index = 0.7046
epoch 1, tr_loss: 42.5558, va_loss: 45.7540, va c-index: 0.7046
updated.... average c-index = 0.7238
epoch 2, tr_loss: 36.5528, va_loss: 45.4380, va c-index: 0.7238
updated.... average c-index = 0.7706
epoch 3, tr_loss: 34.9782, va_loss: 45.2802, va c-index: 0.7706
Counter 1 of 10
epoch 4, tr_loss: 38.9529, va_loss: 42.6032, va c-index: 0.7637
Counter 2 of 10
epoch 5, tr_loss: 33.3932, va_loss: 42.2031, va c-index: 0.7642
updated.... average c-index = 0.7966
epoch 6, tr_loss: 33.0001, va_loss: 44.5708, va c-index: 0.7966
Counter 1 of 10
epoch 7, tr_loss: 32.5891, va_loss: 45.2162, va c-index: 0.7643
Counter 2 of 10
epoch 8, tr_loss: 35.3913, va_loss: 41.9945, va c-index: 0.7652
Counter 3 of 10
epoch 9, tr_loss: 32.6156, va_loss: 40.4413, va c-index: 0.7779
Counter 4 of 10
epoch 10, tr_loss: 3

 33%|█████████████████▋                                   | 4/12 [24:30<43:47, 328.47s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.8025 and val_cindex for this epoch: 0.7696
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0001}
updated.... average c-index = 0.7094
epoch 1, tr_loss: 236.9356, va_loss: 189.0510, va c-index: 0.7094
updated.... average c-index = 0.7385
epoch 2, tr_loss: 234.0529, va_loss: 189.2617, va c-index: 0.7385
Counter 1 of 10
epoch 3, tr_loss: 235.3138, va_loss: 189.7303, va c-index: 0.7189
Counter 2 of 10
epoch 4, tr_loss: 235.2318, va_loss: 189.9984, va c-index: 0.7320
Counter 3 of 10
epoch 5, tr_loss: 229.3929, va_loss: 190.7396, va c-index: 0.7348
updated.... average c-index = 0.7481
epoch 6, tr_loss: 229.9112, va_loss: 191.5585, va c-index: 0.7481
Counter 1 of 10
epoch 7, tr_loss: 221.9945, va_loss: 191.8805, va c-index: 0.7433
Counter 2 of 10
epoch 8, tr_loss: 222.9827, va_loss: 192.7674, va c-index: 0.7340
Counter 3 of 10
epoch 9, tr_loss: 219.2255, va_loss: 193.9492, va c-index: 0.7007
Counter 4 of 10
epoch 10, tr_loss: 217.

 42%|██████████████████████                               | 5/12 [31:45<42:48, 366.87s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.7875 and val_cindex for this epoch: 0.7722
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0005}
updated.... average c-index = 0.6878
epoch 1, tr_loss: 166.8927, va_loss: 198.2120, va c-index: 0.6878
Counter 1 of 10
epoch 2, tr_loss: 149.9650, va_loss: 203.7801, va c-index: 0.6867
updated.... average c-index = 0.7436
epoch 3, tr_loss: 148.5080, va_loss: 194.7031, va c-index: 0.7436
updated.... average c-index = 0.7474
epoch 4, tr_loss: 151.3192, va_loss: 203.1647, va c-index: 0.7474
Counter 1 of 10
epoch 5, tr_loss: 150.5337, va_loss: 221.4686, va c-index: 0.6829
Counter 2 of 10
epoch 6, tr_loss: 146.0023, va_loss: 232.5849, va c-index: 0.7256
Counter 3 of 10
epoch 7, tr_loss: 150.5153, va_loss: 222.3080, va c-index: 0.6781
Counter 4 of 10
epoch 8, tr_loss: 147.2883, va_loss: 189.0318, va c-index: 0.7004
updated.... average c-index = 0.7500
epoch 9, tr_loss: 155.9337, va_loss: 193.3545, va c-index: 0.7500
updated.... averag

 50%|██████████████████████████▌                          | 6/12 [42:06<45:19, 453.28s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.8333 and val_cindex for this epoch: 0.8290
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0001}
updated.... average c-index = 0.5581
epoch 1, tr_loss: 32.8978, va_loss: 30.3575, va c-index: 0.5581
updated.... average c-index = 0.6011
epoch 2, tr_loss: 31.2863, va_loss: 24.6635, va c-index: 0.6011
updated.... average c-index = 0.6379
epoch 3, tr_loss: 30.5365, va_loss: 25.2995, va c-index: 0.6379
updated.... average c-index = 0.6917
epoch 4, tr_loss: 30.0521, va_loss: 25.1509, va c-index: 0.6917
updated.... average c-index = 0.7167
epoch 5, tr_loss: 29.8177, va_loss: 24.8524, va c-index: 0.7167
updated.... average c-index = 0.7390
epoch 6, tr_loss: 28.6844, va_loss: 24.4518, va c-index: 0.7390
updated.... average c-index = 0.7485
epoch 7, tr_loss: 28.7183, va_loss: 24.0713, va c-index: 0.7485
updated.... average c-index = 0.7511
epoch 8, tr_loss: 28.0447, va_loss: 24.0600, va c-index: 0.7511
Counter 1 of 10
epoch 9, tr_loss:

 58%|██████████████████████████████▉                      | 7/12 [46:30<32:36, 391.23s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.7511 and val_cindex for this epoch: 0.7316
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0005}
updated.... average c-index = 0.6868
epoch 1, tr_loss: 24.1166, va_loss: 26.9528, va c-index: 0.6868
updated.... average c-index = 0.7503
epoch 2, tr_loss: 22.5905, va_loss: 24.7879, va c-index: 0.7503
Counter 1 of 10
epoch 3, tr_loss: 35.1212, va_loss: 46.0706, va c-index: 0.6261
Counter 2 of 10
epoch 4, tr_loss: 22.0027, va_loss: 26.2714, va c-index: 0.7205
Counter 3 of 10
epoch 5, tr_loss: 25.9875, va_loss: 37.1879, va c-index: 0.6191
Counter 4 of 10
epoch 6, tr_loss: 20.1428, va_loss: 26.0986, va c-index: 0.6584
Counter 5 of 10
epoch 7, tr_loss: 22.6185, va_loss: 25.6059, va c-index: 0.6789
Counter 6 of 10
epoch 8, tr_loss: 20.6735, va_loss: 25.3721, va c-index: 0.7358
Counter 7 of 10
epoch 9, tr_loss: 19.4777, va_loss: 24.4631, va c-index: 0.7407
Counter 8 of 10
epoch 10, tr_loss: 19.1221, va_loss: 24.6869, va c-index: 0.737

 67%|███████████████████████████████████▎                 | 8/12 [53:12<26:18, 394.68s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.8118 and val_cindex for this epoch: 0.7954
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0001}
updated.... average c-index = 0.7183
epoch 1, tr_loss: 57.1739, va_loss: 43.4013, va c-index: 0.7183
Counter 1 of 10
epoch 2, tr_loss: 56.3162, va_loss: 43.4946, va c-index: 0.7122
Counter 2 of 10
epoch 3, tr_loss: 54.5881, va_loss: 42.8325, va c-index: 0.7142
Counter 3 of 10
epoch 4, tr_loss: 54.4398, va_loss: 42.2293, va c-index: 0.7151
updated.... average c-index = 0.7512
epoch 5, tr_loss: 52.9368, va_loss: 42.4564, va c-index: 0.7512
Counter 1 of 10
epoch 6, tr_loss: 51.8407, va_loss: 41.9505, va c-index: 0.7436
Counter 2 of 10
epoch 7, tr_loss: 53.3850, va_loss: 43.2499, va c-index: 0.7181
Counter 3 of 10
epoch 8, tr_loss: 52.4019, va_loss: 42.7496, va c-index: 0.7364
Counter 4 of 10
epoch 9, tr_loss: 51.1778, va_loss: 42.8899, va c-index: 0.7418
Counter 5 of 10
epoch 10, tr_loss: 51.1152, va_loss: 43.0343, va c-index: 0.7355

 75%|██████████████████████████████████████▎            | 9/12 [1:02:03<21:52, 437.45s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.8165 and val_cindex for this epoch: 0.7914
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0005}
updated.... average c-index = 0.7476
epoch 1, tr_loss: 44.7993, va_loss: 44.3479, va c-index: 0.7476
updated.... average c-index = 0.7904
epoch 2, tr_loss: 40.5681, va_loss: 45.3091, va c-index: 0.7904
Counter 1 of 10
epoch 3, tr_loss: 38.3100, va_loss: 46.4318, va c-index: 0.7732
updated.... average c-index = 0.7934
epoch 4, tr_loss: 39.8132, va_loss: 44.2094, va c-index: 0.7934
Counter 1 of 10
epoch 5, tr_loss: 36.6112, va_loss: 44.8292, va c-index: 0.7687
Counter 2 of 10
epoch 6, tr_loss: 39.3990, va_loss: 53.6279, va c-index: 0.7283
Counter 3 of 10
epoch 7, tr_loss: 33.9207, va_loss: 47.5428, va c-index: 0.7594
Counter 4 of 10
epoch 8, tr_loss: 35.6092, va_loss: 47.8368, va c-index: 0.7510
Counter 5 of 10
epoch 9, tr_loss: 33.5873, va_loss: 47.2838, va c-index: 0.7839
Counter 6 of 10
epoch 10, tr_loss: 34.5932, va_loss: 44.707

 83%|█████████████████████████████████████████▋        | 10/12 [1:12:18<16:24, 492.34s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.8442 and val_cindex for this epoch: 0.8411
Better valudation c-index!!!
ddh_eeg_nocr_4444_alpha1.0_beta1_gamma0.1_BSZ32_Lr0.0005_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230417_2359
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0001}
updated.... average c-index = 0.6687
epoch 1, tr_loss: 240.3563, va_loss: 188.5174, va c-index: 0.6687
updated.... average c-index = 0.7002
epoch 2, tr_loss: 236.5426, va_loss: 187.4814, va c-index: 0.7002
updated.... average c-index = 0.7134
epoch 3, tr_loss: 237.2842, va_loss: 188.5211, va c-index: 0.7134
updated.... average c-index = 0.7342
epoch 4, tr_loss: 238.0649, va_loss: 186.6342, va c-index: 0.7342
updated.... average c-index = 0.7367
epoch 5, tr_loss: 231.1864, va_loss: 187.1715, va c-index: 0.7367
Counter 1 of 10
epoch 6, tr_loss: 229.7157, va_loss: 187.0078, va c-index: 0.7332
updated.... average c-index = 0.7537
epoch 7, tr_loss: 229.0677, va_loss: 188.3

 92%|█████████████████████████████████████████████▊    | 11/12 [1:20:06<08:04, 484.64s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.8134 and val_cindex for this epoch: 0.8069
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0005}
updated.... average c-index = 0.7534
epoch 1, tr_loss: 181.5776, va_loss: 198.1051, va c-index: 0.7534
updated.... average c-index = 0.7847
epoch 2, tr_loss: 178.5399, va_loss: 203.9655, va c-index: 0.7847
updated.... average c-index = 0.7980
epoch 3, tr_loss: 176.7306, va_loss: 210.6639, va c-index: 0.7980
Counter 1 of 10
epoch 4, tr_loss: 163.1818, va_loss: 207.9450, va c-index: 0.7636
Counter 2 of 10
epoch 5, tr_loss: 170.5546, va_loss: 195.4999, va c-index: 0.7753
Counter 3 of 10
epoch 6, tr_loss: 153.9863, va_loss: 221.0785, va c-index: 0.7676
Counter 4 of 10
epoch 7, tr_loss: 154.3016, va_loss: 205.5509, va c-index: 0.7706
Counter 5 of 10
epoch 8, tr_loss: 148.7608, va_loss: 211.4980, va c-index: 0.7865
Counter 6 of 10
epoch 9, tr_loss: 141.1074, va_loss: 206.4012, va c-index: 0.7723
Counter 7 of 10
epoch 10, tr_loss: 144.10

100%|██████████████████████████████████████████████████| 12/12 [1:23:37<00:00, 418.10s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.7980 and val_cindex for this epoch: 0.7874





In [19]:
model

DynamicDeepHitTorch(
  (embedding): LSTM(116, 64, num_layers=2, bias=False, batch_first=True)
  (longitudinal): Sequential(
    (0): Dropout(p=0.4, inplace=False)
    (1): Linear(in_features=64, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): ReLU()
    (5): Linear(in_features=64, out_features=116, bias=True)
  )
  (attention): Sequential(
    (0): Dropout(p=0.4, inplace=False)
    (1): Linear(in_features=180, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): ReLU()
    (5): Linear(in_features=64, out_features=1, bias=True)
  )
  (attention_soft): Softmax(dim=1)
  (cause_specific): ModuleList(
    (0): Sequential(
      (0): Dropout(p=0.4, inplace=False)
      (1): Linear(in_features=180, out_features=64, bias=True)
      (2): ReLU()
      (3): Linear(in_features=64, out_features=64, bias=True)
      (4): ReLU()
      (5): Linear(in_features=64, out_features=1

### Test the Trained Network

## Evaluation Metrics

In [22]:
from sklearn import metrics
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib

seed_list = [1234, 1111, 2222, 3333, 4444]
best_pt_path_list = ['ddh_eeg_nocr_1234_alpha1.0_beta0.5_gamma0.1_BSZ32_Lr0.0005_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230417_1709',
                     'ddh_eeg_nocr_1111_alpha1.0_beta0.5_gamma0.1_BSZ32_Lr0.0005_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230417_1855',
                     'ddh_eeg_nocr_2222_alpha1.0_beta1_gamma0.1_BSZ32_Lr0.0005_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230417_2103',
                     'ddh_eeg_nocr_3333_alpha1.0_beta0.5_gamma0.05_BSZ32_Lr0.0001_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230417_2140',
                     'ddh_eeg_nocr_4444_alpha1.0_beta1_gamma0.1_BSZ32_Lr0.0005_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230417_2359'
                    ]

data_mode = 'EEG_nocr' # ['EEG_nocr', 'EEG_cr2', 'EEG_cr3', 'EEG_cr4']

if data_mode == 'EEG_nocr':
    path = '../../eeg/competing-risk/EEG_processed_data_long_nocr_exclude_3.csv'

(x_dim, x_dim_cont, x_dim_bin), (data, time, label, time_original, time_to_last), \
    (mask1, mask2, mask3), (data_mi), trans_discrete_time = \
                impt.import_dataset(path)

_, num_Event, num_Category  = np.shape(mask1)  # dim of mask3: [subj, Num_Event, Num_Category]
max_length                  = np.shape(data)[1]

pred_time = list(range(1,13))
eval_time = [24, 48, 72]
mean_fpr = np.linspace(1e-10, 1, 1000)



In [21]:
def eval_model(model, te_data, pred_time, eval_time, num_Event, device='cuda', te_label_cp=None):
    # pred_time = [4, 8, 12]
    # eval_time = [12, 24, 36, 48, 60, 72]
    
    if te_label_cp is None:
        te_label_cp = te_label
    
    risk_all = f_get_risk_predictions(model, te_data, pred_time, eval_time, device=device)

    for p, p_time in enumerate(pred_time):
        pred_horizon = int(p_time)
        result1, result2 = np.zeros([num_Event, len(eval_time)]), np.zeros([num_Event, len(eval_time)])

        for t, t_time in enumerate(eval_time):                
            eval_horizon = int(t_time) + pred_horizon
            for k in range(num_Event):
                result1[k, t] = c_index(risk_all[k][:, p, t], te_time, (te_label_cp[:,0] == k+1).astype(int), eval_horizon) #-1 for no event (not comparable)
                result2[k, t] = brier_score(risk_all[k][:, p, t], te_time, (te_label_cp[:,0] == k+1).astype(int), eval_horizon) #-1 for no event (not comparable)

        if p == 0:
            final1, final2 = result1, result2
        else:
            final1, final2 = np.append(final1, result1, axis=0), np.append(final2, result2, axis=0)


    row_header = []
    for p_time in pred_time:
        for k in range(num_Event):
            row_header.append('pred_time {}: event_{}'.format(p_time, k+1))

    col_header = []
    for t_time in eval_time:
        col_header.append('eval_time {}'.format(t_time))

    # c-index result
    df1 = pd.DataFrame(final1, index = row_header, columns=col_header)

    # brier-score result
    df2 = pd.DataFrame(final2, index = row_header, columns=col_header)

    ### PRINT RESULTS
    print('========================================================')
    print('--------------------------------------------------------')
    print('- C-INDEX: ')
    print(df1)
    print('--------------------------------------------------------')
    print('- BRIER-SCORE: ')
    print(df2)
    print('========================================================')
    
    return df1, df2

In [28]:
pred_time = [6,12]
eval_time = [24, 48, 72]

seed_list = [1234, 1111, 2222, 3333, 4444]
best_pt_path_list = ['ddh_eeg_nocr_1234_alpha1.0_beta0.5_gamma0.1_BSZ32_Lr0.0005_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230417_1709',
                     'ddh_eeg_nocr_1111_alpha1.0_beta0.5_gamma0.1_BSZ32_Lr0.0005_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230417_1855',
                     'ddh_eeg_nocr_2222_alpha1.0_beta1_gamma0.1_BSZ32_Lr0.0005_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230417_2103',
                     'ddh_eeg_nocr_3333_alpha1.0_beta0.5_gamma0.05_BSZ32_Lr0.0001_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230417_2140',
                     'ddh_eeg_nocr_4444_alpha1.0_beta1_gamma0.1_BSZ32_Lr0.0005_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230417_2359'
                    ]

C_INDEX = []

for r, seed in tqdm(enumerate(seed_list)):
    best_pt_path = best_pt_path_list[r]

    ### TRAINING-TESTING SPLIT
    (tr_data,te_data, tr_data_mi, te_data_mi, tr_time,te_time, tr_label,te_label, tr_time_orinal,te_time_original,
     tr_time_to_last,te_time_to_last,
     tr_mask1,te_mask1, tr_mask2,te_mask2, tr_mask3,te_mask3) = \
            train_test_split(data, data_mi, time, label, time_original, time_to_last,
                             mask1, mask2, mask3, test_size=0.2, stratify=label, random_state=seed) 

    (tr_data,va_data, tr_data_mi, va_data_mi, tr_time,va_time, tr_label,va_label, tr_time_orinal,va_time_original,
     tr_time_to_last, va_time_to_last,
     tr_mask1,va_mask1, tr_mask2,va_mask2, tr_mask3,va_mask3) = \
            train_test_split(tr_data, tr_data_mi, tr_time, tr_label, tr_time_orinal, tr_time_to_last,
                             tr_mask1, tr_mask2, tr_mask3, test_size=0.2, stratify=tr_label, random_state=seed) 

    model_name = 'ddh_eeg_nocr'
    num_Event = 1

    model = \
        DynamicDeepHitTorch(input_dim  = x_dim,
                            output_dim = num_Category,
                            layers_rnn = param['layers_rnn'],
                            hidden_rnn = param['hidden_rnn'],
                            long_param = {'layers': 2*[param['hidden_dim_FC']], 'dropout': dropout}, 
                            att_param  = {'layers': param['hidden_att']*[param['hidden_dim_FC']], 'dropout': dropout}, 
                            cs_param   = {'layers': param['hidden_cs']*[param['hidden_dim_FC']], 'dropout': dropout},
                            risks      = num_Event
                           )
    optimizer = torch.optim.Adam(model.parameters(), lr = lr_train)
    _ = load_checkpoint(f'checkpoints/{model_name}/{best_pt_path}_{seed}.pt', model, optimizer, device)
    model = model.to(device)
    
    df1, df2 = eval_model(model, te_data, pred_time, eval_time, num_Event)
    C_INDEX.append(df2)

2it [00:00,  5.09it/s]

--------------------------------------------------------
- C-INDEX: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.858338      0.888081      0.886744
pred_time 12: event_1      0.845522      0.872904      0.879616
--------------------------------------------------------
- BRIER-SCORE: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.048799      0.094871      0.110777
pred_time 12: event_1      0.058169      0.099601      0.112548
--------------------------------------------------------
- C-INDEX: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.840745      0.858889      0.850209
pred_time 12: event_1      0.854919      0.850582      0.853053
--------------------------------------------------------
- BRIER-SCORE: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.044879      0.089282       0.10772
pred_time 12

4it [00:00,  5.29it/s]

--------------------------------------------------------
- C-INDEX: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.795208      0.830823      0.822919
pred_time 12: event_1      0.804613      0.836010      0.832424
--------------------------------------------------------
- BRIER-SCORE: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.046276      0.098382      0.131416
pred_time 12: event_1      0.055472      0.101865      0.128824
--------------------------------------------------------
- C-INDEX: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.877805      0.855403      0.849194
pred_time 12: event_1      0.849527      0.847369      0.837496
--------------------------------------------------------
- BRIER-SCORE: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.046273      0.093488      0.115630
pred_time 12

5it [00:00,  5.24it/s]

--------------------------------------------------------
- C-INDEX: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.878678      0.873653      0.865085
pred_time 12: event_1      0.868376      0.869736      0.870631
--------------------------------------------------------
- BRIER-SCORE: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.050029      0.091062      0.106825
pred_time 12: event_1      0.055840      0.095999      0.109115





In [29]:
c_index_mean = np.zeros_like(C_INDEX[0].values)
c_index_std = np.zeros_like(C_INDEX[0].values)
for i in range(c_index_mean.shape[0]):
    for j in range(c_index_mean.shape[1]):
        c_list = [C_INDEX[k].iloc[i,j] for k in range(len(seed_list))]
        c_list = np.array(c_list)
        c_index_mean[i,j] = c_list.mean()
        c_index_std[i,j]  = c_list.std()
        
C_index_mean = C_INDEX[0].copy()
C_index_std = C_INDEX[0].copy()
C_index_mean.iloc[:, :] = c_index_mean
C_index_std.iloc[:, :] = c_index_std

In [31]:
C_index_mean
C_index_std

Unnamed: 0,eval_time 24,eval_time 48,eval_time 72
pred_time 6: event_1,0.001879,0.003144,0.009013
pred_time 12: event_1,0.002665,0.00236,0.007093
