In [1]:
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.models import Transformer, RNN, RETAIN, MLP
from pyhealth.tasks import mortality_prediction_mimic3_fn, readmission_prediction_mimic3_fn, drug_recommendation_mimic3_fn, length_of_stay_prediction_mimic3_fn
from pyhealth.trainer import Trainer

dataset = MIMIC3Dataset(
    root='/data/physionet.org/files/mimiciii/1.4/',
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
    code_mapping={
        # "ICD9CM": "CCSCM", 
        # "ICD9PROC": "CCSPROC",
        "NDC": ("ATC", {"target_kwargs": {"level": 3}})
        },
    dev=False,
    refresh_cache=False
)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
mimic3_ds = dataset.set_task(readmission_prediction_mimic3_fn)

train_dataset, val_dataset, test_dataset = split_by_patient(
    mimic3_ds, [0.8, 0.1, 0.1]
)
train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

Generating samples for readmission_prediction_mimic3_fn: 100%|██████████| 46520/46520 [00:00<00:00, 108702.43it/s]


In [34]:
model_w_pre = Transformer(
    dataset=mimic3_ds,
    feature_keys=["conditions", "procedures", "drugs"],
    label_key="label",
    mode="binary",
    pretrained_emb="LM/sapbert",
    embedding_dim=256,
)

Loading pretrained embedding for conditions...
Loading pretrained embedding for procedures...
Loading pretrained embedding for drugs...


In [35]:
# STEP 4: define trainer
trainer = Trainer(model=model_w_pre)
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=15,
    optimizer_params = {"lr": 1e-4},
    monitor="pr_auc",
)

# STEP 5: evaluate
print(trainer.evaluate(test_dataloader))

Transformer(
  (embeddings): ModuleDict(
    (conditions): Embedding(4031, 768, padding_idx=0)
    (procedures): Embedding(1276, 768, padding_idx=0)
    (drugs): Embedding(194, 768, padding_idx=0)
  )
  (linear_layers): ModuleDict(
    (conditions): Linear(in_features=768, out_features=256, bias=True)
    (procedures): Linear(in_features=768, out_features=256, bias=True)
    (drugs): Linear(in_features=768, out_features=256, bias=True)
  )
  (transformer): ModuleDict(
    (conditions): TransformerLayer(
      (transformer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadedAttention(
            (linear_layers): ModuleList(
              (0): Linear(in_features=256, out_features=256, bias=False)
              (1): Linear(in_features=256, out_features=256, bias=False)
              (2): Linear(in_features=256, out_features=256, bias=False)
            )
            (output_linear): Linear(in_features=256, out_features=256, bias=False)
            (attention): 

Epoch 0 / 15: 100%|██████████| 243/243 [00:03<00:00, 62.86it/s]

--- Train epoch-0, step-243 ---
loss: 1.8513



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 261.69it/s]

--- Eval epoch-0, step-243 ---
pr_auc: 0.6164
roc_auc: 0.6185
f1: 0.6404
loss: 0.7276
New best pr_auc score (0.6164) at epoch-0, step-243








Epoch 1 / 15: 100%|██████████| 243/243 [00:03<00:00, 79.53it/s]

--- Train epoch-1, step-486 ---
loss: 1.0570



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 207.86it/s]

--- Eval epoch-1, step-486 ---
pr_auc: 0.5964
roc_auc: 0.5799
f1: 0.0682
loss: 1.8373




Epoch 2 / 15: 100%|██████████| 243/243 [00:03<00:00, 67.86it/s]

--- Train epoch-2, step-729 ---
loss: 0.9042



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 223.66it/s]

--- Eval epoch-2, step-729 ---
pr_auc: 0.6408
roc_auc: 0.6332
f1: 0.3031
loss: 0.8568
New best pr_auc score (0.6408) at epoch-2, step-729








Epoch 3 / 15: 100%|██████████| 243/243 [00:04<00:00, 55.07it/s]

--- Train epoch-3, step-972 ---
loss: 0.8901



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 230.67it/s]

--- Eval epoch-3, step-972 ---
pr_auc: 0.6363
roc_auc: 0.6316
f1: 0.6910
loss: 0.9789




Epoch 4 / 15: 100%|██████████| 243/243 [00:03<00:00, 71.67it/s]

--- Train epoch-4, step-1215 ---
loss: 0.8306



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 179.66it/s]

