In [1]:
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.models import Transformer, RNN, RETAIN
from pyhealth.tasks import mortality_prediction_mimic3_fn, drug_recommendation_mimic3_fn
from pyhealth.trainer import Trainer

dataset = MIMIC3Dataset(
    root='/data/physionet.org/files/mimiciii/1.4/',
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
    code_mapping={
        # "ICD9CM": "CCSCM", 
        # "ICD9PROC": "CCSPROC",
        "NDC": ("ATC", {"target_kwargs": {"level": 3}})
        },
    dev=False,
    refresh_cache=True
)


  from .autonotebook import tqdm as notebook_tqdm


INFO: Pandarallel will run on 64 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
finish basic patient information parsing : 5.599072456359863s
finish parsing DIAGNOSES_ICD : 9.125769138336182s
finish parsing PROCEDURES_ICD : 6.70231294631958s
finish parsing PRESCRIPTIONS : 58.10236859321594s


Mapping codes: 100%|██████████| 46520/46520 [01:21<00:00, 570.40it/s] 


In [2]:
mimic3_ds = dataset.set_task(drug_recommendation_mimic3_fn)

train_dataset, val_dataset, test_dataset = split_by_patient(
    mimic3_ds, [0.8, 0.1, 0.1]
)
train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

Generating samples for drug_recommendation_mimic3_fn: 100%|██████████| 46520/46520 [00:02<00:00, 21575.02it/s]


In [4]:
model_w_pre = RNN(
    dataset=mimic3_ds,
    feature_keys=["conditions", "procedures"],
    label_key="drugs",
    mode="multilabel",
    pretrained_emb="LM/clinicalbert",
    embedding_dim=128,
)

Loading pretrained embedding for conditions...
Loading pretrained embedding for procedures...


In [5]:
# STEP 4: define trainer
trainer = Trainer(model=model_w_pre)
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=30,
    monitor="pr_auc_samples",
)

# STEP 5: evaluate
print(trainer.evaluate(test_dataloader))

RNN(
  (embeddings): ModuleDict(
    (conditions): Embedding(4493, 768, padding_idx=0)
    (procedures): Embedding(1414, 768, padding_idx=0)
  )
  (linear_layers): ModuleDict(
    (conditions): Linear(in_features=768, out_features=128, bias=True)
    (procedures): Linear(in_features=768, out_features=128, bias=True)
  )
  (rnn): ModuleDict(
    (conditions): RNNLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (rnn): GRU(128, 128, batch_first=True)
    )
    (procedures): RNNLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (rnn): GRU(128, 128, batch_first=True)
    )
  )
  (fc): Linear(in_features=256, out_features=193, bias=True)
)
Metrics: None
Device: cuda

Training:
Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7f25c3cbb4c0>
Monitor: pr_auc_samples
Monitor criterion: max
Epochs: 30



Epoch 0 / 30: 100%|██████████| 354/354 [00:03<00:00, 106.94it/s]

--- Train epoch-0, step-354 ---
loss: 0.2926



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 255.78it/s]


--- Eval epoch-0, step-354 ---
pr_auc_samples: 0.6669
loss: 0.2808
New best pr_auc_samples score (0.6669) at epoch-0, step-354



Epoch 1 / 30: 100%|██████████| 354/354 [00:03<00:00, 115.41it/s]

--- Train epoch-1, step-708 ---
loss: 0.2850



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 190.11it/s]


--- Eval epoch-1, step-708 ---
pr_auc_samples: 0.6739
loss: 0.2737
New best pr_auc_samples score (0.6739) at epoch-1, step-708



Epoch 2 / 30: 100%|██████████| 354/354 [00:03<00:00, 113.91it/s]

--- Train epoch-2, step-1062 ---
loss: 0.2814



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 251.02it/s]


--- Eval epoch-2, step-1062 ---
pr_auc_samples: 0.6660
loss: 0.2804



Epoch 3 / 30: 100%|██████████| 354/354 [00:03<00:00, 106.96it/s]

--- Train epoch-3, step-1416 ---
loss: 0.2804



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 246.46it/s]


--- Eval epoch-3, step-1416 ---
pr_auc_samples: 0.6767
loss: 0.2708
New best pr_auc_samples score (0.6767) at epoch-3, step-1416



Epoch 4 / 30: 100%|██████████| 354/354 [00:03<00:00, 114.07it/s]

--- Train epoch-4, step-1770 ---
loss: 0.2796



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 249.24it/s]


--- Eval epoch-4, step-1770 ---
pr_auc_samples: 0.6756
loss: 0.2705



Epoch 5 / 30: 100%|██████████| 354/354 [00:03<00:00, 113.81it/s]

