## Training

In [2]:
SEED = 1234
_EPSILON = 1e-08

In [3]:
import sys
sys.path.append('../auton-survival/')
sys.path.append('../scr/')

import pandas as pd
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import os
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import datetime

from ddh_torch import DynamicDeepHitTorch
import import_data as impt
from utils_eval import c_index, brier_score
from utils_helper import f_get_boosted_trainset
from utils_helper import loss_Log_Likelihood, loss_Ranking, loss_RNN_Prediction
from utils_helper import f_get_risk_predictions, _f_get_pred
from utils_helper import save_checkpoint, load_checkpoint

In [4]:
## for reproducibility
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

In [5]:
def model_eval(model, dataloader, device):
    model.eval()
    total_loss = 0
    Loss1, Loss2, Loss3 = 0, 0, 0

    for i, (b_data, b_data_mi, b_time, b_label, b_mask1, b_mask2, b_mask3) in enumerate(dataloader):
        b_data, _, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
            b_data.float(), b_data_mi.float(), b_time.float(), b_label.float(), \
            b_mask1.float(), b_mask2.float(), b_mask3.float()
        b_data, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
            b_data.to(device), b_time.to(device), b_label.to(device), \
            b_mask1.to(device), b_mask2.to(device), b_mask3.to(device)
        
        longitudinal_prediction, out = model(b_data)
        out = torch.concat([o.unsqueeze(1) for o in out], 1) # (B, num_Event, num_Category)
        loss1 = loss_Log_Likelihood(out, b_label, b_mask1, b_mask2)
        loss2 = loss_Ranking(out, b_time, b_label, b_mask3, num_Event, num_Category)
        loss3 = loss_RNN_Prediction(longitudinal_prediction, b_data)
        loss = alpha * loss1 + beta * loss2 + gamma * loss3
        
        total_loss += loss
        Loss1 += loss1
        Loss2 += loss2
        Loss3 += loss3
        
    return total_loss/len(dataloader), Loss1/len(dataloader), Loss2/len(dataloader), Loss3/len(dataloader)

### Import Dataset

In [9]:
data_mode                   = 'EEG_cr2' 
seed                        = SEED

##### IMPORT DATASET
'''
    num_Category            = max event/censoring time * 1.2
    num_Event               = number of evetns i.e. len(np.unique(label))-1
    max_length              = maximum number of measurements
    x_dim                   = data dimension including delta (1 + num_features)
    x_dim_cont              = dim of continuous features
    x_dim_bin               = dim of binary features
    mask1, mask2, mask3     = used for cause-specific network (FCNet structure)
'''

if data_mode == 'EEG_cr2':
    (x_dim, x_dim_cont, x_dim_bin), (data, time, label, time_original, time_to_last), \
        (mask1, mask2, mask3), (data_mi), trans_discrete_time = \
                    impt.import_dataset('../../eeg/competing-risk/EEG_processed_data_long_cr2_exclude_3.csv')
    
    # This must be changed depending on the datasets, prediction/evaliation times of interest
    pred_time = [6, 12] # prediction time (in hours)
    eval_time = [12, 24,] # hours evaluation time (for C-index and Brier-Score)
else:
    print ('ERROR:  DATA_MODE NOT FOUND !!!')

_, num_Event, num_Category  = np.shape(mask1)  # dim of mask3: [subj, Num_Event, Num_Category]
max_length                  = np.shape(data)[1]


# file_path = '{}_{}'.format(data_mode, seed)

# if not os.path.exists(file_path):
#     os.makedirs(file_path)
    
num_Event, num_Category, max_length



(2, 119, 12)

### Set Hyper-Parameters

In [6]:
from sklearn.model_selection import ParameterGrid
model_name = 'ddh_eeg_cr2'
param_grid = {'dropout': [0.2], 
              'lr_train': [1e-4, 5e-4],
              'beta': [0.5, 1],
              'gamma': [0.05, 0.1, 0.5],
             }
param_grid_list = list(ParameterGrid(param_grid))
print(len(param_grid_list))
param_grid_list

12


[{'beta': 0.5, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0001},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0005},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0001},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0005},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0001},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0005},
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0001},
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0005},
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0001},
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0005},
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0001},
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0005}]

In [7]:
boost_mode = 'ON' #{'ON', 'OFF'}

param = {'batch_size': 32,

         'num_epoch_burn_in': 20,
         'num_epoch': 100,

         'dropout': 0.2, # 1 - keep_prob
         'lr_train': 1e-4,

         'hidden_rnn': 64,
         'hidden_dim_FC': 64,
         'layers_rnn': 2,
         'hidden_att': 2,
         'hidden_cs' : 2,

         'alpha' :1.0,
         'beta'  :1,
         'gamma' :0.1
        }


batch_size = param['batch_size']
num_epoch  = param['num_epoch']

dropout    = param['dropout']
lr_train   = param['lr_train']

hidden_rnn = param['hidden_rnn']
hidden_dim_FC = param['hidden_dim_FC']
layers_rnn = param['layers_rnn']
hidden_att = param['hidden_att']
hidden_cs = param['hidden_cs']

alpha      = param['alpha']
beta       = param['beta']
gamma      = param['gamma']

device     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


### Split Dataset into Train/Valid/Test Sets

In [13]:
### TRAINING-TESTING SPLIT
(tr_data,te_data, tr_data_mi, te_data_mi, tr_time,te_time, tr_label,te_label, tr_time_orinal,te_time_original,
 tr_time_to_last,te_time_to_last,
 tr_mask1,te_mask1, tr_mask2,te_mask2, tr_mask3,te_mask3) = \
        train_test_split(data, data_mi, time, label, time_original, time_to_last,
                         mask1, mask2, mask3, test_size=0.2, stratify=label, random_state=seed) 

