In [1]:
SEED = 4444
_EPSILON = 1e-08

In [2]:
import sys
sys.path.append('../scr/')

import pandas as pd
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import os
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import datetime

from ddrsa import DDRSA
import import_data as impt
from utils_eval import c_index, brier_score
from utils_helper import f_get_boosted_trainset
from utils_helper import loss_Log_Likelihood, loss_Ranking, loss_RNN_Prediction
from utils_helper import f_get_risk_predictions, _f_get_pred
from utils_helper import save_checkpoint, load_checkpoint

In [3]:
## for reproducibility
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

In [5]:
def model_eval(model, dataloader, device):
    model.eval()
    total_loss = 0
    Loss1, Loss2, Loss3 = 0, 0, 0
    
    with torch.no_grad():
        for i, (b_data, b_data_mi, b_time, b_label, b_mask1, b_mask2, b_mask3) in enumerate(dataloader):
            b_data, _, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
                b_data.float(), b_data_mi.float(), b_time.float(), b_label.float(), \
                b_mask1.float(), b_mask2.float(), b_mask3.float()
            b_data, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
                b_data.to(device), b_time.to(device), b_label.to(device), \
                b_mask1.to(device), b_mask2.to(device), b_mask3.to(device)

            longitudinal_prediction, out = model(b_data)
            out = torch.concat([o.unsqueeze(1) for o in out], 1) # (B, num_Event, num_Category)
            loss1 = loss_Log_Likelihood(out, b_label, b_mask1, b_mask2)
            loss2 = loss_Ranking(out, b_time, b_label, b_mask3, num_Event, num_Category)
            loss3 = loss_RNN_Prediction(longitudinal_prediction, b_data)
            loss = alpha * loss1 + beta * loss2 + gamma * loss3

            total_loss += loss
            Loss1 += loss1
            Loss2 += loss2
            Loss3 += loss3
        
    return total_loss/len(dataloader), Loss1/len(dataloader), Loss2/len(dataloader), Loss3/len(dataloader)

### Import Dataset

In [8]:
import importlib

import import_data
importlib.reload(import_data)

import import_data as impt

In [9]:
data_mode                   = 'EEG_cr3' 
data_path                   = '../../eeg/competing-risk/EEG_processed_data_long_by_death_cat_exclude_3.csv'
seed                        = SEED

##### IMPORT DATASET
'''
    num_Category            = max event/censoring time * 1.2
    num_Event               = number of evetns i.e. len(np.unique(label))-1
    max_length              = maximum number of measurements
    x_dim                   = data dimension including delta (1 + num_features)
    x_dim_cont              = dim of continuous features
    x_dim_bin               = dim of binary features
    mask1, mask2, mask3     = used for cause-specific network (FCNet structure)
'''

if data_mode == 'EEG_cr3':
    (x_dim, x_dim_cont, x_dim_bin), (data, time, label, time_original, time_to_last), \
        (mask1, mask2, mask3), (data_mi), trans_discrete_time = \
                    impt.import_dataset(data_path)
    
    # This must be changed depending on the datasets, prediction/evaliation times of interest
    pred_time = [6, 12] # prediction time (in hours)
    eval_time = [12, 24,] # hours evaluation time (for C-index and Brier-Score)
else:
    print ('ERROR:  DATA_MODE NOT FOUND !!!')

_, num_Event, num_Category  = np.shape(mask1)  # dim of mask3: [subj, Num_Event, Num_Category]
max_length                  = np.shape(data)[1]


# file_path = '{}_{}'.format(data_mode, seed)

# if not os.path.exists(file_path):
#     os.makedirs(file_path)
    
num_Event, num_Category, max_length



(3, 120, 12)

### Set Hyper-Parameters

In [10]:
# ddh_eeg_cr4_1234_alpha1.0_beta1_gamma0.1_BSZ32_Lr0.0005_hiddenRNN64_hiddenFC64_layersRNN2_hiddenAtt2_hiddenCS2_dropout0.3_20230307_0141

In [11]:
from sklearn.model_selection import ParameterGrid
model_name = 'ddrsa_eeg_cr3'
param_grid = {'dropout': [0.2, 0.4], 
              'lr_train': [1e-4, 5e-4, 1e-3],
              'beta': [0.5, 1, 5],
              'gamma': [0.05, 0.1, 0.5],
             }
param_grid_list = list(ParameterGrid(param_grid))
print(len(param_grid_list))
param_grid_list

54