--- Train epoch-5, step-2124 ---
loss: 0.2778



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 244.54it/s]


--- Eval epoch-5, step-2124 ---
pr_auc_samples: 0.6749
loss: 0.2707



Epoch 6 / 30: 100%|██████████| 354/354 [00:02<00:00, 122.88it/s]

--- Train epoch-6, step-2478 ---
loss: 0.2769



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 253.47it/s]


--- Eval epoch-6, step-2478 ---
pr_auc_samples: 0.6789
loss: 0.2682
New best pr_auc_samples score (0.6789) at epoch-6, step-2478



Epoch 7 / 30: 100%|██████████| 354/354 [00:02<00:00, 122.20it/s]

--- Train epoch-7, step-2832 ---
loss: 0.2758



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 251.57it/s]


--- Eval epoch-7, step-2832 ---
pr_auc_samples: 0.6793
loss: 0.2690
New best pr_auc_samples score (0.6793) at epoch-7, step-2832



Epoch 8 / 30: 100%|██████████| 354/354 [00:02<00:00, 122.07it/s]

--- Train epoch-8, step-3186 ---
loss: 0.2746



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 253.74it/s]


--- Eval epoch-8, step-3186 ---
pr_auc_samples: 0.6771
loss: 0.2682



Epoch 9 / 30: 100%|██████████| 354/354 [00:02<00:00, 122.99it/s]

--- Train epoch-9, step-3540 ---
loss: 0.2746



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 260.03it/s]


--- Eval epoch-9, step-3540 ---
pr_auc_samples: 0.6763
loss: 0.2711



Epoch 10 / 30: 100%|██████████| 354/354 [00:03<00:00, 106.40it/s]

--- Train epoch-10, step-3894 ---
loss: 0.2749



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 252.80it/s]


--- Eval epoch-10, step-3894 ---
pr_auc_samples: 0.6791
loss: 0.2672



Epoch 11 / 30: 100%|██████████| 354/354 [00:03<00:00, 107.26it/s]

--- Train epoch-11, step-4248 ---
loss: 0.2745



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 253.47it/s]


--- Eval epoch-11, step-4248 ---
pr_auc_samples: 0.6748
loss: 0.2695



Epoch 12 / 30: 100%|██████████| 354/354 [00:03<00:00, 104.08it/s]

--- Train epoch-12, step-4602 ---
loss: 0.2754



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 207.45it/s]


--- Eval epoch-12, step-4602 ---
pr_auc_samples: 0.6774
loss: 0.2685



Epoch 13 / 30: 100%|██████████| 354/354 [00:03<00:00, 98.22it/s] 

--- Train epoch-13, step-4956 ---
loss: 0.2748



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 238.87it/s]


--- Eval epoch-13, step-4956 ---
pr_auc_samples: 0.6775
loss: 0.2673



Epoch 14 / 30: 100%|██████████| 354/354 [00:03<00:00, 105.55it/s]

--- Train epoch-14, step-5310 ---
loss: 0.2734



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 237.93it/s]


--- Eval epoch-14, step-5310 ---
pr_auc_samples: 0.6764
loss: 0.2675



Epoch 15 / 30: 100%|██████████| 354/354 [00:03<00:00, 117.46it/s]

--- Train epoch-15, step-5664 ---
loss: 0.2727



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 247.20it/s]


--- Eval epoch-15, step-5664 ---
pr_auc_samples: 0.6792
loss: 0.2668



Epoch 16 / 30: 100%|██████████| 354/354 [00:03<00:00, 115.40it/s]

--- Train epoch-16, step-6018 ---
loss: 0.2727



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 248.82it/s]


--- Eval epoch-16, step-6018 ---
pr_auc_samples: 0.6779
loss: 0.2683



Epoch 17 / 30: 100%|██████████| 354/354 [00:03<00:00, 108.12it/s]

--- Train epoch-17, step-6372 ---
loss: 0.2717



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 201.71it/s]


--- Eval epoch-17, step-6372 ---
pr_auc_samples: 0.6800
loss: 0.2653
New best pr_auc_samples score (0.6800) at epoch-17, step-6372



Epoch 18 / 30: 100%|██████████| 354/354 [00:03<00:00, 100.61it/s]

--- Train epoch-18, step-6726 ---
loss: 0.2714



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 244.36it/s]


--- Eval epoch-18, step-6726 ---
pr_auc_samples: 0.6803
loss: 0.2652
New best pr_auc_samples score (0.6803) at epoch-18, step-6726



Epoch 19 / 30: 100%|██████████| 354/354 [00:03<00:00, 106.16it/s]