(tr_data,va_data, tr_data_mi, va_data_mi, tr_time,va_time, tr_label,va_label, tr_time_orinal,va_time_original,
 tr_time_to_last, va_time_to_last,
 tr_mask1,va_mask1, tr_mask2,va_mask2, tr_mask3,va_mask3) = \
        train_test_split(tr_data, tr_data_mi, tr_time, tr_label, tr_time_orinal, tr_time_to_last,
                         tr_mask1, tr_mask2, tr_mask3, test_size=0.2, stratify=tr_label, random_state=seed) 

if boost_mode == 'ON':
    tr_data, tr_data_mi, tr_time, tr_label, tr_mask1, tr_mask2, tr_mask3 = \
            f_get_boosted_trainset(tr_data, tr_data_mi, tr_time, tr_label, tr_mask1, tr_mask2, tr_mask3)

In [15]:
tr_loader = DataLoader(list(zip(tr_data, tr_data_mi, tr_time, tr_label, tr_mask1, tr_mask2, tr_mask3)), 
                       batch_size=batch_size, shuffle=True)

va_loader = DataLoader(list(zip(va_data, va_data_mi, va_time, va_label, va_mask1, va_mask2, va_mask3)), 
                       batch_size=batch_size, shuffle=False)

te_loader = DataLoader(list(zip(te_data, te_data_mi, te_time, te_label, te_mask1, te_mask2, te_mask3)), 
                       batch_size=batch_size, shuffle=False)

### Train the Network

This implementation does not consider missing values.

In [16]:
num_Event

2

In [17]:
pred_time = [6, 12]
eval_time = [24, 48, 72]

In [18]:
Best_C_Index = 0


for params in tqdm(param_grid_list):
    dropout = params['dropout']
    lr_train = params['lr_train']
    beta = params['beta']
    gamma = params['gamma']
    print('='*80, '\n', params)

    model = \
        DynamicDeepHitTorch(input_dim  = x_dim,
                            output_dim = num_Category,
                            layers_rnn = param['layers_rnn'],
                            hidden_rnn = param['hidden_rnn'],
                            long_param = {'layers': 2*[param['hidden_dim_FC']], 'dropout': dropout}, 
                            att_param  = {'layers': param['hidden_att']*[param['hidden_dim_FC']], 'dropout': dropout}, 
                            cs_param   = {'layers': param['hidden_cs']*[param['hidden_dim_FC']], 'dropout': dropout},
                            risks      = num_Event
                           )
    optimizer_burn_in = torch.optim.Adam(model.parameters(), lr = lr_train)
    optimizer = torch.optim.Adam(model.parameters(), lr = lr_train)


    model.train()
    model.to(device)

    min_valid = 0.1
    early_stop_patience = 10
    ep = 0 # for tracking early stopping

    experiment_runtime = datetime.datetime.now().strftime('%Y%m%d_%H%M')
    tensorboard_string_extension = f"{model_name}_{SEED}"
    tensorboard_string_extension += f"_alpha{alpha}_beta{beta}_gamma{gamma}_BSZ{batch_size}_Lr{lr_train}"
    tensorboard_string_extension += f"_hiddenRNN{hidden_rnn}_hiddenFC{hidden_dim_FC}"
    tensorboard_string_extension += f"_layersRNN{layers_rnn}_hiddenAtt{hidden_att}_hiddenCS{hidden_cs}"
    tensorboard_string_extension += f"_dropout{dropout}"
    tensorboard_string_extension += "_" + experiment_runtime
    tensorboard_filename = tensorboard_string_extension
    writer = SummaryWriter(f"runs/{model_name}/{tensorboard_filename}")


    for epoch in range(param['num_epoch_burn_in']):
        model.train()
        for i, (b_data, b_data_mi, b_time, b_label, b_mask1, b_mask2, b_mask3) in enumerate(tr_loader):
            b_data, _, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
                b_data.float(), b_data_mi.float(), b_time.float(), b_label.float(), \
                b_mask1.float(), b_mask2.float(), b_mask3.float()
            b_data, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
                b_data.to(device), b_time.to(device), b_label.to(device), \
                b_mask1.to(device), b_mask2.to(device), b_mask3.to(device)

            longitudinal_prediction, out = model(b_data)
            out = torch.concat([o.unsqueeze(1) for o in out], 1) # (B, num_Event, num_Category)

            # loss1 = loss_Log_Likelihood(out, b_label, b_mask1, b_mask2)
            # loss2 = loss_Ranking(out, b_time, b_label, b_mask3, num_Event, num_Category)
            loss3 = loss_RNN_Prediction(longitudinal_prediction, b_data)
            tr_loss = loss3

            optimizer_burn_in.zero_grad()
            tr_loss.backward()
            optimizer_burn_in.step()


    for epoch in range(num_epoch):
        model.train()
        for i, (b_data, b_data_mi, b_time, b_label, b_mask1, b_mask2, b_mask3) in enumerate(tr_loader):
            b_data, _, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
                b_data.float(), b_data_mi.float(), b_time.float(), b_label.float(), \
                b_mask1.float(), b_mask2.float(), b_mask3.float()
            b_data, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
                b_data.to(device), b_time.to(device), b_label.to(device), \
                b_mask1.to(device), b_mask2.to(device), b_mask3.to(device)

            longitudinal_prediction, out = model(b_data)
            out = torch.concat([o.unsqueeze(1) for o in out], 1) # (B, num_Event, num_Category)

            loss1 = loss_Log_Likelihood(out, b_label, b_mask1, b_mask2)
            loss2 = loss_Ranking(out, b_time, b_label, b_mask3, num_Event, num_Category)
            loss3 = loss_RNN_Prediction(longitudinal_prediction, b_data)
            tr_loss = alpha * loss1 + beta * loss2 + gamma * loss3

            optimizer.zero_grad()
            tr_loss.backward()
            optimizer.step()

        model.eval()
        tr_loss, tr_loss1, tr_loss2, tr_loss3 = model_eval(model, tr_loader, device)
        va_loss, va_loss1, va_loss2, va_loss3 = model_eval(model, va_loader, device)

        ### VALIDATION  (based on average C-index of our interest)
        risk_all = f_get_risk_predictions(model, va_data, pred_time, eval_time, device)

        for p, p_time in enumerate(pred_time):
            pred_horizon = int(p_time)
            val_result1 = np.zeros([num_Event, len(eval_time)])

            for t, t_time in enumerate(eval_time):                
                eval_horizon = int(t_time) + pred_horizon
                for k in range(num_Event):
                    val_result1[k, t] = c_index(risk_all[k][:, p, t], va_time, (va_label[:,0] == k+1).astype(int), eval_horizon) #-1 for no event (not comparable)

            if p == 0:
                val_final1 = val_result1
            else:
                val_final1 = np.append(val_final1, val_result1, axis=0)

        tmp_valid = np.mean(val_final1)

        if tmp_valid >  min_valid:
            ep = 0
            min_valid = tmp_valid
            save_checkpoint(f"checkpoints/{model_name}/{writer.get_logdir().split('/')[-1]}_{SEED}.pt", model, optimizer, va_loss)
            print( 'updated.... average c-index = ' + str('%.4f' %(tmp_valid)))
        else:
            ep += 1  
            print("Counter {} of {}".format(ep, early_stop_patience))

        if ep >= early_stop_patience:
            print("Early stopping with best_cindex: {:.4f}".format(min_valid), 
                  "and val_cindex for this epoch: {:.4f}".format(tmp_valid))
            break

        # Tensorboard
        writer.add_scalar("Loss-Training/Loss", tr_loss, epoch)
        writer.add_scalar("Loss-Training/NLL Loss", tr_loss1, epoch)
        writer.add_scalar("Loss-Training/Ranking Loss", tr_loss2, epoch)
        writer.add_scalar("Loss-Training/LSTM MSELoss", tr_loss3, epoch)

        writer.add_scalar("Loss-Validation/Loss", va_loss, epoch)
        writer.add_scalar("Loss-Validation/NLL Loss", va_loss1, epoch)
        writer.add_scalar("Loss-Validation/Ranking Loss", va_loss2, epoch)
        writer.add_scalar("Loss-Validation/LSTM MSELoss", va_loss3, epoch)

        writer.add_scalar("C-index/Validation Avg C-index", min_valid, epoch)

        print('epoch {}, tr_loss: {:.4f}, va_loss: {:.4f}, va c-index: {:.4f}'.\
                      format(epoch+1, tr_loss, va_loss, tmp_valid))
        
    if min_valid > Best_C_Index:
        Best_C_Index = min_valid
        print('Better valudation c-index!!!')
        print(writer.get_logdir().split('/')[-1])

  0%|                                                              | 0/12 [00:00<?, ?it/s]

 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0001}