[{'beta': 0.5, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0001},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.0005},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.001},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0001},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.0005},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.1, 'lr_train': 0.001},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0001},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.0005},
 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.5, 'lr_train': 0.001},
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.05, 'lr_train': 0.0001},
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.05, 'lr_train': 0.0005},
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.05, 'lr_train': 0.001},
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.1, 'lr_train': 0.0001},
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.1, 'lr_train': 0.0005},
 {'beta': 0.5, 'dropout': 0.4, 'gamma': 0.1, 'lr_train': 0.001},
 {'beta':

In [12]:
boost_mode = 'ON' #{'ON', 'OFF'}

param = {'batch_size': 32,

         'num_epoch_burn_in': 20,
         'num_epoch': 100,

         'dropout': 0.2, # 1 - keep_prob
         'lr_train': 1e-4,

         'hidden_rnn': 64,
         'hidden_dim_FC': 64,
         'layers_rnn': 2,
         'hidden_att': 2,
         'hidden_cs' : 2,

         'alpha' :1.0,
         'beta'  :1,
         'gamma' :0.1
        }


batch_size = param['batch_size']
num_epoch  = param['num_epoch']

dropout    = param['dropout']
lr_train   = param['lr_train']

hidden_rnn = param['hidden_rnn']
hidden_dim_FC = param['hidden_dim_FC']
layers_rnn = param['layers_rnn']
hidden_att = param['hidden_att']
hidden_cs = param['hidden_cs']

alpha      = param['alpha']
beta       = param['beta']
gamma      = param['gamma']

device     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


### Split Dataset into Train/Valid/Test Sets

In [13]:
### TRAINING-TESTING SPLIT
(tr_data,te_data, tr_data_mi, te_data_mi, tr_time,te_time, tr_label,te_label, tr_time_orinal,te_time_original,
 tr_time_to_last,te_time_to_last,
 tr_mask1,te_mask1, tr_mask2,te_mask2, tr_mask3,te_mask3) = \
        train_test_split(data, data_mi, time, label, time_original, time_to_last,
                         mask1, mask2, mask3, test_size=0.2, stratify=label, random_state=seed) 

(tr_data,va_data, tr_data_mi, va_data_mi, tr_time,va_time, tr_label,va_label, tr_time_orinal,va_time_original,
 tr_time_to_last, va_time_to_last,
 tr_mask1,va_mask1, tr_mask2,va_mask2, tr_mask3,va_mask3) = \
        train_test_split(tr_data, tr_data_mi, tr_time, tr_label, tr_time_orinal, tr_time_to_last,
                         tr_mask1, tr_mask2, tr_mask3, test_size=0.2, stratify=tr_label, random_state=seed) 

if boost_mode == 'ON':
    tr_data, tr_data_mi, tr_time, tr_label, tr_mask1, tr_mask2, tr_mask3 = \
            f_get_boosted_trainset(tr_data, tr_data_mi, tr_time, tr_label, tr_mask1, tr_mask2, tr_mask3)

In [14]:
tr_data[50]

array([[ 1.        , 70.35817989, 77.31567078, ...,  1.        ,
         0.        ,  0.        ],
       [ 1.        , 85.97480151, 85.26077385, ...,  1.        ,
         0.        ,  0.        ],
       [ 1.        , 74.4407046 , 82.44836305, ...,  1.        ,
         0.        ,  0.        ],
       ...,
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ]])

In [15]:
tr_loader = DataLoader(list(zip(tr_data, tr_data_mi, tr_time, tr_label, tr_mask1, tr_mask2, tr_mask3)), 
                       batch_size=batch_size, shuffle=True)

va_loader = DataLoader(list(zip(va_data, va_data_mi, va_time, va_label, va_mask1, va_mask2, va_mask3)), 
                       batch_size=batch_size, shuffle=False)

te_loader = DataLoader(list(zip(te_data, te_data_mi, te_time, te_label, te_mask1, te_mask2, te_mask3)), 
                       batch_size=batch_size, shuffle=False)

### Train the Network

This implementation does not consider missing values.

In [16]:
num_Event

3

In [17]:
pred_time = [6, 12]
eval_time = [24, 48, 72]

In [None]:

Best_C_Index = 0

if not os.path.exists(f"checkpoints/{model_name}"):
    os.makedirs(f"checkpoints/{model_name}")