--- Train epoch-19, step-7080 ---
loss: 0.2709



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 251.52it/s]


--- Eval epoch-19, step-7080 ---
pr_auc_samples: 0.6787
loss: 0.2655



Epoch 20 / 30: 100%|██████████| 354/354 [00:03<00:00, 111.55it/s]

--- Train epoch-20, step-7434 ---
loss: 0.2706



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 244.58it/s]


--- Eval epoch-20, step-7434 ---
pr_auc_samples: 0.6799
loss: 0.2649



Epoch 21 / 30: 100%|██████████| 354/354 [00:02<00:00, 120.27it/s]

--- Train epoch-21, step-7788 ---
loss: 0.2702



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 254.74it/s]


--- Eval epoch-21, step-7788 ---
pr_auc_samples: 0.6789
loss: 0.2667



Epoch 22 / 30: 100%|██████████| 354/354 [00:03<00:00, 117.05it/s]

--- Train epoch-22, step-8142 ---
loss: 0.2701



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 233.34it/s]


--- Eval epoch-22, step-8142 ---
pr_auc_samples: 0.6821
loss: 0.2636
New best pr_auc_samples score (0.6821) at epoch-22, step-8142



Epoch 23 / 30: 100%|██████████| 354/354 [00:02<00:00, 126.18it/s]

--- Train epoch-23, step-8496 ---
loss: 0.2706



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 255.52it/s]


--- Eval epoch-23, step-8496 ---
pr_auc_samples: 0.6807
loss: 0.2647



Epoch 24 / 30: 100%|██████████| 354/354 [00:02<00:00, 124.11it/s]

--- Train epoch-24, step-8850 ---
loss: 0.2714



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 256.42it/s]


--- Eval epoch-24, step-8850 ---
pr_auc_samples: 0.6791
loss: 0.2671



Epoch 25 / 30: 100%|██████████| 354/354 [00:02<00:00, 126.17it/s]

--- Train epoch-25, step-9204 ---
loss: 0.2698



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 262.67it/s]


--- Eval epoch-25, step-9204 ---
pr_auc_samples: 0.6800
loss: 0.2639



Epoch 26 / 30: 100%|██████████| 354/354 [00:02<00:00, 128.08it/s]

--- Train epoch-26, step-9558 ---
loss: 0.2703



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 260.66it/s]


--- Eval epoch-26, step-9558 ---
pr_auc_samples: 0.6785
loss: 0.2669



Epoch 27 / 30: 100%|██████████| 354/354 [00:02<00:00, 126.07it/s]

--- Train epoch-27, step-9912 ---
loss: 0.2702



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 256.24it/s]


--- Eval epoch-27, step-9912 ---
pr_auc_samples: 0.6809
loss: 0.2640



Epoch 28 / 30: 100%|██████████| 354/354 [00:02<00:00, 125.92it/s]

--- Train epoch-28, step-10266 ---
loss: 0.2701



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 215.82it/s]


--- Eval epoch-28, step-10266 ---
pr_auc_samples: 0.6825
loss: 0.2629
New best pr_auc_samples score (0.6825) at epoch-28, step-10266



Epoch 29 / 30: 100%|██████████| 354/354 [00:03<00:00, 101.93it/s]

--- Train epoch-29, step-10620 ---
loss: 0.2696



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 248.80it/s]


--- Eval epoch-29, step-10620 ---
pr_auc_samples: 0.6822
loss: 0.2633
Loaded best model


Evaluation: 100%|██████████| 44/44 [00:00<00:00, 257.54it/s]


{'pr_auc_samples': 0.6888944454447027, 'loss': 0.2677140161395073}


In [5]:
model_no_pre = RNN(
    dataset=mimic3_ds,
    feature_keys=["conditions", "procedures"],
    label_key="drugs",
    mode="multilabel",
    embedding_dim=128,
)

In [6]:
# STEP 4: define trainer
trainer = Trainer(model=model_no_pre)
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=10,
    monitor="pr_auc_samples",
)

# STEP 5: evaluate
print(trainer.evaluate(test_dataloader))

RNN(
  (rand_init_embedding): ModuleDict()
  (pretrained_embedding): ModuleDict()
  (embeddings): ModuleDict(
    (conditions): Embedding(4493, 128, padding_idx=0)
    (procedures): Embedding(1414, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (rnn): ModuleDict(
    (conditions): RNNLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (rnn): GRU(128, 128, batch_first=True)
    )
    (procedures): RNNLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (rnn): GRU(128, 128, batch_first=True)
    )
  )
  (fc): Linear(in_features=256, out_features=193, bias=True)
)
Metrics: None
Device: cuda