--- Eval epoch-4, step-1215 ---
pr_auc: 0.6267
roc_auc: 0.6126
f1: 0.2055
loss: 0.9268




Epoch 5 / 15: 100%|██████████| 243/243 [00:02<00:00, 85.26it/s]

--- Train epoch-5, step-1458 ---
loss: 0.7842



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 232.22it/s]

--- Eval epoch-5, step-1458 ---
pr_auc: 0.6562
roc_auc: 0.6520
f1: 0.5544
loss: 0.6879
New best pr_auc score (0.6562) at epoch-5, step-1458








Epoch 6 / 15: 100%|██████████| 243/243 [00:02<00:00, 84.32it/s]

--- Train epoch-6, step-1701 ---
loss: 0.7804



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 229.38it/s]

--- Eval epoch-6, step-1701 ---
pr_auc: 0.6571
roc_auc: 0.6552
f1: 0.6567
loss: 0.6736
New best pr_auc score (0.6571) at epoch-6, step-1701








Epoch 7 / 15: 100%|██████████| 243/243 [00:03<00:00, 64.07it/s]

--- Train epoch-7, step-1944 ---
loss: 0.7510



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 214.18it/s]

--- Eval epoch-7, step-1944 ---
pr_auc: 0.6543
roc_auc: 0.6515
f1: 0.6539
loss: 0.6687




Epoch 8 / 15: 100%|██████████| 243/243 [00:03<00:00, 80.98it/s]

--- Train epoch-8, step-2187 ---
loss: 0.7375



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 192.19it/s]

--- Eval epoch-8, step-2187 ---
pr_auc: 0.6565
roc_auc: 0.6471
f1: 0.6957
loss: 0.7565




Epoch 9 / 15: 100%|██████████| 243/243 [00:02<00:00, 83.44it/s]

--- Train epoch-9, step-2430 ---
loss: 0.7167



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 233.96it/s]

--- Eval epoch-9, step-2430 ---
pr_auc: 0.6566
roc_auc: 0.6477
f1: 0.6994
loss: 0.7239




Epoch 10 / 15: 100%|██████████| 243/243 [00:03<00:00, 79.00it/s]

--- Train epoch-10, step-2673 ---
loss: 0.6904



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 174.31it/s]

--- Eval epoch-10, step-2673 ---
pr_auc: 0.6613
roc_auc: 0.6566
f1: 0.6678
loss: 0.6750
New best pr_auc score (0.6613) at epoch-10, step-2673








Epoch 11 / 15: 100%|██████████| 243/243 [00:02<00:00, 87.90it/s]

--- Train epoch-11, step-2916 ---
loss: 0.6895



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 174.85it/s]

--- Eval epoch-11, step-2916 ---
pr_auc: 0.6626
roc_auc: 0.6569
f1: 0.5957
loss: 0.6720
New best pr_auc score (0.6626) at epoch-11, step-2916








Epoch 12 / 15: 100%|██████████| 243/243 [00:02<00:00, 94.81it/s]

--- Train epoch-12, step-3159 ---
loss: 0.6678



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 144.92it/s]

--- Eval epoch-12, step-3159 ---
pr_auc: 0.6602
roc_auc: 0.6487
f1: 0.6946
loss: 0.8150




Epoch 13 / 15: 100%|██████████| 243/243 [00:02<00:00, 87.42it/s]

--- Train epoch-13, step-3402 ---
loss: 0.6650



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 213.11it/s]

--- Eval epoch-13, step-3402 ---
pr_auc: 0.6657
roc_auc: 0.6558
f1: 0.6344
loss: 0.6663
New best pr_auc score (0.6657) at epoch-13, step-3402








