In [2]:
# for sample in samples:
#     sample["conditions"] = sample["conditions"][0]
#     sample["procedures"] = sample["procedures"][0]
#     sample["drugs"] = sample["drugs"][0]

In [9]:
TASK = "mortality"
DATASET = "mimic3"


In [10]:
from pyhealth.datasets import SampleEHRDataset
import json

with open(f"/shared/eng/pj20/kelpie_exp_data/ehr_data/{DATASET}_{TASK}_samples_train.json", "r") as f:
    samples_train = json.load(f)
with open(f"/shared/eng/pj20/kelpie_exp_data/ehr_data/{DATASET}_{TASK}_samples_test.json", "r") as f:
    samples_test = json.load(f)


dataset_train = SampleEHRDataset(samples_train, dataset_name=DATASET, task_name=TASK)
dataset_test = SampleEHRDataset(samples_test, dataset_name=DATASET, task_name=TASK)


In [11]:
len(dataset_test)

996

In [12]:
from pyhealth.datasets import split_by_patient, get_dataloader

# train_dataset, val_dataset, test_dataset = split_by_patient(
#     dataset, [0.8, 0.1, 0.1], seed=528
# )
train_dataloader = get_dataloader(dataset_train, batch_size=32, shuffle=True)
test_dataloader = get_dataloader(dataset_test, batch_size=32, shuffle=False)

In [18]:
from pyhealth.trainer import Trainer
from pyhealth.models import Deepr, AdaCare, StageNet, GRASP, Transformer, RETAIN, RNN
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "5"

model = GRASP(
    dataset=dataset_train,
    feature_keys=["conditions", "procedures", "drugs"],
    label_key="label",
    mode="binary",
    use_embedding=[True, True, True],
    embedding_dim=128,
)


In [19]:
trainer = Trainer(model=model, metrics=['accuracy', 'f1', 'pr_auc', 'sensitivity', 'specificity'])
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=test_dataloader,
    epochs=15,
    optimizer_params = {"lr": 1e-3},
    monitor="f1",
)

# STEP 5: evaluate
print(trainer.evaluate(test_dataloader))

