In [1]:
SEED = 4444
_EPSILON = 1e-08

In [2]:
import sys
sys.path.append('../auton-survival/')
sys.path.append('../ddh/')

import pandas as pd
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import os
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import datetime

from ddh_torch import DynamicDeepHitTorch
import import_data as impt
from utils_eval import c_index, brier_score
from utils_helper import f_get_boosted_trainset
from utils_helper import loss_Log_Likelihood, loss_Ranking, loss_RNN_Prediction
from utils_helper import f_get_risk_predictions, _f_get_pred
from utils_helper import save_checkpoint, load_checkpoint

In [3]:
## for reproducibility
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

In [5]:
def model_eval(model, dataloader, device):
    model.eval()
    total_loss = 0
    Loss1, Loss2, Loss3 = 0, 0, 0

    for i, (b_data, b_data_mi, b_time, b_label, b_mask1, b_mask2, b_mask3) in enumerate(dataloader):
        b_data, _, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
            b_data.float(), b_data_mi.float(), b_time.float(), b_label.float(), \
            b_mask1.float(), b_mask2.float(), b_mask3.float()
        b_data, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
            b_data.to(device), b_time.to(device), b_label.to(device), \
            b_mask1.to(device), b_mask2.to(device), b_mask3.to(device)
        
        longitudinal_prediction, out = model(b_data)
        out = torch.concat([o.unsqueeze(1) for o in out], 1) # (B, num_Event, num_Category)
        loss1 = loss_Log_Likelihood(out, b_label, b_mask1, b_mask2)
        loss2 = loss_Ranking(out, b_time, b_label, b_mask3, num_Event, num_Category)
        loss3 = loss_RNN_Prediction(longitudinal_prediction, b_data)
        loss = alpha * loss1 + beta * loss2 + gamma * loss3
        
        total_loss += loss
        Loss1 += loss1
        Loss2 += loss2
        Loss3 += loss3
        
    return total_loss/len(dataloader), Loss1/len(dataloader), Loss2/len(dataloader), Loss3/len(dataloader)

### Import Dataset

In [9]:
data_mode                   = 'EEG_cr3' # competing risk with 3 causes
seed                        = SEED

##### IMPORT DATASET
'''
    num_Category            = max event/censoring time * 1.2
    num_Event               = number of evetns i.e. len(np.unique(label))-1
    max_length              = maximum number of measurements
    x_dim                   = data dimension including delta (1 + num_features)
    x_dim_cont              = dim of continuous features
    x_dim_bin               = dim of binary features
    mask1, mask2, mask3     = used for cause-specific network (FCNet structure)
'''

if data_mode == 'EEG_cr3':
    (x_dim, x_dim_cont, x_dim_bin), (data, time, label, time_original, time_to_last), \
        (mask1, mask2, mask3), (data_mi), trans_discrete_time = \
                    impt.import_dataset('../../eeg/competing-risk/EEG_processed_data_long_by_death_cat_exclude_3.csv')
    
    # This must be changed depending on the datasets, prediction/evaliation times of interest
    pred_time = [6, 12] # prediction time (in hours)
    eval_time = [12, 24,] # hours evaluation time (for C-index and Brier-Score)
else:
    print ('ERROR:  DATA_MODE NOT FOUND !!!')

_, num_Event, num_Category  = np.shape(mask1)  # dim of mask3: [subj, Num_Event, Num_Category]
max_length                  = np.shape(data)[1]


# file_path = '{}_{}'.format(data_mode, seed)

# if not os.path.exists(file_path):
#     os.makedirs(file_path)
    
num_Event, num_Category, max_length



(3, 120, 12)

### Set Hyper-Parameters

In [11]:
from sklearn.model_selection import ParameterGrid
model_name = 'ddh_eeg_cr3'
param_grid = {'dropout': [0.2, 0.4], 
              'lr_train': [1e-4, 5e-4, 1e-3],
              'beta': [0.5, 1, 5],
              'gamma': [0.05, 0.1, 0.5],
             }
param_grid_list = list(ParameterGrid(param_grid))
print(len(param_grid_list))
param_grid_list

54