Epoch 14 / 15: 100%|██████████| 243/243 [00:03<00:00, 71.42it/s]

--- Train epoch-14, step-3645 ---
loss: 0.6534



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 234.73it/s]

--- Eval epoch-14, step-3645 ---
pr_auc: 0.6619
roc_auc: 0.6526
f1: 0.3598
loss: 0.7386
Loaded best model



Evaluation: 100%|██████████| 32/32 [00:00<00:00, 256.61it/s]

{'pr_auc': 0.7164389859666139, 'roc_auc': 0.6767631286457192, 'f1': 0.6512488436632747, 'loss': 0.6382618751376867}





In [32]:
model_no_pre = Transformer(
    dataset=mimic3_ds,
    feature_keys=["conditions", "procedures", "drugs"],
    label_key="label",
    mode="binary",
    embedding_dim=256,
)

In [33]:
# STEP 4: define trainer
trainer = Trainer(model=model_no_pre)
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=15,
    optimizer_params = {"lr": 1e-4},
    monitor="pr_auc",
)

# STEP 5: evaluate
print(trainer.evaluate(test_dataloader))

Transformer(
  (embeddings): ModuleDict(
    (conditions): Embedding(4031, 256, padding_idx=0)
    (procedures): Embedding(1276, 256, padding_idx=0)
    (drugs): Embedding(194, 256, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (transformer): ModuleDict(
    (conditions): TransformerLayer(
      (transformer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadedAttention(
            (linear_layers): ModuleList(
              (0): Linear(in_features=256, out_features=256, bias=False)
              (1): Linear(in_features=256, out_features=256, bias=False)
              (2): Linear(in_features=256, out_features=256, bias=False)
            )
            (output_linear): Linear(in_features=256, out_features=256, bias=False)
            (attention): Attention()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (feed_forward): PositionwiseFeedForward(
            (w_1): Linear(in_features=256, out_features=1024, bias=True)
          

Epoch 0 / 15: 100%|██████████| 243/243 [00:03<00:00, 66.17it/s]

--- Train epoch-0, step-243 ---
loss: 1.2175



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 227.09it/s]

--- Eval epoch-0, step-243 ---
pr_auc: 0.6090
roc_auc: 0.5884
f1: 0.6103
loss: 0.7941
New best pr_auc score (0.6090) at epoch-0, step-243




Epoch 1 / 15: 100%|██████████| 243/243 [00:02<00:00, 95.22it/s]

--- Train epoch-1, step-486 ---
loss: 1.0263



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 270.01it/s]

--- Eval epoch-1, step-486 ---
pr_auc: 0.6337
roc_auc: 0.6155
f1: 0.6277
loss: 0.7491
New best pr_auc score (0.6337) at epoch-1, step-486




Epoch 2 / 15: 100%|██████████| 243/243 [00:02<00:00, 93.65it/s]

--- Train epoch-2, step-729 ---
loss: 0.9085



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 267.19it/s]

--- Eval epoch-2, step-729 ---
pr_auc: 0.6457
roc_auc: 0.6255
f1: 0.6238
loss: 0.7093
New best pr_auc score (0.6457) at epoch-2, step-729




Epoch 3 / 15: 100%|██████████| 243/243 [00:03<00:00, 76.83it/s]

--- Train epoch-3, step-972 ---
loss: 0.8412



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 198.09it/s]

--- Eval epoch-3, step-972 ---
pr_auc: 0.6532
roc_auc: 0.6369
f1: 0.6521
loss: 0.7020
New best pr_auc score (0.6532) at epoch-3, step-972








Epoch 4 / 15: 100%|██████████| 243/243 [00:02<00:00, 91.48it/s]

--- Train epoch-4, step-1215 ---
loss: 0.7811



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 276.12it/s]

--- Eval epoch-4, step-1215 ---
pr_auc: 0.6603
roc_auc: 0.6385
f1: 0.6323
loss: 0.6859
New best pr_auc score (0.6603) at epoch-4, step-1215




Epoch 5 / 15: 100%|██████████| 243/243 [00:02<00:00, 99.31it/s] 

--- Train epoch-5, step-1458 ---
loss: 0.7451



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 270.40it/s]

--- Eval epoch-5, step-1458 ---
pr_auc: 0.6580
roc_auc: 0.6382
f1: 0.6291
loss: 0.6849




Epoch 6 / 15: 100%|██████████| 243/243 [00:02<00:00, 93.74it/s]

--- Train epoch-6, step-1701 ---
loss: 0.7158



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 224.82it/s]

--- Eval epoch-6, step-1701 ---
pr_auc: 0.6562
roc_auc: 0.6378
f1: 0.6371
loss: 0.6838




Epoch 7 / 15: 100%|██████████| 243/243 [00:02<00:00, 90.81it/s]


--- Train epoch-7, step-1944 ---
loss: 0.6777


Evaluation: 100%|██████████| 31/31 [00:00<00:00, 278.12it/s]

--- Eval epoch-7, step-1944 ---
pr_auc: 0.6590
roc_auc: 0.6417
f1: 0.6172
loss: 0.6782




Epoch 8 / 15: 100%|██████████| 243/243 [00:02<00:00, 95.73it/s]

--- Train epoch-8, step-2187 ---
loss: 0.6581



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 281.33it/s]