Training:
Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7ff51064a8b0>
Monitor: pr_auc_samples
Monitor criterion: max
Epochs: 10



Epoch 0 / 10: 100%|██████████| 354/354 [00:02<00:00, 125.89it/s]

--- Train epoch-0, step-354 ---
loss: 0.3037



Evaluation: 100%|██████████| 44/44 [00:00<00:00, 244.18it/s]


--- Eval epoch-0, step-354 ---
pr_auc_samples: 0.7173
loss: 0.2600
New best pr_auc_samples score (0.7173) at epoch-0, step-354



Epoch 1 / 10: 100%|██████████| 354/354 [00:02<00:00, 124.47it/s]

--- Train epoch-1, step-708 ---
loss: 0.2580



Evaluation: 100%|██████████| 44/44 [00:00<00:00, 246.94it/s]


--- Eval epoch-1, step-708 ---
pr_auc_samples: 0.7307
loss: 0.2504
New best pr_auc_samples score (0.7307) at epoch-1, step-708



Epoch 2 / 10: 100%|██████████| 354/354 [00:02<00:00, 120.08it/s]

--- Train epoch-2, step-1062 ---
loss: 0.2485



Evaluation: 100%|██████████| 44/44 [00:00<00:00, 259.72it/s]


--- Eval epoch-2, step-1062 ---
pr_auc_samples: 0.7343
loss: 0.2470
New best pr_auc_samples score (0.7343) at epoch-2, step-1062



Epoch 3 / 10: 100%|██████████| 354/354 [00:02<00:00, 122.50it/s]

--- Train epoch-3, step-1416 ---
loss: 0.2427



Evaluation: 100%|██████████| 44/44 [00:00<00:00, 263.01it/s]


--- Eval epoch-3, step-1416 ---
pr_auc_samples: 0.7376
loss: 0.2454
New best pr_auc_samples score (0.7376) at epoch-3, step-1416



Epoch 4 / 10: 100%|██████████| 354/354 [00:02<00:00, 122.97it/s]

--- Train epoch-4, step-1770 ---
loss: 0.2391



Evaluation: 100%|██████████| 44/44 [00:00<00:00, 271.60it/s]


--- Eval epoch-4, step-1770 ---
pr_auc_samples: 0.7384
loss: 0.2446
New best pr_auc_samples score (0.7384) at epoch-4, step-1770



Epoch 5 / 10: 100%|██████████| 354/354 [00:02<00:00, 124.45it/s]

--- Train epoch-5, step-2124 ---
loss: 0.2360



Evaluation: 100%|██████████| 44/44 [00:00<00:00, 267.90it/s]


--- Eval epoch-5, step-2124 ---
pr_auc_samples: 0.7398
loss: 0.2439
New best pr_auc_samples score (0.7398) at epoch-5, step-2124



Epoch 6 / 10: 100%|██████████| 354/354 [00:03<00:00, 107.20it/s]

--- Train epoch-6, step-2478 ---
loss: 0.2329



Evaluation: 100%|██████████| 44/44 [00:00<00:00, 255.18it/s]


--- Eval epoch-6, step-2478 ---
pr_auc_samples: 0.7430
loss: 0.2432
New best pr_auc_samples score (0.7430) at epoch-6, step-2478



Epoch 7 / 10: 100%|██████████| 354/354 [00:03<00:00, 106.04it/s]

--- Train epoch-7, step-2832 ---
loss: 0.2307



Evaluation: 100%|██████████| 44/44 [00:00<00:00, 258.11it/s]


--- Eval epoch-7, step-2832 ---
pr_auc_samples: 0.7428
loss: 0.2426



Epoch 8 / 10: 100%|██████████| 354/354 [00:03<00:00, 110.50it/s]

--- Train epoch-8, step-3186 ---
loss: 0.2285



Evaluation: 100%|██████████| 44/44 [00:00<00:00, 219.87it/s]


--- Eval epoch-8, step-3186 ---
pr_auc_samples: 0.7439
loss: 0.2426
New best pr_auc_samples score (0.7439) at epoch-8, step-3186



Epoch 9 / 10: 100%|██████████| 354/354 [00:03<00:00, 112.81it/s]

--- Train epoch-9, step-3540 ---
loss: 0.2262



Evaluation: 100%|██████████| 44/44 [00:00<00:00, 230.73it/s]


--- Eval epoch-9, step-3540 ---
pr_auc_samples: 0.7431
loss: 0.2434
Loaded best model


Evaluation: 100%|██████████| 45/45 [00:00<00:00, 243.22it/s]


{'pr_auc_samples': 0.7394630422926343, 'loss': 0.2355364070998298}