updated.... average c-index = 0.6241
epoch 1, tr_loss: 29.2581, va_loss: 32.4258, va c-index: 0.6241
updated.... average c-index = 0.6324
epoch 2, tr_loss: 29.1473, va_loss: 32.3743, va c-index: 0.6324
updated.... average c-index = 0.6518
epoch 3, tr_loss: 28.7799, va_loss: 32.0537, va c-index: 0.6518
Counter 1 of 10
epoch 4, tr_loss: 28.3154, va_loss: 31.9413, va c-index: 0.6405
updated.... average c-index = 0.6682
epoch 5, tr_loss: 27.7768, va_loss: 31.7710, va c-index: 0.6682
Counter 1 of 10
epoch 6, tr_loss: 27.8035, va_loss: 31.7104, va c-index: 0.6562
updated.... average c-index = 0.6718
epoch 7, tr_loss: 27.2197, va_loss: 31.6794, va c-index: 0.6718
Counter 1 of 10
epoch 8, tr_loss: 27.1096, va_loss: 31.9652, va c-index: 0.6586
Counter 2 of 10
epoch 9, tr_loss: 26.7812, va_loss: 31.9652, va c-index: 0.6658
Counter 3 of 10
epoch 10, tr_loss: 26.3769, va_loss: 31.8220, va c-index: 0.6679
updated.... average c-index 

  8%|████▎                                              | 1/12 [10:04<1:50:44, 604.05s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.7341 and val_cindex for this epoch: 0.7077
Better valudation c-index!!!
ddh_eeg_cr2_4444_alpha1.0_beta0.5_gamma0.05_BSZ32_Lr0.0001_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230417_2319
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0005}
updated.... average c-index = 0.6400
epoch 1, tr_loss: 26.2491, va_loss: 66.0308, va c-index: 0.6400
Counter 1 of 10
epoch 2, tr_loss: 25.4019, va_loss: 44.8905, va c-index: 0.4827
Counter 2 of 10
epoch 3, tr_loss: 22.7927, va_loss: 41.5435, va c-index: 0.5364
Counter 3 of 10
epoch 4, tr_loss: 23.9146, va_loss: 41.6518, va c-index: 0.5231
Counter 4 of 10
epoch 5, tr_loss: 23.5055, va_loss: 42.9859, va c-index: 0.5879
updated.... average c-index = 0.6451
epoch 6, tr_loss: 22.4418, va_loss: 40.1904, va c-index: 0.6451
Counter 1 of 10
epoch 7, tr_loss: 24.1952, va_loss: 42.1015, va c-index: 0.4802
Counter 2 of 10
epoch 8, tr_loss: 22.4827, va_loss: 40.5448, va c-ind

 17%|████████▌                                          | 2/12 [14:23<1:06:52, 401.28s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6451 and val_cindex for this epoch: 0.4455
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0001}
updated.... average c-index = 0.5926
epoch 1, tr_loss: 51.1691, va_loss: 59.2319, va c-index: 0.5926
Counter 1 of 10
epoch 2, tr_loss: 50.7063, va_loss: 59.5248, va c-index: 0.5669
Counter 2 of 10
epoch 3, tr_loss: 50.0586, va_loss: 59.1433, va c-index: 0.5864
Counter 3 of 10
epoch 4, tr_loss: 49.6490, va_loss: 58.9142, va c-index: 0.5918
updated.... average c-index = 0.6061
epoch 5, tr_loss: 48.8513, va_loss: 59.1100, va c-index: 0.6061
Counter 1 of 10
epoch 6, tr_loss: 48.4984, va_loss: 59.1984, va c-index: 0.6001
updated.... average c-index = 0.6079
epoch 7, tr_loss: 47.9488, va_loss: 59.1113, va c-index: 0.6079
updated.... average c-index = 0.6179
epoch 8, tr_loss: 47.6652, va_loss: 59.5855, va c-index: 0.6179
updated.... average c-index = 0.6330
epoch 9, tr_loss: 47.0779, va_loss: 59.7210, va c-index: 0.6330
Counter 1 of 10

 25%|████████████▊                                      | 3/12 [26:32<1:22:38, 550.92s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.7213 and val_cindex for this epoch: 0.6932
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0005}
updated.... average c-index = 0.5728
epoch 1, tr_loss: 40.3775, va_loss: 60.4529, va c-index: 0.5728
updated.... average c-index = 0.6274
epoch 2, tr_loss: 37.2916, va_loss: 60.5394, va c-index: 0.6274
updated.... average c-index = 0.6321
epoch 3, tr_loss: 37.4905, va_loss: 63.7424, va c-index: 0.6321
updated.... average c-index = 0.6615
epoch 4, tr_loss: 36.5922, va_loss: 61.4195, va c-index: 0.6615
updated.... average c-index = 0.6628
epoch 5, tr_loss: 35.3061, va_loss: 62.6683, va c-index: 0.6628
updated.... average c-index = 0.6880
epoch 6, tr_loss: 32.8819, va_loss: 65.7274, va c-index: 0.6880
Counter 1 of 10
epoch 7, tr_loss: 30.9986, va_loss: 63.1559, va c-index: 0.6456
Counter 2 of 10
epoch 8, tr_loss: 30.6615, va_loss: 62.3812, va c-index: 0.6573
Counter 3 of 10
epoch 9, tr_loss: 32.0870, va_loss: 65.4162, va c-index: 0

 33%|█████████████████▋                                   | 4/12 [30:40<57:32, 431.52s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6880 and val_cindex for this epoch: 0.6291
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0001}
updated.... average c-index = 0.6500
epoch 1, tr_loss: 233.2632, va_loss: 275.2896, va c-index: 0.6500
Counter 1 of 10
epoch 2, tr_loss: 228.8344, va_loss: 274.4848, va c-index: 0.6316
Counter 2 of 10
epoch 3, tr_loss: 226.9907, va_loss: 274.9715, va c-index: 0.6187
Counter 3 of 10
epoch 4, tr_loss: 224.4840, va_loss: 274.1079, va c-index: 0.6229
Counter 4 of 10
epoch 5, tr_loss: 221.8472, va_loss: 274.7260, va c-index: 0.6304
Counter 5 of 10
epoch 6, tr_loss: 219.3192, va_loss: 275.1768, va c-index: 0.6397
updated.... average c-index = 0.6615
epoch 7, tr_loss: 217.7158, va_loss: 273.3860, va c-index: 0.6615
updated.... average c-index = 0.6705
epoch 8, tr_loss: 217.3849, va_loss: 275.5955, va c-index: 0.6705
Counter 1 of 10
epoch 9, tr_loss: 213.6604, va_loss: 274.7610, va c-index: 0.6690
Counter 2 of 10
epoch 10, tr_loss: 210.

 42%|██████████████████████                               | 5/12 [39:46<55:09, 472.85s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.7018 and val_cindex for this epoch: 0.6759
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0005}
updated.... average c-index = 0.6401
epoch 1, tr_loss: 162.1290, va_loss: 277.8737, va c-index: 0.6401
Counter 1 of 10
epoch 2, tr_loss: 159.7836, va_loss: 279.7472, va c-index: 0.6157
updated.... average c-index = 0.6633
epoch 3, tr_loss: 145.2635, va_loss: 293.2363, va c-index: 0.6633
updated.... average c-index = 0.7030
epoch 4, tr_loss: 147.2704, va_loss: 285.4144, va c-index: 0.7030
Counter 1 of 10
epoch 5, tr_loss: 201.0717, va_loss: 275.6272, va c-index: 0.6506
Counter 2 of 10
epoch 6, tr_loss: 147.2321, va_loss: 292.2567, va c-index: 0.6567
Counter 3 of 10
epoch 7, tr_loss: 142.5302, va_loss: 289.7346, va c-index: 0.6221
Counter 4 of 10
epoch 8, tr_loss: 143.6475, va_loss: 280.7945, va c-index: 0.6556
Counter 5 of 10
epoch 9, tr_loss: 154.7733, va_loss: 301.1443, va c-index: 0.6742
Counter 6 of 10
epoch 10, tr_loss: 188.

 50%|██████████████████████████▌                          | 6/12 [43:41<39:10, 391.72s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.7030 and val_cindex for this epoch: 0.6594
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0001}
updated.... average c-index = 0.5675
epoch 1, tr_loss: 32.4302, va_loss: 35.3441, va c-index: 0.5675
Counter 1 of 10
epoch 2, tr_loss: 32.2770, va_loss: 35.1112, va c-index: 0.5496
Counter 2 of 10
epoch 3, tr_loss: 31.4362, va_loss: 35.0999, va c-index: 0.5633
updated.... average c-index = 0.5731
epoch 4, tr_loss: 31.0298, va_loss: 35.0130, va c-index: 0.5731
updated.... average c-index = 0.5807
epoch 5, tr_loss: 30.8558, va_loss: 35.0022, va c-index: 0.5807
updated.... average c-index = 0.5939
epoch 6, tr_loss: 30.2736, va_loss: 34.7656, va c-index: 0.5939
updated.... average c-index = 0.6070
epoch 7, tr_loss: 29.7656, va_loss: 34.6279, va c-index: 0.6070
updated.... average c-index = 0.6110
epoch 8, tr_loss: 29.8436, va_loss: 34.5217, va c-index: 0.6110
updated.... average c-index = 0.6246
epoch 9, tr_loss: 29.1994, va_loss: 34

 58%|██████████████████████████████▉                      | 7/12 [53:07<37:23, 448.71s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6957 and val_cindex for this epoch: 0.6795
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0005}
updated.... average c-index = 0.5343
epoch 1, tr_loss: 27.8226, va_loss: 39.1814, va c-index: 0.5343
updated.... average c-index = 0.5383
epoch 2, tr_loss: 28.9220, va_loss: 38.3002, va c-index: 0.5383
updated.... average c-index = 0.6394
epoch 3, tr_loss: 22.7964, va_loss: 36.1436, va c-index: 0.6394
updated.... average c-index = 0.6631
epoch 4, tr_loss: 21.4909, va_loss: 36.3039, va c-index: 0.6631
Counter 1 of 10
epoch 5, tr_loss: 20.6686, va_loss: 35.5040, va c-index: 0.6341
Counter 2 of 10
epoch 6, tr_loss: 23.3110, va_loss: 40.4191, va c-index: 0.6507
Counter 3 of 10
epoch 7, tr_loss: 24.6861, va_loss: 38.6751, va c-index: 0.6138
Counter 4 of 10
epoch 8, tr_loss: 25.2184, va_loss: 37.2571, va c-index: 0.6354
updated.... average c-index = 0.6667
epoch 9, tr_loss: 20.5597, va_loss: 36.6627, va c-index: 0.6667
updated.... aver

 67%|███████████████████████████████████▎                 | 8/12 [58:12<26:51, 402.98s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6868 and val_cindex for this epoch: 0.5804
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0001}
updated.... average c-index = 0.6081
epoch 1, tr_loss: 55.4460, va_loss: 61.1326, va c-index: 0.6081
Counter 1 of 10
epoch 2, tr_loss: 53.9487, va_loss: 61.4523, va c-index: 0.5644
Counter 2 of 10
epoch 3, tr_loss: 52.3394, va_loss: 60.8174, va c-index: 0.5790
Counter 3 of 10
epoch 4, tr_loss: 52.0392, va_loss: 60.9556, va c-index: 0.5894
Counter 4 of 10
epoch 5, tr_loss: 50.8215, va_loss: 60.4553, va c-index: 0.6024
Counter 5 of 10
epoch 6, tr_loss: 50.1867, va_loss: 60.5229, va c-index: 0.6024
Counter 6 of 10
epoch 7, tr_loss: 50.2616, va_loss: 60.9943, va c-index: 0.6044
updated.... average c-index = 0.6178
epoch 8, tr_loss: 50.0815, va_loss: 60.4253, va c-index: 0.6178
updated.... average c-index = 0.6351
epoch 9, tr_loss: 49.1372, va_loss: 60.6757, va c-index: 0.6351
Counter 1 of 10
epoch 10, tr_loss: 48.2027, va_loss: 61.150

 75%|██████████████████████████████████████▎            | 9/12 [1:06:59<22:05, 441.96s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.7350 and val_cindex for this epoch: 0.6870
Better valudation c-index!!!
ddh_eeg_cr2_4444_alpha1.0_beta1_gamma0.1_BSZ32_Lr0.0001_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230418_0017
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0005}
updated.... average c-index = 0.6139
epoch 1, tr_loss: 44.6632, va_loss: 66.2044, va c-index: 0.6139
Counter 1 of 10
epoch 2, tr_loss: 42.0538, va_loss: 67.5795, va c-index: 0.5906
updated.... average c-index = 0.6215
epoch 3, tr_loss: 40.7763, va_loss: 71.5852, va c-index: 0.6215
updated.... average c-index = 0.6445
epoch 4, tr_loss: 37.9863, va_loss: 69.3890, va c-index: 0.6445
updated.... average c-index = 0.6448
epoch 5, tr_loss: 39.2986, va_loss: 69.5982, va c-index: 0.6448
updated.... average c-index = 0.6759
epoch 6, tr_loss: 36.3953, va_loss: 67.4125, va c-index: 0.6759
Counter 1 of 10
epoch 7, tr_loss: 41.3592, va_loss: 70.8313, va c-index: 0.4682
Counter 2 of

 83%|█████████████████████████████████████████▋        | 10/12 [1:10:12<12:09, 364.83s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6759 and val_cindex for this epoch: 0.4807
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0001}
updated.... average c-index = 0.5629
epoch 1, tr_loss: 235.0476, va_loss: 274.0296, va c-index: 0.5629
updated.... average c-index = 0.5676
epoch 2, tr_loss: 233.3145, va_loss: 274.0872, va c-index: 0.5676
updated.... average c-index = 0.5716
epoch 3, tr_loss: 228.2397, va_loss: 273.4285, va c-index: 0.5716
Counter 1 of 10
epoch 4, tr_loss: 222.5135, va_loss: 273.0786, va c-index: 0.5696
Counter 2 of 10
epoch 5, tr_loss: 222.2067, va_loss: 273.3729, va c-index: 0.5613
Counter 3 of 10
epoch 6, tr_loss: 220.7585, va_loss: 272.5696, va c-index: 0.5659
updated.... average c-index = 0.5744
epoch 7, tr_loss: 218.0659, va_loss: 272.6150, va c-index: 0.5744
updated.... average c-index = 0.5811
epoch 8, tr_loss: 216.9417, va_loss: 272.9284, va c-index: 0.5811
updated.... average c-index = 0.5859
epoch 9, tr_loss: 209.6428, va_loss: 273.069

 92%|█████████████████████████████████████████████▊    | 11/12 [1:14:44<05:36, 336.64s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6696 and val_cindex for this epoch: 0.6636
 {'beta': 1, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0005}
updated.... average c-index = 0.5645
epoch 1, tr_loss: 171.0744, va_loss: 291.4052, va c-index: 0.5645
updated.... average c-index = 0.5940
epoch 2, tr_loss: 162.1983, va_loss: 288.5358, va c-index: 0.5940
Counter 1 of 10
epoch 3, tr_loss: 164.1179, va_loss: 293.5221, va c-index: 0.4916
Counter 2 of 10
epoch 4, tr_loss: 153.6421, va_loss: 299.4476, va c-index: 0.5916
updated.... average c-index = 0.6455
epoch 5, tr_loss: 140.3203, va_loss: 302.1269, va c-index: 0.6455
updated.... average c-index = 0.6601
epoch 6, tr_loss: 127.2505, va_loss: 314.2252, va c-index: 0.6601
Counter 1 of 10
epoch 7, tr_loss: 138.5870, va_loss: 290.7132, va c-index: 0.6116
Counter 2 of 10
epoch 8, tr_loss: 132.2811, va_loss: 300.1582, va c-index: 0.6181
updated.... average c-index = 0.6775
epoch 9, tr_loss: 133.7939, va_loss: 289.4071, va c-index: 0.6775

100%|██████████████████████████████████████████████████| 12/12 [1:20:34<00:00, 402.84s/it]

Counter 10 of 10
Early stopping with best_cindex: 0.6998 and val_cindex for this epoch: 0.6450





In [19]:
model

DynamicDeepHitTorch(
  (embedding): LSTM(116, 64, num_layers=2, bias=False, batch_first=True)
  (longitudinal): Sequential(
    (0): Dropout(p=0.4, inplace=False)
    (1): Linear(in_features=64, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): ReLU()
    (5): Linear(in_features=64, out_features=116, bias=True)
  )
  (attention): Sequential(
    (0): Dropout(p=0.4, inplace=False)
    (1): Linear(in_features=180, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): ReLU()
    (5): Linear(in_features=64, out_features=1, bias=True)
  )
  (attention_soft): Softmax(dim=1)
  (cause_specific): ModuleList(
    (0): Sequential(
      (0): Dropout(p=0.4, inplace=False)
      (1): Linear(in_features=180, out_features=64, bias=True)
      (2): ReLU()
      (3): Linear(in_features=64, out_features=64, bias=True)
      (4): ReLU()
      (5): Linear(in_features=64, out_features=1

### Test the Trained Network

In [19]:
_ = load_checkpoint(f'checkpoints/EEG_cr2_{SEED}.pt', model, optimizer, 'cpu')

pred_time = [4, 8, 12]
eval_time = [4, 8, 12, 24, 36, 48, 60, 72]

risk_all = f_get_risk_predictions(model, te_data, pred_time, eval_time, device=device)

for p, p_time in enumerate(pred_time):
    pred_horizon = int(p_time)
    result1, result2 = np.zeros([num_Event, len(eval_time)]), np.zeros([num_Event, len(eval_time)])

    for t, t_time in enumerate(eval_time):                
        eval_horizon = int(t_time) + pred_horizon
        for k in range(num_Event):
            result1[k, t] = c_index(risk_all[k][:, p, t], te_time, (te_label[:,0] == k+1).astype(int), eval_horizon) #-1 for no event (not comparable)
            result2[k, t] = brier_score(risk_all[k][:, p, t], te_time, (te_label[:,0] == k+1).astype(int), eval_horizon) #-1 for no event (not comparable)
    
    if p == 0:
        final1, final2 = result1, result2
    else:
        final1, final2 = np.append(final1, result1, axis=0), np.append(final2, result2, axis=0)
        
        
row_header = []
for p_time in pred_time:
    for k in range(num_Event):
        row_header.append('pred_time {}: event_{}'.format(p_time, k+1))
            
col_header = []
for t_time in eval_time:
    col_header.append('eval_time {}'.format(t_time))

# c-index result
df1 = pd.DataFrame(final1, index = row_header, columns=col_header)

# brier-score result
df2 = pd.DataFrame(final2, index = row_header, columns=col_header)

### PRINT RESULTS
print('========================================================')
print('--------------------------------------------------------')
print('- C-INDEX: ')
print(df1)
print('--------------------------------------------------------')
print('- BRIER-SCORE: ')
print(df2)
print('========================================================')

--------------------------------------------------------
- C-INDEX: 
                       eval_time 4  eval_time 8  eval_time 12  eval_time 24  \
pred_time 4: event_1     -1.000000    -1.000000      0.750427      0.626849   
pred_time 4: event_2     -1.000000     0.885287      0.918117      0.910617   
pred_time 8: event_1     -1.000000     0.770085      0.681452      0.606364   
pred_time 8: event_2      0.910224     0.850563      0.841351      0.852697   
pred_time 12: event_1     0.814530     0.668779      0.717061      0.616516   
pred_time 12: event_2     0.906858     0.901740      0.885681      0.836158   

                       eval_time 36  eval_time 48  eval_time 60  eval_time 72  
pred_time 4: event_1       0.609666      0.585309      0.596739      0.622711  
pred_time 4: event_2       0.891424      0.886954      0.883024      0.876623  
pred_time 8: event_1       0.596252      0.573063      0.571680      0.595925  
pred_time 8: event_2       0.817641      0.791897      0.

In [20]:
out = _f_get_pred(model, te_data, 12, device)

In [21]:
df1[::2]

Unnamed: 0,eval_time 4,eval_time 8,eval_time 12,eval_time 24,eval_time 36,eval_time 48,eval_time 60,eval_time 72
pred_time 4: event_1,-1.0,-1.0,0.750427,0.626849,0.609666,0.585309,0.596739,0.622711
pred_time 8: event_1,-1.0,0.770085,0.681452,0.606364,0.596252,0.573063,0.57168,0.595925
pred_time 12: event_1,0.81453,0.668779,0.717061,0.616516,0.6288,0.61104,0.623035,0.642319


In [22]:
df1[1::2]

Unnamed: 0,eval_time 4,eval_time 8,eval_time 12,eval_time 24,eval_time 36,eval_time 48,eval_time 60,eval_time 72
pred_time 4: event_2,-1.0,0.885287,0.918117,0.910617,0.891424,0.886954,0.883024,0.876623
pred_time 8: event_2,0.910224,0.850563,0.841351,0.852697,0.817641,0.791897,0.806404,0.83088
pred_time 12: event_2,0.906858,0.90174,0.885681,0.836158,0.834582,0.830669,0.805195,0.823982


In [268]:
te_data[0, :, 0]

array([1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.4, 0. ])

## Evaluation Metrics

### c-index

In [29]:
def eval_model(model, te_data, pred_time, eval_time, num_Event, device='cuda', te_label_cp=None):
    # pred_time = [4, 8, 12]
    # eval_time = [12, 24, 36, 48, 60, 72]
    
    if te_label_cp is None:
        te_label_cp = te_label
    
    risk_all = f_get_risk_predictions(model, te_data, pred_time, eval_time, device=device)

    for p, p_time in enumerate(pred_time):
        pred_horizon = int(p_time)
        result1, result2 = np.zeros([num_Event, len(eval_time)]), np.zeros([num_Event, len(eval_time)])

        for t, t_time in enumerate(eval_time):                
            eval_horizon = int(t_time) + pred_horizon
            for k in range(num_Event):
                result1[k, t] = c_index(risk_all[k][:, p, t], te_time, (te_label_cp[:,0] == k+1).astype(int), eval_horizon) #-1 for no event (not comparable)
                result2[k, t] = brier_score(risk_all[k][:, p, t], te_time, (te_label_cp[:,0] == k+1).astype(int), eval_horizon) #-1 for no event (not comparable)

        if p == 0:
            final1, final2 = result1, result2
        else:
            final1, final2 = np.append(final1, result1, axis=0), np.append(final2, result2, axis=0)


    row_header = []
    for p_time in pred_time:
        for k in range(num_Event):
            row_header.append('pred_time {}: event_{}'.format(p_time, k+1))

    col_header = []
    for t_time in eval_time:
        col_header.append('eval_time {}'.format(t_time))

    # c-index result
    df1 = pd.DataFrame(final1, index = row_header, columns=col_header)

    # brier-score result
    df2 = pd.DataFrame(final2, index = row_header, columns=col_header)

    ### PRINT RESULTS
    print('========================================================')
    print('--------------------------------------------------------')
    print('- C-INDEX: ')
    print(df1)
    print('--------------------------------------------------------')
    print('- BRIER-SCORE: ')
    print(df2)
    print('========================================================')
    
    return df1, df2

In [50]:
pred_time = [6,12]
eval_time = [24, 48, 72]

seed_list = [1234, 1111, 2222, 3333, 4444]
best_pt_path_list = ['ddh_eeg_cr2_1234_alpha1.0_beta0.5_gamma0.5_BSZ32_Lr0.0001_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230417_1704',
                     'ddh_eeg_cr2_1111_alpha1.0_beta0.5_gamma0.05_BSZ32_Lr0.0005_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230417_1825',
                     'ddh_eeg_cr2_2222_alpha1.0_beta1_gamma0.5_BSZ32_Lr0.0005_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230417_2120',
                     'ddh_eeg_cr2_3333_alpha1.0_beta0.5_gamma0.05_BSZ32_Lr0.0001_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230417_2140',
                     'ddh_eeg_cr2_4444_alpha1.0_beta1_gamma0.1_BSZ32_Lr0.0001_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.2_20230418_0017'
                    ]

C_INDEX = []

for r, seed in tqdm(enumerate(seed_list)):
    best_pt_path = best_pt_path_list[r]

    ### TRAINING-TESTING SPLIT
    (tr_data,te_data, tr_data_mi, te_data_mi, tr_time,te_time, tr_label,te_label, tr_time_orinal,te_time_original,
     tr_time_to_last,te_time_to_last,
     tr_mask1,te_mask1, tr_mask2,te_mask2, tr_mask3,te_mask3) = \
            train_test_split(data, data_mi, time, label, time_original, time_to_last,
                             mask1, mask2, mask3, test_size=0.2, stratify=label, random_state=seed) 

    (tr_data,va_data, tr_data_mi, va_data_mi, tr_time,va_time, tr_label,va_label, tr_time_orinal,va_time_original,
     tr_time_to_last, va_time_to_last,
     tr_mask1,va_mask1, tr_mask2,va_mask2, tr_mask3,va_mask3) = \
            train_test_split(tr_data, tr_data_mi, tr_time, tr_label, tr_time_orinal, tr_time_to_last,
                             tr_mask1, tr_mask2, tr_mask3, test_size=0.2, stratify=tr_label, random_state=seed) 

    model_name = 'ddh_eeg_cr2'
    num_Event = 2

    model = \
        DynamicDeepHitTorch(input_dim  = x_dim,
                            output_dim = num_Category,
                            layers_rnn = param['layers_rnn'],
                            hidden_rnn = param['hidden_rnn'],
                            long_param = {'layers': 2*[param['hidden_dim_FC']], 'dropout': dropout}, 
                            att_param  = {'layers': param['hidden_att']*[param['hidden_dim_FC']], 'dropout': dropout}, 
                            cs_param   = {'layers': param['hidden_cs']*[param['hidden_dim_FC']], 'dropout': dropout},
                            risks      = num_Event
                           )
    optimizer = torch.optim.Adam(model.parameters(), lr = lr_train)
    _ = load_checkpoint(f'checkpoints/{model_name}/{best_pt_path}_{seed}.pt', model, optimizer, device)
    model = model.to(device)
    
    df1, df2 = eval_model(model, te_data, pred_time, eval_time, num_Event)
    C_INDEX.append(df2)

1it [00:00,  3.74it/s]

--------------------------------------------------------
- C-INDEX: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.496344      0.503831      0.558388
pred_time 6: event_2       0.874130      0.857475      0.855550
pred_time 12: event_1      0.460449      0.510670      0.554730
pred_time 12: event_2      0.850216      0.862753      0.858834
--------------------------------------------------------
- BRIER-SCORE: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.027177      0.072449      0.111484
pred_time 6: event_2       0.037428      0.086354      0.115500
pred_time 12: event_1      0.038222      0.082698      0.114480
pred_time 12: event_2      0.050286      0.096415      0.120787


2it [00:00,  4.05it/s]

--------------------------------------------------------
- C-INDEX: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.640678      0.641945      0.639076
pred_time 6: event_2       0.822979      0.848123      0.856087
pred_time 12: event_1      0.675402      0.676223      0.686912
pred_time 12: event_2      0.813622      0.850089      0.865605
--------------------------------------------------------
- BRIER-SCORE: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.031956      0.076812      0.102032
pred_time 6: event_2       0.042457      0.089546      0.110210
pred_time 12: event_1      0.043372      0.080063      0.104756
pred_time 12: event_2      0.055686      0.093626      0.112952


3it [00:00,  4.18it/s]

--------------------------------------------------------
- C-INDEX: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.667736      0.666565      0.701347
pred_time 6: event_2       0.846583      0.876269      0.866750
pred_time 12: event_1      0.696252      0.746027      0.770207
pred_time 12: event_2      0.832753      0.861936      0.862984
--------------------------------------------------------
- BRIER-SCORE: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.037117      0.084240      0.106983
pred_time 6: event_2       0.047793      0.093553      0.110139
pred_time 12: event_1      0.048314      0.089795      0.108699
pred_time 12: event_2      0.059579      0.098336      0.112768


5it [00:01,  4.15it/s]

--------------------------------------------------------
- C-INDEX: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.684823      0.582977      0.626028
pred_time 6: event_2       0.873071      0.847532      0.848385
pred_time 12: event_1      0.520256      0.542955      0.577828
pred_time 12: event_2      0.856154      0.854307      0.848460
--------------------------------------------------------
- BRIER-SCORE: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.030180      0.078372      0.103766
pred_time 6: event_2       0.040222      0.091988      0.113595
pred_time 12: event_1      0.045223      0.084674      0.109637
pred_time 12: event_2      0.057722      0.097862      0.118471
--------------------------------------------------------
- C-INDEX: 
                       eval_time 24  eval_time 48  eval_time 72
pred_time 6: event_1       0.701132      0.715106      0.683059
pred_time 6: event_2 




In [34]:
C_INDEX[0]

Unnamed: 0,eval_time 24,eval_time 48,eval_time 72
pred_time 6: event_1,0.496344,0.503831,0.558388
pred_time 6: event_2,0.87413,0.857475,0.85555
pred_time 12: event_1,0.460449,0.51067,0.55473
pred_time 12: event_2,0.850216,0.862753,0.858834


In [51]:
c_index_mean = np.zeros_like(C_INDEX[0].values)
c_index_std = np.zeros_like(C_INDEX[0].values)
for i in range(c_index_mean.shape[0]):
    for j in range(c_index_mean.shape[1]):
        c_list = [C_INDEX[k].iloc[i,j] for k in range(len(seed_list))]
        c_list = np.array(c_list)
        c_index_mean[i,j] = c_list.mean()
        c_index_std[i,j]  = c_list.std()
        
C_index_mean = C_INDEX[0].copy()
C_index_std = C_INDEX[0].copy()
C_index_mean.iloc[:, :] = c_index_mean
C_index_std.iloc[:, :] = c_index_std

In [53]:
C_index_std

Unnamed: 0,eval_time 24,eval_time 48,eval_time 72
pred_time 6: event_1,0.003733,0.004876,0.004385
pred_time 6: event_2,0.00398,0.003447,0.002712
pred_time 12: event_1,0.003278,0.004544,0.003439
pred_time 12: event_2,0.003122,0.0035,0.003575