for params in tqdm(param_grid_list[2:]):
    dropout = params['dropout']
    lr_train = params['lr_train']
    beta = params['beta']
    gamma = params['gamma']
    print('='*80, '\n', params)

    model = \
        DDRSA(input_dim  = x_dim,
              output_dim = num_Category,
              layers_rnn = param['layers_rnn'],
              hidden_rnn = param['hidden_rnn'],
              long_param = {'layers': 2*[param['hidden_dim_FC']], 'dropout': dropout}, 
              att_param  = {'layers': param['hidden_att']*[param['hidden_dim_FC']], 'dropout': dropout}, 
              cs_param   = {'layers': param['hidden_cs']*[param['hidden_dim_FC']], 'dropout': dropout},
              risks      = num_Event
             )
    optimizer_burn_in = torch.optim.Adam(model.parameters(), lr = lr_train)
    optimizer = torch.optim.Adam(model.parameters(), lr = lr_train)


    model.train()
    model.to(device)

    min_valid = 0.1
    early_stop_patience = 10
    ep = 0 # for tracking early stopping

    experiment_runtime = datetime.datetime.now().strftime('%Y%m%d_%H%M')
    tensorboard_string_extension = f"{model_name}_{SEED}"
    tensorboard_string_extension += f"_alpha{alpha}_beta{beta}_gamma{gamma}_BSZ{batch_size}_Lr{lr_train}"
    tensorboard_string_extension += f"_hiddenRNN{hidden_rnn}_hiddenFC{hidden_dim_FC}"
    tensorboard_string_extension += f"_layersRNN{layers_rnn}_hiddenAtt{hidden_att}_hiddenCS{hidden_cs}"
    tensorboard_string_extension += f"_dropout{dropout}"
    tensorboard_string_extension += "_" + experiment_runtime
    tensorboard_filename = tensorboard_string_extension
    writer = SummaryWriter(f"runs/{model_name}/{tensorboard_filename}")


    for epoch in range(param['num_epoch_burn_in']):
        model.train()
        for i, (b_data, b_data_mi, b_time, b_label, b_mask1, b_mask2, b_mask3) in enumerate(tr_loader):
            b_data, _, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
                b_data.float(), b_data_mi.float(), b_time.float(), b_label.float(), \
                b_mask1.float(), b_mask2.float(), b_mask3.float()
            b_data, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
                b_data.to(device), b_time.to(device), b_label.to(device), \
                b_mask1.to(device), b_mask2.to(device), b_mask3.to(device)

            longitudinal_prediction, out = model(b_data)
            out = torch.concat([o.unsqueeze(1) for o in out], 1) # (B, num_Event, num_Category)

            # loss1 = loss_Log_Likelihood(out, b_label, b_mask1, b_mask2)
            # loss2 = loss_Ranking(out, b_time, b_label, b_mask3, num_Event, num_Category)
            loss3 = loss_RNN_Prediction(longitudinal_prediction, b_data)
            tr_loss = loss3

            optimizer_burn_in.zero_grad()
            tr_loss.backward()
            optimizer_burn_in.step()


    for epoch in range(num_epoch):
        model.train()
        for i, (b_data, b_data_mi, b_time, b_label, b_mask1, b_mask2, b_mask3) in enumerate(tr_loader):
            b_data, _, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
                b_data.float(), b_data_mi.float(), b_time.float(), b_label.float(), \
                b_mask1.float(), b_mask2.float(), b_mask3.float()
            b_data, b_time, b_label, b_mask1, b_mask2, b_mask3 = \
                b_data.to(device), b_time.to(device), b_label.to(device), \
                b_mask1.to(device), b_mask2.to(device), b_mask3.to(device)

            longitudinal_prediction, out = model(b_data)
            out = torch.concat([o.unsqueeze(1) for o in out], 1) # (B, num_Event, num_Category)

            loss1 = loss_Log_Likelihood(out, b_label, b_mask1, b_mask2)
            loss2 = loss_Ranking(out, b_time, b_label, b_mask3, num_Event, num_Category)
            loss3 = loss_RNN_Prediction(longitudinal_prediction, b_data)
            tr_loss = alpha * loss1 + beta * loss2 + gamma * loss3

            optimizer.zero_grad()
            tr_loss.backward()
            optimizer.step()

        model.eval()
        tr_loss, tr_loss1, tr_loss2, tr_loss3 = model_eval(model, tr_loader, device)
        va_loss, va_loss1, va_loss2, va_loss3 = model_eval(model, va_loader, device)

        ### VALIDATION  (based on average C-index of our interest)
        risk_all = f_get_risk_predictions(model, va_data, pred_time, eval_time, device)

        for p, p_time in enumerate(pred_time):
            pred_horizon = int(p_time)
            val_result1 = np.zeros([num_Event, len(eval_time)])

            for t, t_time in enumerate(eval_time):                
                eval_horizon = int(t_time) + pred_horizon
                for k in range(num_Event):
                    val_result1[k, t] = c_index(risk_all[k][:, p, t], va_time, (va_label[:,0] == k+1).astype(int), eval_horizon) #-1 for no event (not comparable)

            if p == 0:
                val_final1 = val_result1
            else:
                val_final1 = np.append(val_final1, val_result1, axis=0)

        tmp_valid = np.mean(val_final1)

        if tmp_valid >  min_valid:
            ep = 0
            min_valid = tmp_valid
            save_checkpoint(f"checkpoints_ddrsa/{model_name}/{writer.get_logdir().split('/')[-1]}_{SEED}.pt", model, optimizer, va_loss)
            print( 'updated.... average c-index = ' + str('%.4f' %(tmp_valid)))
        else:
            ep += 1  
            print("Counter {} of {}".format(ep, early_stop_patience))

        if ep >= early_stop_patience:
            print("Early stopping with best_cindex: {:.4f}".format(min_valid), 
                  "and val_cindex for this epoch: {:.4f}".format(tmp_valid))
            break

        # Tensorboard
        writer.add_scalar("Loss-Training/Loss", tr_loss, epoch)
        writer.add_scalar("Loss-Training/NLL Loss", tr_loss1, epoch)
        writer.add_scalar("Loss-Training/Ranking Loss", tr_loss2, epoch)
        writer.add_scalar("Loss-Training/LSTM MSELoss", tr_loss3, epoch)

        writer.add_scalar("Loss-Validation/Loss", va_loss, epoch)
        writer.add_scalar("Loss-Validation/NLL Loss", va_loss1, epoch)
        writer.add_scalar("Loss-Validation/Ranking Loss", va_loss2, epoch)
        writer.add_scalar("Loss-Validation/LSTM MSELoss", va_loss3, epoch)

        writer.add_scalar("C-index/Validation Avg C-index", min_valid, epoch)

        print('epoch {}, tr_loss: {:.4f}, va_loss: {:.4f}, va c-index: {:.4f}'.\
                      format(epoch+1, tr_loss, va_loss, tmp_valid))
        
    if min_valid > Best_C_Index:
        Best_C_Index = min_valid
        print('Better valudation c-index!!!')
        print(writer.get_logdir().split('/')[-1])

  0%|                                                                | 0/52 [00:00<?, ?it/s]

 {'beta': 0.5, 'dropout': 0.2, 'gamma': 0.05, 'lr_train': 0.001}