[{'beta': 0.5, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0001},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0005},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.001},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0001},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0005},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.001},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0001},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0005},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.001},
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.05, 'lr_train': 0.0001},
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.05, 'lr_train': 0.0005},
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.05, 'lr_train': 0.001},
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.1, 'lr_train': 0.0001},
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.1, 'lr_train': 0.0005},
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.1, 'lr_train': 0.001},
 {'beta':

In [12]:
boost_mode = 'ON' #{'ON', 'OFF'}

param = {'batch_size': 32,

         'num_epoch_burn_in': 20,
         'num_epoch': 100,

         'dropout': 0.2, # 1 - keep_prob
         'lr_train': 1e-4,

         'hidden_rnn': 64,
         'hidden_dim_FC': 64,
         'layers_rnn': 2,
         'hidden_att': 2,
         'hidden_cs' : 2,

         'alpha' :1.0,
         'beta'  :1,
         'gamma' :0.1
        }


batch_size = param['batch_size']
num_epoch  = param['num_epoch']

dropout    = param['dropout']
lr_train   = param['lr_train']

hidden_rnn = param['hidden_rnn']
hidden_dim_FC = param['hidden_dim_FC']
layers_rnn = param['layers_rnn']
hidden_att = param['hidden_att']
hidden_cs = param['hidden_cs']

alpha      = param['alpha']
beta       = param['beta']
gamma      = param['gamma']

device     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


### Split Dataset into Train/Valid/Test Sets

In [13]:
### TRAINING-TESTING SPLIT
(tr_data,te_data, tr_data_mi, te_data_mi, tr_time,te_time, tr_label,te_label, tr_time_orinal,te_time_original,
 tr_time_to_last,te_time_to_last,
 tr_mask1,te_mask1, tr_mask2,te_mask2, tr_mask3,te_mask3) = \
        train_test_split(data, data_mi, time, label, time_original, time_to_last,
                         mask1, mask2, mask3, test_size=0.2, stratify=label, random_state=seed) 

(tr_data,va_data, tr_data_mi, va_data_mi, tr_time,va_time, tr_label,va_label, tr_time_orinal,va_time_original,
 tr_time_to_last, va_time_to_last,
 tr_mask1,va_mask1, tr_mask2,va_mask2, tr_mask3,va_mask3) = \
        train_test_split(tr_data, tr_data_mi, tr_time, tr_label, tr_time_orinal, tr_time_to_last,
                         tr_mask1, tr_mask2, tr_mask3, test_size=0.2, stratify=tr_label, random_state=seed) 

if boost_mode == 'ON':
    tr_data, tr_data_mi, tr_time, tr_label, tr_mask1, tr_mask2, tr_mask3 = \
            f_get_boosted_trainset(tr_data, tr_data_mi, tr_time, tr_label, tr_mask1, tr_mask2, tr_mask3)

In [15]:
tr_loader = DataLoader(list(zip(tr_data, tr_data_mi, tr_time, tr_label, tr_mask1, tr_mask2, tr_mask3)), 
                       batch_size=batch_size, shuffle=True)

va_loader = DataLoader(list(zip(va_data, va_data_mi, va_time, va_label, va_mask1, va_mask2, va_mask3)), 
                       batch_size=batch_size, shuffle=False)

te_loader = DataLoader(list(zip(te_data, te_data_mi, te_time, te_label, te_mask1, te_mask2, te_mask3)), 
                       batch_size=batch_size, shuffle=False)

### Train the Network

This implementation does not consider missing values.

In [16]:
num_Event

3

In [17]:
pred_time = [6, 12]
eval_time = [24, 48, 72]

In [None]:
import pdb

Best_C_Index = 0


for params in tqdm(param_grid_list):
    dropout = params['dropout']
    lr_train = params['lr_train']
    beta = params['beta']
    gamma = params['gamma']
    print('='*80, '\n', params)

    model = \
        DynamicDeepHitTorch(input_dim  = x_dim,
                            output_dim = num_Category,
                            layers_rnn = param['layers_rnn'],
                            hidden_rnn = param['hidden_rnn'],
                            long_param = {'layers': 2*[param['hidden_dim_FC']], 'dropout': dropout}, 
                            att_param  = {'layers': param['hidden_att']*[param['hidden_dim_FC']], 'dropout': dropout}, 
                            cs_param   = {'layers': param['hidden_cs']*[param['hidden_dim_FC']], 'dropout': dropout},
                            risks      = num_Event
                           )
    optimizer_burn_in = torch.optim.Adam(model.parameters(), lr = lr_train)
    optimizer = torch.optim.Adam(model.parameters(), lr = lr_train)


    model.train()
    model.to(device)

    min_valid = 0.1
    early_stop_patience = 10
    ep = 0 # for tracking early stopping

    experiment_runtime = datetime.datetime.now().strftime('%Y%m%d_%H%M')
    tensorboard_string_extension = f"{model_name}_{SEED}"
    tensorboard_string_extension += f"_alpha{alpha}_beta{beta}_gamma{gamma}_BSZ{batch_size}_Lr{lr_train}"
    tensorboard_string_extension += f"_hiddenRNN{hidden_rnn}_hiddenFC{hidden_dim_FC}"
    tensorboard_string_extension += f"_layersRNN{layers_rnn}_hiddenAtt{hidden_att}_hiddenCS{hidden_cs}"
    tensorboard_string_extension += f"_dropout{dropout}"
    tensorboard_string_extension += "_" + experiment_runtime
    tensorboard_filename = tensorboard_string_extension
    writer = SummaryWriter(f"runs/{model_name}/{tensorboard_filename}")


    for epoch in range(param['num_epoch_burn_in']):
        model.train()
        for i, (b_data, b_data_mi, b_time, b_label, b_mask1, b_mask2, b_mask3) in enumerate(tr_loader):
            b_data, _, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
                b_data.float(), b_data_mi.float(), b_time.float(), b_label.float(), \
                b_mask1.float(), b_mask2.float(), b_mask3.float()
            b_data, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
                b_data.to(device), b_time.to(device), b_label.to(device), \
                b_mask1.to(device), b_mask2.to(device), b_mask3.to(device)

            longitudinal_prediction, out = model(b_data)
            out = torch.concat([o.unsqueeze(1) for o in out], 1) # (B, num_Event, num_Category)

            # loss1 = loss_Log_Likelihood(out, b_label, b_mask1, b_mask2)
            # loss2 = loss_Ranking(out, b_time, b_label, b_mask3, num_Event, num_Category)
            loss3 = loss_RNN_Prediction(longitudinal_prediction, b_data)
            tr_loss = loss3

            optimizer_burn_in.zero_grad()
            tr_loss.backward()
            optimizer_burn_in.step()


    for epoch in range(num_epoch):
        model.train()
        for i, (b_data, b_data_mi, b_time, b_label, b_mask1, b_mask2, b_mask3) in enumerate(tr_loader):
            b_data, _, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
                b_data.float(), b_data_mi.float(), b_time.float(), b_label.float(), \
                b_mask1.float(), b_mask2.float(), b_mask3.float()
            b_data, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
                b_data.to(device), b_time.to(device), b_label.to(device), \
                b_mask1.to(device), b_mask2.to(device), b_mask3.to(device)

            longitudinal_prediction, out = model(b_data)
            out = torch.concat([o.unsqueeze(1) for o in out], 1) # (B, num_Event, num_Category)

            loss1 = loss_Log_Likelihood(out, b_label, b_mask1, b_mask2)
            loss2 = loss_Ranking(out, b_time, b_label, b_mask3, num_Event, num_Category)
            loss3 = loss_RNN_Prediction(longitudinal_prediction, b_data)
            tr_loss = alpha * loss1 + beta * loss2 + gamma * loss3

            optimizer.zero_grad()
            tr_loss.backward()
            optimizer.step()

        model.eval()
        tr_loss, tr_loss1, tr_loss2, tr_loss3 = model_eval(model, tr_loader, device)
        va_loss, va_loss1, va_loss2, va_loss3 = model_eval(model, va_loader, device)

        ### VALIDATION  (based on average C-index of our interest)
        risk_all = f_get_risk_predictions(model, va_data, pred_time, eval_time, device)

        for p, p_time in enumerate(pred_time):
            pred_horizon = int(p_time)
            val_result1 = np.zeros([num_Event, len(eval_time)])

            for t, t_time in enumerate(eval_time):                
                eval_horizon = int(t_time) + pred_horizon
                for k in range(num_Event):
                    val_result1[k, t] = c_index(risk_all[k][:, p, t], va_time, (va_label[:,0] == k+1).astype(int), eval_horizon) #-1 for no event (not comparable)

            if p == 0:
                val_final1 = val_result1
            else:
                val_final1 = np.append(val_final1, val_result1, axis=0)

        tmp_valid = np.mean(val_final1)

        if tmp_valid >  min_valid:
            ep = 0
            min_valid = tmp_valid
            save_checkpoint(f"checkpoints/{model_name}/{writer.get_logdir().split('/')[-1]}_{SEED}.pt", model, optimizer, va_loss)
            print( 'updated.... average c-index = ' + str('%.4f' %(tmp_valid)))
        else:
            ep += 1  
            print("Counter {} of {}".format(ep, early_stop_patience))

        if ep >= early_stop_patience:
            print("Early stopping with best_cindex: {:.4f}".format(min_valid), 
                  "and val_cindex for this epoch: {:.4f}".format(tmp_valid))
            break

        # Tensorboard
        writer.add_scalar("Loss-Training/Loss", tr_loss, epoch)
        writer.add_scalar("Loss-Training/NLL Loss", tr_loss1, epoch)
        writer.add_scalar("Loss-Training/Ranking Loss", tr_loss2, epoch)
        writer.add_scalar("Loss-Training/LSTM MSELoss", tr_loss3, epoch)

        writer.add_scalar("Loss-Validation/Loss", va_loss, epoch)
        writer.add_scalar("Loss-Validation/NLL Loss", va_loss1, epoch)
        writer.add_scalar("Loss-Validation/Ranking Loss", va_loss2, epoch)
        writer.add_scalar("Loss-Validation/LSTM MSELoss", va_loss3, epoch)

        writer.add_scalar("C-index/Validation Avg C-index", min_valid, epoch)

        print('epoch {}, tr_loss: {:.4f}, va_loss: {:.4f}, va c-index: {:.4f}'.\
                      format(epoch+1, tr_loss, va_loss, tmp_valid))
        
    if min_valid > Best_C_Index:
        Best_C_Index = min_valid
        print('Better valudation c-index!!!')
        print(writer.get_logdir().split('/')[-1])

  0%|                                                                | 0/54 [00:00<?, ?it/s]

 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0001}