--- Eval epoch-8, step-2187 ---
pr_auc: 0.6592
roc_auc: 0.6462
f1: 0.6437
loss: 0.6810




Epoch 9 / 15: 100%|██████████| 243/243 [00:03<00:00, 73.40it/s]

--- Train epoch-9, step-2430 ---
loss: 0.6387



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 272.67it/s]

--- Eval epoch-9, step-2430 ---
pr_auc: 0.6576
roc_auc: 0.6428
f1: 0.6589
loss: 0.6868




Epoch 10 / 15: 100%|██████████| 243/243 [00:02<00:00, 84.12it/s]

--- Train epoch-10, step-2673 ---
loss: 0.6261



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 267.80it/s]

--- Eval epoch-10, step-2673 ---
pr_auc: 0.6570
roc_auc: 0.6419
f1: 0.6217
loss: 0.6812




Epoch 11 / 15: 100%|██████████| 243/243 [00:03<00:00, 73.52it/s]

--- Train epoch-11, step-2916 ---
loss: 0.6116



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 276.41it/s]

--- Eval epoch-11, step-2916 ---
pr_auc: 0.6522
roc_auc: 0.6367
f1: 0.5841
loss: 0.6851




Epoch 12 / 15: 100%|██████████| 243/243 [00:02<00:00, 92.59it/s]

--- Train epoch-12, step-3159 ---
loss: 0.6066



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 278.50it/s]

--- Eval epoch-12, step-3159 ---
pr_auc: 0.6551
roc_auc: 0.6330
f1: 0.5935
loss: 0.6867




Epoch 13 / 15: 100%|██████████| 243/243 [00:03<00:00, 77.78it/s]

--- Train epoch-13, step-3402 ---
loss: 0.5941



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 226.16it/s]

--- Eval epoch-13, step-3402 ---
pr_auc: 0.6570
roc_auc: 0.6378
f1: 0.6110
loss: 0.6855




Epoch 14 / 15: 100%|██████████| 243/243 [00:02<00:00, 90.75it/s]

--- Train epoch-14, step-3645 ---
loss: 0.5834



Evaluation: 100%|██████████| 31/31 [00:00<00:00, 269.48it/s]

--- Eval epoch-14, step-3645 ---
pr_auc: 0.6611
roc_auc: 0.6421
f1: 0.6327
loss: 0.6925
New best pr_auc score (0.6611) at epoch-14, step-3645
Loaded best model



Evaluation: 100%|██████████| 32/32 [00:00<00:00, 264.14it/s]

{'pr_auc': 0.697333603874223, 'roc_auc': 0.6722280364429949, 'f1': 0.6672694394213382, 'loss': 0.6479178862646222}