GRASP(
  (embeddings): ModuleDict(
    (conditions): Embedding(263, 128, padding_idx=0)
    (procedures): Embedding(192, 128, padding_idx=0)
    (drugs): Embedding(194, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (grasp): ModuleDict(
    (conditions): GRASPLayer(
      (backbone): ConCareLayer(
        (PositionalEncoding): PositionalEncoding(
          (dropout): Dropout(p=0, inplace=False)
        )
        (GRUs): ModuleList(
          (0-127): 128 x GRU(1, 128, batch_first=True)
        )
        (LastStepAttentions): ModuleList(
          (0-127): 128 x SingleAttention(
            (tanh): Tanh()
            (softmax): Softmax(dim=1)
            (sigmoid): Sigmoid()
            (relu): ReLU()
          )
        )
        (FinalAttentionQKV): FinalAttentionQKV(
          (W_q): Linear(in_features=128, out_features=128, bias=True)
          (W_k): Linear(in_features=128, out_features=128, bias=True)
          (W_v): Linear(in_features=128, out_features=128, bias=True)

Epoch 0 / 15: 100%|██████████| 242/242 [04:14<00:00,  1.05s/it]

--- Train epoch-0, step-242 ---
loss: 0.2626



Evaluation: 100%|██████████| 32/32 [00:11<00:00,  2.69it/s]

--- Eval epoch-0, step-242 ---
accuracy: 0.9458
f1: 0.4861
pr_auc: 0.0718
sensitivity: 0.0000
specificity: 1.0000
loss: 0.2068
New best f1 score (0.4861) at epoch-0, step-242








Epoch 1 / 15: 100%|██████████| 242/242 [03:51<00:00,  1.04it/s]

--- Train epoch-1, step-484 ---
loss: 0.2530



Evaluation: 100%|██████████| 32/32 [00:11<00:00,  2.69it/s]

--- Eval epoch-1, step-484 ---
accuracy: 0.9458
f1: 0.4861
pr_auc: 0.0862
sensitivity: 0.0000
specificity: 1.0000
loss: 0.2120




Epoch 2 / 15: 100%|██████████| 242/242 [03:58<00:00,  1.02it/s]

--- Train epoch-2, step-726 ---
loss: 0.2492



Evaluation: 100%|██████████| 32/32 [00:11<00:00,  2.69it/s]

--- Eval epoch-2, step-726 ---
accuracy: 0.9458
f1: 0.4861
pr_auc: 0.0861
sensitivity: 0.0000
specificity: 1.0000
loss: 0.2159




Epoch 3 / 15: 100%|██████████| 242/242 [03:53<00:00,  1.04it/s]

--- Train epoch-3, step-968 ---
loss: 0.2434



Evaluation: 100%|██████████| 32/32 [00:11<00:00,  2.69it/s]

--- Eval epoch-3, step-968 ---
accuracy: 0.9458
f1: 0.4861
pr_auc: 0.0855
sensitivity: 0.0000
specificity: 1.0000
loss: 0.2057




Epoch 4 / 15: 100%|██████████| 242/242 [03:49<00:00,  1.05it/s]

--- Train epoch-4, step-1210 ---
loss: 0.2342



Evaluation: 100%|██████████| 32/32 [00:11<00:00,  2.67it/s]

--- Eval epoch-4, step-1210 ---
accuracy: 0.9458
f1: 0.4861
pr_auc: 0.0948
sensitivity: 0.0000
specificity: 1.0000
loss: 0.2239




Epoch 5 / 15: 100%|██████████| 242/242 [03:50<00:00,  1.05it/s]

--- Train epoch-5, step-1452 ---
loss: 0.2335



Evaluation: 100%|██████████| 32/32 [00:11<00:00,  2.69it/s]

--- Eval epoch-5, step-1452 ---
accuracy: 0.9458
f1: 0.4861
pr_auc: 0.0957
sensitivity: 0.0000
specificity: 1.0000
loss: 0.2145




Epoch 6 / 15: 100%|██████████| 242/242 [03:53<00:00,  1.04it/s]

--- Train epoch-6, step-1694 ---
loss: 0.2251



Evaluation: 100%|██████████| 32/32 [00:11<00:00,  2.68it/s]

--- Eval epoch-6, step-1694 ---
accuracy: 0.9458
f1: 0.4861
pr_auc: 0.0970
sensitivity: 0.0000
specificity: 1.0000
loss: 0.2282




Epoch 7 / 15: 100%|██████████| 242/242 [04:21<00:00,  1.08s/it]

--- Train epoch-7, step-1936 ---
loss: 0.2171



Evaluation: 100%|██████████| 32/32 [00:11<00:00,  2.69it/s]

--- Eval epoch-7, step-1936 ---
accuracy: 0.9367
f1: 0.4990
pr_auc: 0.1022
sensitivity: 0.0185
specificity: 0.9894
loss: 0.2501
New best f1 score (0.4990) at epoch-7, step-1936








Epoch 8 / 15: 100%|██████████| 242/242 [04:21<00:00,  1.08s/it]

--- Train epoch-8, step-2178 ---
loss: 0.2160



Evaluation: 100%|██████████| 32/32 [00:11<00:00,  2.67it/s]

--- Eval epoch-8, step-2178 ---
accuracy: 0.9458
f1: 0.5039
pr_auc: 0.1006
sensitivity: 0.0185
specificity: 0.9989
loss: 0.2232
New best f1 score (0.5039) at epoch-8, step-2178








Epoch 9 / 15: 100%|██████████| 242/242 [03:59<00:00,  1.01it/s]

--- Train epoch-9, step-2420 ---
loss: 0.2079



Evaluation: 100%|██████████| 32/32 [00:11<00:00,  2.67it/s]

--- Eval epoch-9, step-2420 ---
accuracy: 0.9468
f1: 0.5214
pr_auc: 0.1123
sensitivity: 0.0370
specificity: 0.9989
loss: 0.2682
New best f1 score (0.5214) at epoch-9, step-2420








Epoch 10 / 15: 100%|██████████| 242/242 [03:48<00:00,  1.06it/s]

--- Train epoch-10, step-2662 ---
loss: 0.2240



Evaluation: 100%|██████████| 32/32 [00:11<00:00,  2.70it/s]

--- Eval epoch-10, step-2662 ---
accuracy: 0.9458
f1: 0.4861
pr_auc: 0.0929
sensitivity: 0.0000
specificity: 1.0000
loss: 0.2291




Epoch 11 / 15: 100%|██████████| 242/242 [03:49<00:00,  1.06it/s]

--- Train epoch-11, step-2904 ---
loss: 0.2042



Evaluation: 100%|██████████| 32/32 [00:11<00:00,  2.71it/s]

--- Eval epoch-11, step-2904 ---
accuracy: 0.9448
f1: 0.5197
pr_auc: 0.1086
sensitivity: 0.0370
specificity: 0.9968
loss: 0.2404




Epoch 12 / 15: 100%|██████████| 242/242 [03:49<00:00,  1.05it/s]

--- Train epoch-12, step-3146 ---
loss: 0.2058



Evaluation: 100%|██████████| 32/32 [00:11<00:00,  2.69it/s]

--- Eval epoch-12, step-3146 ---
accuracy: 0.9448
f1: 0.4858
pr_auc: 0.0991
sensitivity: 0.0000
specificity: 0.9989
loss: 0.2260




Epoch 13 / 15: 100%|██████████| 242/242 [03:54<00:00,  1.03it/s]

--- Train epoch-13, step-3388 ---
loss: 0.2106



Evaluation: 100%|██████████| 32/32 [00:13<00:00,  2.42it/s]

--- Eval epoch-13, step-3388 ---
accuracy: 0.9458
f1: 0.4861
pr_auc: 0.1042
sensitivity: 0.0000
specificity: 1.0000
loss: 0.2170




Epoch 14 / 15: 100%|██████████| 242/242 [03:55<00:00,  1.03it/s]

--- Train epoch-14, step-3630 ---
loss: 0.2121



Evaluation: 100%|██████████| 32/32 [00:15<00:00,  2.13it/s]

--- Eval epoch-14, step-3630 ---
accuracy: 0.9448
f1: 0.4858
pr_auc: 0.1082
sensitivity: 0.0000
specificity: 0.9989
loss: 0.2135
Loaded best model



  state_dict = torch.load(ckpt_path, map_location=self.device)
Evaluation: 100%|██████████| 32/32 [00:11<00:00,  2.71it/s]

{'accuracy': 0.9447791164658634, 'f1': 0.5033319733442132, 'pr_auc': 0.09635838918523011, 'sensitivity': 0.018518518518518517, 'specificity': 0.9978768577494692, 'loss': 0.2662505491171032}



