In [1]:
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')



In [2]:
import os
os.chdir('../')

import transEHR
from transEHR.utils import random_seed
# set random seed
random_seed(42)

In [3]:
# load dataset by passing data name and task name
train_dataset, val_dataset, test_dataset, num_train_set, cat_cols, num_cols, bin_cols = transEHR.load_data('./data/mimic-iii/', 'in-hospital-mortality')

########################################
load from local data dir ./data/mimic-iii/ for in-hospital-mortality task
# data: 21139, # feat: 21, # cate: 6,  # bin: 0, # numerical: 15


In [3]:
# build patient dataset
# from transEHR.dataset import PatientDataset

# train_dataset = PatientDataset(root_dir='./data/mimic-iii/in-hospital-mortality/', mode="train", 
#                                         label_file="./data/mimic-iii/in-hospital-mortality/train_listfile.csv")

# val_dataset = PatientDataset(root_dir='./data/mimic-iii/in-hospital-mortality/', mode="val", 
#                                         label_file="./data/mimic-iii/in-hospital-mortality/val_listfile.csv")

# test_dataset = PatientDataset(root_dir='./data/mimic-iii/in-hospital-mortality/', mode="test", 
#                                         label_file="./data/mimic-iii/in-hospital-mortality/test_listfile.csv")

In [4]:
# build Dataloader
from torch.utils.data import DataLoader
from transEHR.utils import SupervisedTrainCollator

patient_collate_fn = SupervisedTrainCollator(
                categorical_columns=cat_cols,
                numerical_columns=num_cols,
                binary_columns=bin_cols,
                ignore_duplicate_cols=False,
            )
train_loader = DataLoader(train_dataset, collate_fn=patient_collate_fn, batch_size=2, shuffle=False)
val_loader = DataLoader(val_dataset, collate_fn=patient_collate_fn, batch_size=2, shuffle=False)
test_loader = DataLoader(test_dataset, collate_fn=patient_collate_fn, batch_size=2, shuffle=False)

In [5]:
# A sample
data = next(iter(train_loader))

all_inputs, X_t_lens, y = data
print(len(all_inputs))
print(X_t_lens)
print(y)

76
[75, 69]
0    0
1    0
Name: y_true, dtype: int64


In [6]:
# static features
print(all_inputs[0])

{'x_num': tensor([[ 88.0654,   0.0000,  49.5000],
        [ 82.3834,   0.0000, 106.7756]], dtype=torch.float64), 'num_col_input_ids': tensor([[2287],
        [4578],
        [3635]]), 'x_cat_input_ids': tensor([[18240,  1018,  1012,  1014,  5907,  1015,  1012,  1014],
        [18240,  1018,  1012,  1014,  5907,  1016,  1012,  1014]]), 'x_bin_input_ids': None, 'num_att_mask': tensor([[1],
        [1],
        [1]]), 'cat_att_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1]])}


In [8]:
# time-serial 
print(all_inputs[1])

{'x_num': tensor([[  2.0000,  36.0000,   0.0000,   0.0000,   0.0000, 108.0000,   0.0000,
          60.6667,  95.0000,  43.0000, 110.0000,  39.1111,   0.0000,   0.0000],
        [  1.1628,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,  36.7778,   0.0000,   0.0000]],
       dtype=torch.float64), 'num_col_input_ids': tensor([[ 2847,     0,     0,     0,     0,     0,     0],
        [22939, 16033, 10415,  2668,  3778,     0,     0],
        [12884,  4427,  7722,     0,     0,     0,     0],
        [ 1043,  8523,  3597,  2860, 16571,  4094,  2561],
        [18423,     0,     0,     0,     0,     0,     0],
        [ 2540,  3446,     0,     0,     0,     0,     0],
        [ 4578,     0,     0,     0,     0,     0,     0],
        [ 2812,  2668,  3778,     0,     0,     0,     0],
        [ 7722,  2938, 18924,     0,     0,     0,     0],
        [16464,  3446,     0,     0,     0,     0,     0],
        [25353, 16033, 10415,  

In [4]:
# build transEHR classifier model
from transEHR.modeling_transtab import TransEHRClassifier

model = TransEHRClassifier(
        categorical_columns = cat_cols,
        numerical_columns = num_cols,
        binary_columns = bin_cols,
        num_class=2,
        hidden_dim=128,
        num_layer=2,
        num_attention_head=8,
        hidden_dropout_prob=0,
        ffn_dim=256,
        activation='relu',
        device='cuda:0',
        )


# logits, loss = model(data[0], data[1], data[2])
# print(logits, loss)

In [5]:
# specify training arguments, take validation loss for early stopping
training_arguments = {
    'num_epoch':10,
    'batch_size':2,   #16,
    'lr':2e-5,
    'eval_metric':'auc',
    'eval_less_is_better':False,
    'output_dir':'../checkpoint',
    'num_workers': 0,
    'warmup_ratio':None,
    'warmup_steps':None,
    'num_train_set':num_train_set,
}

from transEHR.train import Trainer

trainer = Trainer(
        model,
        train_dataset,
        val_dataset,
        **training_arguments,
    )

trainer.train()


Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

epoch: 0, ite: 1, train loss: 0.6640
epoch: 0, ite: 2, train loss: 0.5552
epoch: 0, ite: 3, train loss: 0.5483
epoch: 0, ite: 4, train loss: 0.6722
epoch: 0, ite: 5, train loss: 0.4657
epoch: 0, ite: 6, train loss: 0.5016
epoch: 0, ite: 7, train loss: 0.8826
epoch: 0, ite: 8, train loss: 0.4901
epoch: 0, ite: 9, train loss: 0.4603


Epoch:   0%|          | 0/10 [06:54<?, ?it/s]


KeyboardInterrupt: 