In [1]:
import os
os.chdir('../')

import transEHR
from transEHR.utils import random_seed
# set random seed
random_seed(42)



In [2]:
from transEHR.dataset import load_data

_, _, test_dataset, _, cat_cols, num_cols, bin_cols = transEHR.load_data('./data/mimic-iii/', 'sepsis')

########################################
load from local data dir ./data/mimic-iii/ for sepsis task
# data: 27250, # feat: 19, # cate: 3,  # bin: 0, # numerical: 16


In [3]:
# build BERT classifier model
from transEHR.modeling_transtab import TransEHRClassifier

model = TransEHRClassifier(
        categorical_columns = cat_cols,
        numerical_columns = num_cols,
        binary_columns = bin_cols,
        num_class=2,
        hidden_dim=128,
        num_layer=3,
        num_attention_head=8,
        hidden_dropout_prob=0,
        ffn_dim=256,
        activation='relu',
        device='cuda:1',
    )

In [5]:
from transEHR.testing import Testing

test = Testing(
    model=model,
    test_set_list=test_dataset,
    ckpt_dir='./sepsis_checkpoint/epoch_6/',
    batch_size=8,
    num_workers=3,
)

test.predict()

2023-05-22 09:16:07.555 | INFO     | transEHR.modeling_transtab:load:812 - missing keys: []
2023-05-22 09:16:07.557 | INFO     | transEHR.modeling_transtab:load:813 - unexpected keys: []
2023-05-22 09:16:07.558 | INFO     | transEHR.modeling_transtab:load:814 - load model from ./sepsis_checkpoint/epoch_6/
2023-05-22 09:16:07.593 | INFO     | transEHR.modeling_transtab:load:229 - load feature extractor from ./sepsis_checkpoint/epoch_6/extractor/extractor.json
100%|██████████| 349/349 [10:35<00:00,  1.82s/it]

confusion matrix:
[[2673    8]
 [  49   60]]
accuracy = 0.9795699119567871
precision class 0 = 0.9819985032081604
precision class 1 = 0.8823529481887817
recall class 0 = 0.9970160126686096
recall class 1 = 0.5504587292671204
AUC of ROC = 0.9727439781814946
AUC of PRC = 0.8408114719967309
min(+P, Se) = 0.7155963302752294
f1_score = 0.677966085930478
END Predict, cost time: 635.3 secs





### ZSL on physionet challenge 2019

In [7]:
_, _, test_dataset, _, cat_cols, num_cols, bin_cols = transEHR.load_data('./data/physionet_sepsis/', 'binary_sepsis_predict')

########################################
load from local data dir ./data/physionet_sepsis/ for binary_sepsis_predict task
# data: 40336, # feat: 39, # cate: 3,  # bin: 0, # numerical: 36


In [8]:
test = Testing(
    model=model,
    test_set_list=test_dataset,
    ckpt_dir='./sepsis_checkpoint/epoch_6/',
    batch_size=8,
    num_workers=3,
)

test.predict()

2023-05-22 09:42:29.976 | INFO     | transEHR.modeling_transtab:load:812 - missing keys: []
2023-05-22 09:42:29.978 | INFO     | transEHR.modeling_transtab:load:813 - unexpected keys: []
2023-05-22 09:42:29.979 | INFO     | transEHR.modeling_transtab:load:814 - load model from ./sepsis_checkpoint/epoch_6/
2023-05-22 09:42:30.018 | INFO     | transEHR.modeling_transtab:load:229 - load feature extractor from ./sepsis_checkpoint/epoch_6/extractor/extractor.json
100%|██████████| 518/518 [05:22<00:00,  1.61it/s]

confusion matrix:
[[3731  103]
 [ 283   20]]
accuracy = 0.9066956639289856
precision class 0 = 0.9294967651367188
precision class 1 = 0.16260161995887756
recall class 0 = 0.9731351137161255
recall class 1 = 0.066006600856781
AUC of ROC = 0.6477384045133778
AUC of PRC = 0.12095818894699423
min(+P, Se) = 0.15822784810126583
f1_score = 0.093896712804096
END Predict, cost time: 322.8 secs





### Train Together

In [2]:
_, _, test_dataset, _, cat_cols, num_cols, bin_cols = transEHR.load_data(['./data/mimic-iii/', './data/physionet_sepsis/'], ['sepsis', 'binary_sepsis_predict'])

########################################
load from local data dir ./data/mimic-iii/ for sepsis task
# data: 27250, # feat: 49, # cate: 3,  # bin: 0, # numerical: 46
########################################
load from local data dir ./data/physionet_sepsis/ for binary_sepsis_predict task
# data: 40336, # feat: 39, # cate: 3,  # bin: 0, # numerical: 36