updated.... average c-index = 0.6010
epoch 1, tr_loss: 26.7280, va_loss: 35.3778, va c-index: 0.6010
updated.... average c-index = 0.6107
epoch 2, tr_loss: 24.9008, va_loss: 35.4974, va c-index: 0.6107
Counter 1 of 10
epoch 3, tr_loss: 24.5915, va_loss: 35.2764, va c-index: 0.5594
Counter 2 of 10
epoch 4, tr_loss: 24.1867, va_loss: 35.8731, va c-index: 0.5579
Counter 3 of 10
epoch 5, tr_loss: 24.5262, va_loss: 36.2500, va c-index: 0.5904
Counter 4 of 10
epoch 6, tr_loss: 25.9468, va_loss: 35.3984, va c-index: 0.5872
updated.... average c-index = 0.6207
epoch 7, tr_loss: 24.5841, va_loss: 35.0238, va c-index: 0.6207
Counter 1 of 10
epoch 8, tr_loss: 24.5827, va_loss: 35.7382, va c-index: 0.5923
Counter 2 of 10
epoch 9, tr_loss: 22.7727, va_loss: 36.8780, va c-index: 0.5886
Counter 3 of 10
epoch 10, tr_loss: 23.7698, va_loss: 36.4176, va c-index: 0.6024
Counter 4 of 10
epoch 11, tr_loss: 26.5165, va_loss: 37.6422, va c-inde

In [36]:
model

DDRSA(
  (embedding): LSTM(116, 64, bias=False, batch_first=True)
  (longitudinal): Sequential(
    (0): Dropout(p=0.2, inplace=False)
    (1): Linear(in_features=64, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): ReLU()
    (5): Linear(in_features=64, out_features=116, bias=True)
  )
  (attention): Sequential(
    (0): Dropout(p=0.2, inplace=False)
    (1): Linear(in_features=180, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=1, bias=True)
  )
  (attention_soft): Softmax(dim=1)
  (cause_specific_rnn): ModuleList(
    (0): LSTMCell(180, 180, bias=False)
    (1): LSTMCell(180, 180, bias=False)
    (2): LSTMCell(180, 180, bias=False)
  )
  (cause_specific): ModuleList(
    (0): Sequential(
      (0): Dropout(p=0.2, inplace=False)
      (1): Linear(in_features=180, out_features=64, bias=True)
      (2): ReLU()
      (3): Linear(in_features=64, out_features=64, bias=True)
      (4):

### Test the Trained Network