updated.... average c-index = 0.6208
epoch 1, tr_loss: 31.4881, va_loss: 36.2923, va c-index: 0.6208
Counter 1 of 10
epoch 2, tr_loss: 31.3619, va_loss: 36.2717, va c-index: 0.6153
updated.... average c-index = 0.6245
epoch 3, tr_loss: 29.6774, va_loss: 35.9038, va c-index: 0.6245
updated.... average c-index = 0.6292
epoch 4, tr_loss: 29.3677, va_loss: 35.8742, va c-index: 0.6292
updated.... average c-index = 0.6346
epoch 5, tr_loss: 29.0159, va_loss: 35.6508, va c-index: 0.6346
updated.... average c-index = 0.6393
epoch 6, tr_loss: 28.3746, va_loss: 35.5731, va c-index: 0.6393
updated.... average c-index = 0.6479
epoch 7, tr_loss: 28.6759, va_loss: 35.6991, va c-index: 0.6479
updated.... average c-index = 0.6496
epoch 8, tr_loss: 28.1240, va_loss: 35.6501, va c-index: 0.6496
updated.... average c-index = 0.6568
epoch 9, tr_loss: 27.8617, va_loss: 35.6734, va c-index: 0.6568
Counter 1 of 10
epoch 10, tr_loss: 27.6488, va

  2%|▉                                                    | 1/54 [02:18<2:02:39, 138.86s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6840 and val_cindex for this epoch: 0.6767
Better valudation c-index!!!
ddh_eeg_cr3_4444_alpha1.0_beta0.5_gamma0.05_BSZ32_Lr0.0001_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230314_1025
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0005}
updated.... average c-index = 0.5652
epoch 1, tr_loss: 26.6152, va_loss: 39.4735, va c-index: 0.5652
updated.... average c-index = 0.6413
epoch 2, tr_loss: 23.4386, va_loss: 39.3555, va c-index: 0.6413
updated.... average c-index = 0.6473
epoch 3, tr_loss: 21.4149, va_loss: 38.4729, va c-index: 0.6473
updated.... average c-index = 0.6701
epoch 4, tr_loss: 20.1238, va_loss: 38.8832, va c-index: 0.6701
Counter 1 of 10
epoch 5, tr_loss: 20.2581, va_loss: 39.1446, va c-index: 0.6335
Counter 2 of 10
epoch 6, tr_loss: 22.6350, va_loss: 40.9687, va c-index: 0.6456
Counter 3 of 10
epoch 7, tr_loss: 22.8216, va_loss: 36.5716, va c-index: 0.6659
Counter 4 of 10
epoch 8, tr

  4%|█▉                                                   | 2/54 [03:57<1:39:46, 115.13s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6701 and val_cindex for this epoch: 0.5862
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.001}
updated.... average c-index = 0.6731
epoch 1, tr_loss: 27.6562, va_loss: 36.3823, va c-index: 0.6731
Counter 1 of 10
epoch 2, tr_loss: 27.4547, va_loss: 36.5253, va c-index: 0.6299
Counter 2 of 10
epoch 3, tr_loss: 31.1247, va_loss: 40.0557, va c-index: 0.5742
Counter 3 of 10
epoch 4, tr_loss: 27.7220, va_loss: 35.1720, va c-index: 0.6481
Counter 4 of 10
epoch 5, tr_loss: 26.8663, va_loss: 36.7862, va c-index: 0.5916
Counter 5 of 10
epoch 6, tr_loss: 27.3393, va_loss: 38.2245, va c-index: 0.6081
Counter 6 of 10
epoch 7, tr_loss: 25.9059, va_loss: 36.6620, va c-index: 0.6693
Counter 7 of 10
epoch 8, tr_loss: 24.7103, va_loss: 36.7114, va c-index: 0.6585
Counter 8 of 10
epoch 9, tr_loss: 25.5206, va_loss: 37.2442, va c-index: 0.6345
Counter 9 of 10
epoch 10, tr_loss: 23.4405, va_loss: 37.2313, va c-index: 0.6722


  6%|███                                                   | 3/54 [05:18<1:24:33, 99.47s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6731 and val_cindex for this epoch: 0.6683
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0001}
updated.... average c-index = 0.5502
epoch 1, tr_loss: 52.4227, va_loss: 62.9614, va c-index: 0.5502
updated.... average c-index = 0.6379
epoch 2, tr_loss: 51.5424, va_loss: 62.9606, va c-index: 0.6379
Counter 1 of 10
epoch 3, tr_loss: 50.8352, va_loss: 62.3518, va c-index: 0.6235
Counter 2 of 10
epoch 4, tr_loss: 50.1322, va_loss: 62.6174, va c-index: 0.5845
Counter 3 of 10
epoch 5, tr_loss: 50.2574, va_loss: 62.9095, va c-index: 0.5626
Counter 4 of 10
epoch 6, tr_loss: 49.3446, va_loss: 62.8156, va c-index: 0.5829
Counter 5 of 10
epoch 7, tr_loss: 49.1829, va_loss: 62.6899, va c-index: 0.6079
Counter 6 of 10
epoch 8, tr_loss: 48.1104, va_loss: 62.5574, va c-index: 0.6116
Counter 7 of 10
epoch 9, tr_loss: 48.2718, va_loss: 63.0557, va c-index: 0.6197
Counter 8 of 10
epoch 10, tr_loss: 46.9648, va_loss: 62.2535, va c-index: 0.62

  7%|███▉                                                 | 4/54 [10:41<2:36:26, 187.73s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.7148 and val_cindex for this epoch: 0.7010
Better valudation c-index!!!
ddh_eeg_cr3_4444_alpha1.0_beta0.5_gamma0.1_BSZ32_Lr0.0001_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230314_1030
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0005}
updated.... average c-index = 0.6007
epoch 1, tr_loss: 43.2925, va_loss: 64.6533, va c-index: 0.6007
updated.... average c-index = 0.6678
epoch 2, tr_loss: 39.4649, va_loss: 64.0966, va c-index: 0.6678
updated.... average c-index = 0.6801
epoch 3, tr_loss: 37.5918, va_loss: 62.8580, va c-index: 0.6801
Counter 1 of 10
epoch 4, tr_loss: 36.3031, va_loss: 63.8553, va c-index: 0.6435
updated.... average c-index = 0.6906
epoch 5, tr_loss: 33.1156, va_loss: 61.7698, va c-index: 0.6906
Counter 1 of 10
epoch 6, tr_loss: 33.2436, va_loss: 63.6097, va c-index: 0.6740
Counter 2 of 10
epoch 7, tr_loss: 32.1694, va_loss: 65.5723, va c-index: 0.6810
Counter 3 of 10
epoch 8, tr_l

  9%|████▉                                                | 5/54 [13:11<2:22:24, 174.38s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.7106 and val_cindex for this epoch: 0.6721
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.001}
updated.... average c-index = 0.5981
epoch 1, tr_loss: 57.3612, va_loss: 71.8221, va c-index: 0.5981
updated.... average c-index = 0.6039
epoch 2, tr_loss: 57.7757, va_loss: 72.4681, va c-index: 0.6039
Counter 1 of 10
epoch 3, tr_loss: 57.9489, va_loss: 74.8016, va c-index: 0.5992
Counter 2 of 10
epoch 4, tr_loss: 55.2187, va_loss: 74.0305, va c-index: 0.4909
Counter 3 of 10
epoch 5, tr_loss: 53.7529, va_loss: 73.5928, va c-index: 0.2996
Counter 4 of 10
epoch 6, tr_loss: 54.2789, va_loss: 74.7318, va c-index: 0.3002
Counter 5 of 10
epoch 7, tr_loss: 56.3278, va_loss: 74.4931, va c-index: 0.3008
Counter 6 of 10
epoch 8, tr_loss: 55.2951, va_loss: 74.6806, va c-index: 0.3011
Counter 7 of 10
epoch 9, tr_loss: 52.5243, va_loss: 75.3823, va c-index: 0.3006
Counter 8 of 10
epoch 10, tr_loss: 50.7166, va_loss: 74.2436, va c-index: 0.300

 11%|█████▉                                               | 6/54 [14:38<1:55:34, 144.46s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6039 and val_cindex for this epoch: 0.3035
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0001}
updated.... average c-index = 0.5208
epoch 1, tr_loss: 224.0143, va_loss: 277.3434, va c-index: 0.5208
updated.... average c-index = 0.5514
epoch 2, tr_loss: 221.8389, va_loss: 275.9254, va c-index: 0.5514
updated.... average c-index = 0.5744
epoch 3, tr_loss: 216.9831, va_loss: 276.4259, va c-index: 0.5744
updated.... average c-index = 0.5948
epoch 4, tr_loss: 215.5141, va_loss: 275.9234, va c-index: 0.5948
updated.... average c-index = 0.6177
epoch 5, tr_loss: 213.1636, va_loss: 274.3697, va c-index: 0.6177
updated.... average c-index = 0.6263
epoch 6, tr_loss: 210.1973, va_loss: 275.3633, va c-index: 0.6263
updated.... average c-index = 0.6542
epoch 7, tr_loss: 208.5298, va_loss: 275.6359, va c-index: 0.6542
Counter 1 of 10
epoch 8, tr_loss: 206.4594, va_loss: 274.7381, va c-index: 0.6317
Counter 2 of 10
epoch 9, tr_loss: 204

 13%|██████▊                                              | 7/54 [19:25<2:29:40, 191.07s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.7112 and val_cindex for this epoch: 0.7037
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0005}
updated.... average c-index = 0.5474
epoch 1, tr_loss: 183.7102, va_loss: 276.5255, va c-index: 0.5474
updated.... average c-index = 0.5992
epoch 2, tr_loss: 182.5504, va_loss: 294.7361, va c-index: 0.5992
Counter 1 of 10
epoch 3, tr_loss: 177.0708, va_loss: 291.6385, va c-index: 0.5805
Counter 2 of 10
epoch 4, tr_loss: 173.6467, va_loss: 297.6463, va c-index: 0.4944
Counter 3 of 10
epoch 5, tr_loss: 160.3464, va_loss: 287.5805, va c-index: 0.5891
Counter 4 of 10
epoch 6, tr_loss: 177.0634, va_loss: 302.2266, va c-index: 0.5237
Counter 5 of 10
epoch 7, tr_loss: 153.9014, va_loss: 298.8491, va c-index: 0.5456
updated.... average c-index = 0.6333
epoch 8, tr_loss: 145.6192, va_loss: 296.0160, va c-index: 0.6333
Counter 1 of 10
epoch 9, tr_loss: 148.6787, va_loss: 294.1065, va c-index: 0.5819
Counter 2 of 10
epoch 10, tr_loss: 147.

 15%|███████▊                                             | 8/54 [21:56<2:16:44, 178.35s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6823 and val_cindex for this epoch: 0.6721
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.001}
updated.... average c-index = 0.5865
epoch 1, tr_loss: 185.5463, va_loss: 294.3869, va c-index: 0.5865
updated.... average c-index = 0.6075
epoch 2, tr_loss: 201.3366, va_loss: 281.3783, va c-index: 0.6075
Counter 1 of 10
epoch 3, tr_loss: 182.6672, va_loss: 299.2939, va c-index: 0.5642
Counter 2 of 10
epoch 4, tr_loss: 177.4374, va_loss: 339.0493, va c-index: 0.5052
Counter 3 of 10
epoch 5, tr_loss: 182.8284, va_loss: 294.1678, va c-index: 0.5465
Counter 4 of 10
epoch 6, tr_loss: 171.5012, va_loss: 283.8381, va c-index: 0.5134
Counter 5 of 10
epoch 7, tr_loss: 172.0211, va_loss: 287.1907, va c-index: 0.5189
Counter 6 of 10
epoch 8, tr_loss: 214.6126, va_loss: 308.3563, va c-index: 0.5221
Counter 7 of 10
epoch 9, tr_loss: 184.9675, va_loss: 354.3998, va c-index: 0.5216
Counter 8 of 10
epoch 10, tr_loss: 163.0577, va_loss: 385.392

 17%|████████▊                                            | 9/54 [23:24<1:52:31, 150.04s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6075 and val_cindex for this epoch: 0.3440
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.05, 'lr_train': 0.0001}
updated.... average c-index = 0.5578
epoch 1, tr_loss: 34.3975, va_loss: 37.3978, va c-index: 0.5578
updated.... average c-index = 0.5657
epoch 2, tr_loss: 31.8717, va_loss: 36.8856, va c-index: 0.5657
updated.... average c-index = 0.6005
epoch 3, tr_loss: 30.9567, va_loss: 36.4042, va c-index: 0.6005
Counter 1 of 10
epoch 4, tr_loss: 31.4228, va_loss: 36.2806, va c-index: 0.5667
Counter 2 of 10
epoch 5, tr_loss: 30.9816, va_loss: 35.9218, va c-index: 0.5746
Counter 3 of 10
epoch 6, tr_loss: 29.9882, va_loss: 35.8731, va c-index: 0.5865
updated.... average c-index = 0.6019
epoch 7, tr_loss: 30.2705, va_loss: 36.3676, va c-index: 0.6019
Counter 1 of 10
epoch 8, tr_loss: 30.1848, va_loss: 35.9459, va c-index: 0.5969
updated.... average c-index = 0.6076
epoch 9, tr_loss: 29.9958, va_loss: 35.6708, va c-index: 0.6076
Counter 1 of 1

 19%|█████████▋                                          | 10/54 [28:14<2:21:48, 193.37s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6681 and val_cindex for this epoch: 0.6603
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.05, 'lr_train': 0.0005}
updated.... average c-index = 0.5503
epoch 1, tr_loss: 31.3562, va_loss: 38.6575, va c-index: 0.5503
updated.... average c-index = 0.5522
epoch 2, tr_loss: 30.0861, va_loss: 38.3940, va c-index: 0.5522
Counter 1 of 10
epoch 3, tr_loss: 30.2624, va_loss: 39.6784, va c-index: 0.5466
updated.... average c-index = 0.5603
epoch 4, tr_loss: 27.7134, va_loss: 39.3443, va c-index: 0.5603
updated.... average c-index = 0.5673
epoch 5, tr_loss: 25.7346, va_loss: 36.9231, va c-index: 0.5673
updated.... average c-index = 0.5955
epoch 6, tr_loss: 24.7736, va_loss: 35.7928, va c-index: 0.5955
updated.... average c-index = 0.6158
epoch 7, tr_loss: 25.3030, va_loss: 37.0404, va c-index: 0.6158
Counter 1 of 10
epoch 8, tr_loss: 24.3275, va_loss: 36.6374, va c-index: 0.5872
updated.... average c-index = 0.6319
epoch 9, tr_loss: 24.3460, va_loss: 

 20%|██████████▌                                         | 11/54 [31:52<2:24:02, 200.98s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.7032 and val_cindex for this epoch: 0.6989
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.05, 'lr_train': 0.001}
updated.... average c-index = 0.5680
epoch 1, tr_loss: 29.9795, va_loss: 36.2351, va c-index: 0.5680
Counter 1 of 10
epoch 2, tr_loss: 30.0637, va_loss: 36.8382, va c-index: 0.5328
Counter 2 of 10
epoch 3, tr_loss: 30.9675, va_loss: 37.0804, va c-index: 0.5608
updated.... average c-index = 0.5925
epoch 4, tr_loss: 28.0847, va_loss: 35.3005, va c-index: 0.5925
updated.... average c-index = 0.6395
epoch 5, tr_loss: 28.4998, va_loss: 35.5441, va c-index: 0.6395
Counter 1 of 10
epoch 6, tr_loss: 27.2722, va_loss: 34.5902, va c-index: 0.6056
Counter 2 of 10
epoch 7, tr_loss: 29.6923, va_loss: 37.9956, va c-index: 0.5555
updated.... average c-index = 0.6533
epoch 8, tr_loss: 26.7117, va_loss: 34.9916, va c-index: 0.6533
Counter 1 of 10
epoch 9, tr_loss: 29.6618, va_loss: 35.1613, va c-index: 0.6532
Counter 2 of 10
epoch 10, tr_loss: 2

 22%|███████████▌                                        | 12/54 [35:13<2:20:40, 200.98s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.7233 and val_cindex for this epoch: 0.7053
Better valudation c-index!!!
ddh_eeg_cr3_4444_alpha1.0_beta0.5_gamma0.05_BSZ32_Lr0.001_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.4_20230314_1057
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.1, 'lr_train': 0.0001}
updated.... average c-index = 0.5092
epoch 1, tr_loss: 55.4419, va_loss: 63.4466, va c-index: 0.5092
updated.... average c-index = 0.5768
epoch 2, tr_loss: 52.6662, va_loss: 62.7995, va c-index: 0.5768
updated.... average c-index = 0.5953
epoch 3, tr_loss: 52.4902, va_loss: 62.5245, va c-index: 0.5953
Counter 1 of 10
epoch 4, tr_loss: 52.4294, va_loss: 62.7761, va c-index: 0.5858
Counter 2 of 10
epoch 5, tr_loss: 51.4651, va_loss: 62.6441, va c-index: 0.5857
Counter 3 of 10
epoch 6, tr_loss: 50.3353, va_loss: 62.7277, va c-index: 0.5913
updated.... average c-index = 0.6002
epoch 7, tr_loss: 50.5201, va_loss: 62.8509, va c-index: 0.6002
Counter 1 of 10
epoch 8, tr_l

 24%|████████████▌                                       | 13/54 [38:50<2:20:35, 205.74s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6596 and val_cindex for this epoch: 0.6572
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.1, 'lr_train': 0.0005}
updated.... average c-index = 0.5441
epoch 1, tr_loss: 50.2813, va_loss: 64.6702, va c-index: 0.5441
updated.... average c-index = 0.5848
epoch 2, tr_loss: 47.6517, va_loss: 64.0453, va c-index: 0.5848
updated.... average c-index = 0.6300
epoch 3, tr_loss: 46.5595, va_loss: 63.8572, va c-index: 0.6300
updated.... average c-index = 0.6445
epoch 4, tr_loss: 45.8733, va_loss: 64.0065, va c-index: 0.6445
updated.... average c-index = 0.6502
epoch 5, tr_loss: 43.0613, va_loss: 65.2085, va c-index: 0.6502
updated.... average c-index = 0.6504
epoch 6, tr_loss: 43.4562, va_loss: 64.6345, va c-index: 0.6504
updated.... average c-index = 0.6740
epoch 7, tr_loss: 43.9364, va_loss: 64.6217, va c-index: 0.6740
updated.... average c-index = 0.6788
epoch 8, tr_loss: 38.2162, va_loss: 65.9783, va c-index: 0.6788
Counter 1 of 10
epoch 9, tr_loss

 26%|█████████████▍                                      | 14/54 [41:58<2:13:28, 200.21s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6944 and val_cindex for this epoch: 0.6774
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.1, 'lr_train': 0.001}
updated.... average c-index = 0.6522
epoch 1, tr_loss: 51.2963, va_loss: 63.6856, va c-index: 0.6522
Counter 1 of 10
epoch 2, tr_loss: 48.3548, va_loss: 62.1348, va c-index: 0.6137
Counter 2 of 10
epoch 3, tr_loss: 46.8988, va_loss: 62.8582, va c-index: 0.6299
Counter 3 of 10
epoch 4, tr_loss: 45.5321, va_loss: 60.8956, va c-index: 0.6320
updated.... average c-index = 0.6848
epoch 5, tr_loss: 45.3175, va_loss: 66.1968, va c-index: 0.6848
Counter 1 of 10
epoch 6, tr_loss: 47.1833, va_loss: 60.2790, va c-index: 0.6664
Counter 2 of 10
epoch 7, tr_loss: 43.1989, va_loss: 62.7338, va c-index: 0.6842
Counter 3 of 10
epoch 8, tr_loss: 43.9542, va_loss: 64.0769, va c-index: 0.6545
Counter 4 of 10
epoch 9, tr_loss: 42.3450, va_loss: 62.5165, va c-index: 0.6763
Counter 5 of 10
epoch 10, tr_loss: 48.3135, va_loss: 62.0628, va c-index: 0.662

 28%|██████████████▍                                     | 15/54 [44:41<2:03:00, 189.24s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.7109 and val_cindex for this epoch: 0.6852
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.5, 'lr_train': 0.0001}
updated.... average c-index = 0.5553
epoch 1, tr_loss: 288.7724, va_loss: 342.4261, va c-index: 0.5553
updated.... average c-index = 0.5688
epoch 2, tr_loss: 279.1596, va_loss: 330.3333, va c-index: 0.5688
updated.... average c-index = 0.5796
epoch 3, tr_loss: 259.5881, va_loss: 309.8677, va c-index: 0.5796
Counter 1 of 10
epoch 4, tr_loss: 245.3879, va_loss: 295.8973, va c-index: 0.5773
updated.... average c-index = 0.5827
epoch 5, tr_loss: 236.4909, va_loss: 288.4579, va c-index: 0.5827
updated.... average c-index = 0.5914
epoch 6, tr_loss: 227.3066, va_loss: 282.9779, va c-index: 0.5914
updated.... average c-index = 0.5992
epoch 7, tr_loss: 220.2635, va_loss: 282.3904, va c-index: 0.5992
updated.... average c-index = 0.6029
epoch 8, tr_loss: 219.9887, va_loss: 280.3968, va c-index: 0.6029
Counter 1 of 10
epoch 9, tr_loss: 220

 30%|███████████████▍                                    | 16/54 [50:16<2:27:33, 232.98s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6892 and val_cindex for this epoch: 0.6846
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.5, 'lr_train': 0.0005}
updated.... average c-index = 0.5849
epoch 1, tr_loss: 191.0362, va_loss: 280.5488, va c-index: 0.5849
updated.... average c-index = 0.6027
epoch 2, tr_loss: 184.3867, va_loss: 281.2785, va c-index: 0.6027
Counter 1 of 10
epoch 3, tr_loss: 182.5152, va_loss: 281.8851, va c-index: 0.5894
updated.... average c-index = 0.6618
epoch 4, tr_loss: 173.7420, va_loss: 291.0594, va c-index: 0.6618
Counter 1 of 10
epoch 5, tr_loss: 175.5743, va_loss: 292.0378, va c-index: 0.6201
Counter 2 of 10
epoch 6, tr_loss: 165.0801, va_loss: 293.4055, va c-index: 0.6499
updated.... average c-index = 0.6667
epoch 7, tr_loss: 148.6917, va_loss: 285.6953, va c-index: 0.6667
Counter 1 of 10
epoch 8, tr_loss: 183.9829, va_loss: 316.4873, va c-index: 0.6604
updated.... average c-index = 0.6826
epoch 9, tr_loss: 147.1881, va_loss: 296.5170, va c-index: 0.68

 31%|████████████████▎                                   | 17/54 [52:27<2:04:45, 202.30s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6921 and val_cindex for this epoch: 0.6597
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.5, 'lr_train': 0.001}
updated.... average c-index = 0.5977
epoch 1, tr_loss: 207.9098, va_loss: 277.9919, va c-index: 0.5977
Counter 1 of 10
epoch 2, tr_loss: 220.2438, va_loss: 281.3165, va c-index: 0.5480
Counter 2 of 10
epoch 3, tr_loss: 201.8204, va_loss: 282.5824, va c-index: 0.5343
updated.... average c-index = 0.6519
epoch 4, tr_loss: 198.7457, va_loss: 281.4623, va c-index: 0.6519
updated.... average c-index = 0.6563
epoch 5, tr_loss: 194.6953, va_loss: 269.7964, va c-index: 0.6563
updated.... average c-index = 0.6861
epoch 6, tr_loss: 191.9758, va_loss: 277.6926, va c-index: 0.6861
Counter 1 of 10
epoch 7, tr_loss: 204.5944, va_loss: 275.5511, va c-index: 0.6742
Counter 2 of 10
epoch 8, tr_loss: 191.9402, va_loss: 274.0439, va c-index: 0.6755
Counter 3 of 10
epoch 9, tr_loss: 197.5170, va_loss: 269.7621, va c-index: 0.6615
Counter 4 of 10
epo

 33%|█████████████████▎                                  | 18/54 [54:14<1:44:14, 173.74s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6861 and val_cindex for this epoch: 0.6558
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0001}
updated.... average c-index = 0.5492
epoch 1, tr_loss: 36.9693, va_loss: 41.3358, va c-index: 0.5492
updated.... average c-index = 0.5593
epoch 2, tr_loss: 35.2968, va_loss: 39.9931, va c-index: 0.5593
updated.... average c-index = 0.5853
epoch 3, tr_loss: 34.8540, va_loss: 39.3857, va c-index: 0.5853
Counter 1 of 10
epoch 4, tr_loss: 34.5551, va_loss: 39.4116, va c-index: 0.5851
updated.... average c-index = 0.6000
epoch 5, tr_loss: 33.6416, va_loss: 39.0673, va c-index: 0.6000
updated.... average c-index = 0.6101
epoch 6, tr_loss: 32.9520, va_loss: 38.7500, va c-index: 0.6101
updated.... average c-index = 0.6172
epoch 7, tr_loss: 32.3275, va_loss: 38.5477, va c-index: 0.6172
updated.... average c-index = 0.6244
epoch 8, tr_loss: 32.1521, va_loss: 38.4888, va c-index: 0.6244
Counter 1 of 10
epoch 9, tr_loss: 31.7349, va_loss: 38

 35%|██████████████████▎                                 | 19/54 [56:37<1:36:00, 164.58s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6613 and val_cindex for this epoch: 0.6490
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0005}
updated.... average c-index = 0.5501
epoch 1, tr_loss: 33.6549, va_loss: 44.3394, va c-index: 0.5501
Counter 1 of 10
epoch 2, tr_loss: 39.2834, va_loss: 42.5495, va c-index: 0.5363
updated.... average c-index = 0.5581
epoch 3, tr_loss: 30.6019, va_loss: 43.4143, va c-index: 0.5581
updated.... average c-index = 0.6084
epoch 4, tr_loss: 26.4277, va_loss: 40.1801, va c-index: 0.6084
updated.... average c-index = 0.6085
epoch 5, tr_loss: 26.0565, va_loss: 39.0791, va c-index: 0.6085
Counter 1 of 10
epoch 6, tr_loss: 29.1588, va_loss: 43.1283, va c-index: 0.5908
updated.... average c-index = 0.6185
epoch 7, tr_loss: 29.0757, va_loss: 42.6911, va c-index: 0.6185
updated.... average c-index = 0.6371
epoch 8, tr_loss: 26.9243, va_loss: 39.9058, va c-index: 0.6371
updated.... average c-index = 0.6493
epoch 9, tr_loss: 26.0831, va_loss: 40

 37%|███████████████████▎                                | 20/54 [59:10<1:31:11, 160.93s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6617 and val_cindex for this epoch: 0.5282
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.001}
updated.... average c-index = 0.6106
epoch 1, tr_loss: 36.8136, va_loss: 45.8663, va c-index: 0.6106
updated.... average c-index = 0.6431
epoch 2, tr_loss: 30.3669, va_loss: 41.0799, va c-index: 0.6431
Counter 1 of 10
epoch 3, tr_loss: 32.2490, va_loss: 43.5243, va c-index: 0.5492
Counter 2 of 10
epoch 4, tr_loss: 39.5966, va_loss: 49.7404, va c-index: 0.5656
Counter 3 of 10
epoch 5, tr_loss: 42.7789, va_loss: 48.9673, va c-index: 0.5593
Counter 4 of 10
epoch 6, tr_loss: 39.0064, va_loss: 51.0817, va c-index: 0.5576
Counter 5 of 10
epoch 7, tr_loss: 38.9318, va_loss: 49.6955, va c-index: 0.5570
Counter 6 of 10
epoch 8, tr_loss: 38.5569, va_loss: 51.3047, va c-index: 0.5499
Counter 7 of 10
epoch 9, tr_loss: 38.1219, va_loss: 50.2404, va c-index: 0.4826
Counter 8 of 10
epoch 10, tr_loss: 37.9283, va_loss: 49.8135, va c-index: 0.4827

 39%|███████████████████▍                              | 21/54 [1:00:35<1:16:00, 138.20s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6431 and val_cindex for this epoch: 0.4828
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0001}
updated.... average c-index = 0.6104
epoch 1, tr_loss: 59.0532, va_loss: 69.0956, va c-index: 0.6104
Counter 1 of 10
epoch 2, tr_loss: 58.3412, va_loss: 69.1352, va c-index: 0.6008
Counter 2 of 10
epoch 3, tr_loss: 58.1469, va_loss: 69.8055, va c-index: 0.5873
Counter 3 of 10
epoch 4, tr_loss: 57.7375, va_loss: 69.5284, va c-index: 0.5913
Counter 4 of 10
epoch 5, tr_loss: 57.4199, va_loss: 69.3469, va c-index: 0.6031
Counter 5 of 10
epoch 6, tr_loss: 56.3147, va_loss: 69.0272, va c-index: 0.6047
Counter 6 of 10
epoch 7, tr_loss: 55.8737, va_loss: 69.0977, va c-index: 0.6095
updated.... average c-index = 0.6111
epoch 8, tr_loss: 55.4849, va_loss: 68.8724, va c-index: 0.6111
updated.... average c-index = 0.6219
epoch 9, tr_loss: 54.0323, va_loss: 68.2313, va c-index: 0.6219
Counter 1 of 10
epoch 10, tr_loss: 54.5109, va_loss: 68.208

 41%|████████████████████▎                             | 22/54 [1:04:11<1:26:08, 161.50s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6757 and val_cindex for this epoch: 0.6581
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0005}
updated.... average c-index = 0.5563
epoch 1, tr_loss: 49.1415, va_loss: 71.1113, va c-index: 0.5563
updated.... average c-index = 0.6157
epoch 2, tr_loss: 45.4115, va_loss: 66.7374, va c-index: 0.6157
updated.... average c-index = 0.6323
epoch 3, tr_loss: 43.3803, va_loss: 69.6213, va c-index: 0.6323
updated.... average c-index = 0.6415
epoch 4, tr_loss: 41.8593, va_loss: 65.8270, va c-index: 0.6415
updated.... average c-index = 0.6445
epoch 5, tr_loss: 38.0905, va_loss: 65.0547, va c-index: 0.6445
Counter 1 of 10
epoch 6, tr_loss: 36.2054, va_loss: 70.2015, va c-index: 0.6080
Counter 2 of 10
epoch 7, tr_loss: 36.4016, va_loss: 66.3333, va c-index: 0.5884
updated.... average c-index = 0.6464
epoch 8, tr_loss: 35.0411, va_loss: 67.5008, va c-index: 0.6464
updated.... average c-index = 0.6579
epoch 9, tr_loss: 33.3616, va_loss: 64.

 43%|█████████████████████▎                            | 23/54 [1:06:38<1:21:13, 157.21s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6803 and val_cindex for this epoch: 0.6766
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.001}
updated.... average c-index = 0.4739
epoch 1, tr_loss: 60.0645, va_loss: 72.5033, va c-index: 0.4739
updated.... average c-index = 0.5239
epoch 2, tr_loss: 58.7150, va_loss: 74.2702, va c-index: 0.5239
updated.... average c-index = 0.5303
epoch 3, tr_loss: 59.3875, va_loss: 75.1802, va c-index: 0.5303
Counter 1 of 10
epoch 4, tr_loss: 57.7719, va_loss: 75.9527, va c-index: 0.5135
Counter 2 of 10
epoch 5, tr_loss: 56.5085, va_loss: 78.5017, va c-index: 0.5272
updated.... average c-index = 0.5481
epoch 6, tr_loss: 56.2699, va_loss: 77.9029, va c-index: 0.5481
Counter 1 of 10
epoch 7, tr_loss: 56.3052, va_loss: 76.0401, va c-index: 0.4077
Counter 2 of 10
epoch 8, tr_loss: 56.0845, va_loss: 79.9107, va c-index: 0.4106
Counter 3 of 10
epoch 9, tr_loss: 54.7949, va_loss: 80.1521, va c-index: 0.4224
Counter 4 of 10
epoch 10, tr_loss: 55.9

 44%|██████████████████████▏                           | 24/54 [1:08:22<1:10:37, 141.25s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.5481 and val_cindex for this epoch: 0.1712
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0001}
updated.... average c-index = 0.5527
epoch 1, tr_loss: 229.9800, va_loss: 280.9272, va c-index: 0.5527
Counter 1 of 10
epoch 2, tr_loss: 221.7777, va_loss: 281.5533, va c-index: 0.5501
Counter 2 of 10
epoch 3, tr_loss: 218.4233, va_loss: 282.6547, va c-index: 0.5435
Counter 3 of 10
epoch 4, tr_loss: 215.8605, va_loss: 282.3095, va c-index: 0.5507
updated.... average c-index = 0.5713
epoch 5, tr_loss: 213.7785, va_loss: 283.0789, va c-index: 0.5713
updated.... average c-index = 0.5723
epoch 6, tr_loss: 215.0329, va_loss: 281.6309, va c-index: 0.5723
updated.... average c-index = 0.5757
epoch 7, tr_loss: 211.6124, va_loss: 284.2629, va c-index: 0.5757
updated.... average c-index = 0.5922
epoch 8, tr_loss: 210.9591, va_loss: 284.1594, va c-index: 0.5922
updated.... average c-index = 0.5983
epoch 9, tr_loss: 209.6112, va_loss: 283.998

 46%|███████████████████████▏                          | 25/54 [1:12:47<1:26:09, 178.27s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6793 and val_cindex for this epoch: 0.6723
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0005}
updated.... average c-index = 0.5759
epoch 1, tr_loss: 151.7689, va_loss: 282.9630, va c-index: 0.5759
updated.... average c-index = 0.5995
epoch 2, tr_loss: 142.6212, va_loss: 312.6950, va c-index: 0.5995
Counter 1 of 10
epoch 3, tr_loss: 150.5140, va_loss: 283.3591, va c-index: 0.5363
Counter 2 of 10
epoch 4, tr_loss: 136.4515, va_loss: 301.4830, va c-index: 0.5530
Counter 3 of 10
epoch 5, tr_loss: 137.6825, va_loss: 330.9066, va c-index: 0.5529
Counter 4 of 10
epoch 6, tr_loss: 138.8768, va_loss: 309.0941, va c-index: 0.5511
Counter 5 of 10
epoch 7, tr_loss: 128.6332, va_loss: 304.3136, va c-index: 0.5405
Counter 6 of 10
epoch 8, tr_loss: 152.2822, va_loss: 292.7333, va c-index: 0.5218
Counter 7 of 10
epoch 9, tr_loss: 146.4389, va_loss: 299.0457, va c-index: 0.5265
Counter 8 of 10
epoch 10, tr_loss: 159.9199, va_loss: 331.2216

 48%|████████████████████████                          | 26/54 [1:14:14<1:10:26, 150.96s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.5995 and val_cindex for this epoch: 0.5225
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.001}
updated.... average c-index = 0.6299
epoch 1, tr_loss: 205.1813, va_loss: 284.4098, va c-index: 0.6299
Counter 1 of 10
epoch 2, tr_loss: 214.8070, va_loss: 299.9201, va c-index: 0.5345
Counter 2 of 10
epoch 3, tr_loss: 208.3268, va_loss: 280.7289, va c-index: 0.3594
Counter 3 of 10
epoch 4, tr_loss: 213.5549, va_loss: 284.4718, va c-index: 0.3616
Counter 4 of 10
epoch 5, tr_loss: 211.2365, va_loss: 293.0811, va c-index: 0.3614
Counter 5 of 10
epoch 6, tr_loss: 201.4305, va_loss: 287.2241, va c-index: 0.3621
Counter 6 of 10
epoch 7, tr_loss: 204.9607, va_loss: 300.8508, va c-index: 0.3633
Counter 7 of 10
epoch 8, tr_loss: 194.3884, va_loss: 296.6799, va c-index: 0.3884
Counter 8 of 10
epoch 9, tr_loss: 190.6970, va_loss: 315.7277, va c-index: 0.3538
Counter 9 of 10
epoch 10, tr_loss: 185.3269, va_loss: 322.3084, va c-index: 0.3577


 50%|██████████████████████████                          | 27/54 [1:15:37<58:49, 130.74s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6299 and val_cindex for this epoch: 0.3654
 {'beta': 1, 'dropout': 0.4, 'gamma': 0.05, 'lr_train': 0.0001}
updated.... average c-index = 0.5897
epoch 1, tr_loss: 45.2333, va_loss: 47.3541, va c-index: 0.5897
Counter 1 of 10
epoch 2, tr_loss: 40.1580, va_loss: 44.2779, va c-index: 0.5751
Counter 2 of 10
epoch 3, tr_loss: 38.4982, va_loss: 41.5433, va c-index: 0.5538
Counter 3 of 10
epoch 4, tr_loss: 36.6623, va_loss: 40.6518, va c-index: 0.5564
Counter 4 of 10
epoch 5, tr_loss: 36.2480, va_loss: 40.8331, va c-index: 0.5372
Counter 5 of 10
epoch 6, tr_loss: 35.6411, va_loss: 40.3953, va c-index: 0.5438
Counter 6 of 10
epoch 7, tr_loss: 35.1649, va_loss: 40.1289, va c-index: 0.5504
Counter 7 of 10
epoch 8, tr_loss: 34.6219, va_loss: 39.8959, va c-index: 0.5583
Counter 8 of 10
epoch 9, tr_loss: 34.1414, va_loss: 39.5539, va c-index: 0.5707
Counter 9 of 10
epoch 10, tr_loss: 33.8205, va_loss: 39.2817, va c-index: 0.5864


 52%|██████████████████████████▉                         | 28/54 [1:17:02<50:35, 116.75s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.5897 and val_cindex for this epoch: 0.5775
 {'beta': 1, 'dropout': 0.4, 'gamma': 0.05, 'lr_train': 0.0005}
updated.... average c-index = 0.6385
epoch 1, tr_loss: 33.8969, va_loss: 43.6298, va c-index: 0.6385
Counter 1 of 10
epoch 2, tr_loss: 29.7902, va_loss: 39.8372, va c-index: 0.6180
updated.... average c-index = 0.6395
epoch 3, tr_loss: 29.3989, va_loss: 39.5397, va c-index: 0.6395
Counter 1 of 10
epoch 4, tr_loss: 28.7069, va_loss: 38.8091, va c-index: 0.6378
Counter 2 of 10
epoch 5, tr_loss: 28.7981, va_loss: 39.5408, va c-index: 0.6034
Counter 3 of 10
epoch 6, tr_loss: 29.7172, va_loss: 40.8482, va c-index: 0.5708
Counter 4 of 10
epoch 7, tr_loss: 28.3005, va_loss: 39.7005, va c-index: 0.6088
Counter 5 of 10
epoch 8, tr_loss: 26.8548, va_loss: 39.4779, va c-index: 0.6112
Counter 6 of 10
epoch 9, tr_loss: 26.1046, va_loss: 41.2546, va c-index: 0.6291
Counter 7 of 10
epoch 10, tr_loss: 26.8416, va_loss: 42.5482, va c-index: 0.600

 54%|██████████████████████████▊                       | 29/54 [1:21:45<1:09:29, 166.76s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.7022 and val_cindex for this epoch: 0.6810
 {'beta': 1, 'dropout': 0.4, 'gamma': 0.05, 'lr_train': 0.001}
updated.... average c-index = 0.5753
epoch 1, tr_loss: 37.3326, va_loss: 42.7602, va c-index: 0.5753
updated.... average c-index = 0.6328
epoch 2, tr_loss: 34.3897, va_loss: 40.4844, va c-index: 0.6328
Counter 1 of 10
epoch 3, tr_loss: 34.1918, va_loss: 39.5611, va c-index: 0.5941
Counter 2 of 10
epoch 4, tr_loss: 31.7399, va_loss: 38.6338, va c-index: 0.6258
Counter 3 of 10
epoch 5, tr_loss: 32.0524, va_loss: 38.6383, va c-index: 0.5962
Counter 4 of 10
epoch 6, tr_loss: 31.2655, va_loss: 38.1040, va c-index: 0.6324
updated.... average c-index = 0.6347
epoch 7, tr_loss: 31.5855, va_loss: 37.9615, va c-index: 0.6347
Counter 1 of 10
epoch 8, tr_loss: 31.9071, va_loss: 38.4120, va c-index: 0.6308
updated.... average c-index = 0.6569
epoch 9, tr_loss: 29.9938, va_loss: 37.4148, va c-index: 0.6569
updated.... average c-index = 0.6761
e

 56%|███████████████████████████▊                      | 30/54 [1:23:51<1:01:52, 154.67s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6811 and val_cindex for this epoch: 0.2930
 {'beta': 1, 'dropout': 0.4, 'gamma': 0.1, 'lr_train': 0.0001}
updated.... average c-index = 0.5018
epoch 1, tr_loss: 61.0642, va_loss: 67.7218, va c-index: 0.5018
updated.... average c-index = 0.5414
epoch 2, tr_loss: 56.8409, va_loss: 67.2079, va c-index: 0.5414
Counter 1 of 10
epoch 3, tr_loss: 56.2519, va_loss: 66.4977, va c-index: 0.5386
updated.... average c-index = 0.5494
epoch 4, tr_loss: 54.9938, va_loss: 66.2892, va c-index: 0.5494
updated.... average c-index = 0.5577
epoch 5, tr_loss: 54.6680, va_loss: 65.8139, va c-index: 0.5577
updated.... average c-index = 0.5586
epoch 6, tr_loss: 54.3779, va_loss: 65.5380, va c-index: 0.5586
Counter 1 of 10
epoch 7, tr_loss: 54.9518, va_loss: 65.8667, va c-index: 0.5553
updated.... average c-index = 0.5610
epoch 8, tr_loss: 53.9554, va_loss: 66.1899, va c-index: 0.5610
Counter 1 of 10
epoch 9, tr_loss: 54.2507, va_loss: 66.6523, va c-index: 0.5

In [19]:
model

DynamicDeepHitTorch(
  (embedding): LSTM(116, 64, num_layers=2, bias=False, batch_first=True)
  (longitudinal): Sequential(
    (0): Dropout(p=0.4, inplace=False)
    (1): Linear(in_features=64, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): ReLU()
    (5): Linear(in_features=64, out_features=116, bias=True)
  )
  (attention): Sequential(
    (0): Dropout(p=0.4, inplace=False)
    (1): Linear(in_features=180, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): ReLU()
    (5): Linear(in_features=64, out_features=1, bias=True)
  )
  (attention_soft): Softmax(dim=1)
  (cause_specific): ModuleList(
    (0): Sequential(
      (0): Dropout(p=0.4, inplace=False)
      (1): Linear(in_features=180, out_features=64, bias=True)
      (2): ReLU()
      (3): Linear(in_features=64, out_features=64, bias=True)
      (4): ReLU()
      (5): Linear(in_features=64, out_features=1

### Test the Trained Network
see the code in `eval-ddh.ipynb` file