In [8]:
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.models import Transformer, RNN
from pyhealth.tasks import mortality_prediction_mimic3_fn, drug_recommendation_mimic3_fn
from pyhealth.trainer import Trainer

dataset = MIMIC3Dataset(
    root='/data/physionet.org/files/mimiciii/1.4/',
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
    code_mapping={
        "ICD9CM": "CCSCM", 
        "ICD9PROC": "CCSPROC",
        "NDC": ("ATC", {"target_kwargs": {"level": 3}})
        },
    dev=False,
    refresh_cache=False
)


In [10]:
mimic3_ds = dataset.set_task(drug_recommendation_mimic3_fn)

train_dataset, val_dataset, test_dataset = split_by_patient(
    mimic3_ds, [0.8, 0.1, 0.1]
)
train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

Generating samples for drug_recommendation_mimic3_fn: 100%|██████████| 46520/46520 [00:02<00:00, 19741.61it/s]


In [17]:
model_w_pre = Transformer(
    dataset=mimic3_ds,
    feature_keys=["conditions", "procedures"],
    label_key="drugs",
    mode="multilabel",
    pretrained_emb="LM",
    embedding_dim=768,
)

Loading pretrained embedding for conditions...
Loading pretrained embedding for procedures...


In [18]:
# STEP 4: define trainer
trainer = Trainer(model=model_w_pre)
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=10,
    monitor="pr_auc_samples",
)

# STEP 5: evaluate
print(trainer.evaluate(test_dataloader))

Transformer(
  (embeddings): ModuleDict(
    (conditions): Embedding(266, 1536, padding_idx=0)
    (procedures): Embedding(207, 1536, padding_idx=0)
  )
  (linear_layers): ModuleDict(
    (conditions): Linear(in_features=1536, out_features=768, bias=True)
    (procedures): Linear(in_features=1536, out_features=768, bias=True)
  )
  (transformer): ModuleDict(
    (conditions): TransformerLayer(
      (transformer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadedAttention(
            (linear_layers): ModuleList(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): Linear(in_features=768, out_features=768, bias=False)
              (2): Linear(in_features=768, out_features=768, bias=False)
            )
            (output_linear): Linear(in_features=768, out_features=768, bias=False)
            (attention): Attention()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (feed_forward): Positionw

Epoch 0 / 10: 100%|██████████| 353/353 [00:03<00:00, 97.17it/s]

--- Train epoch-0, step-353 ---
loss: 0.3051



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 260.33it/s]


--- Eval epoch-0, step-353 ---
pr_auc_samples: 0.7006
loss: 0.2761
New best pr_auc_samples score (0.7006) at epoch-0, step-353



Epoch 1 / 10: 100%|██████████| 353/353 [00:03<00:00, 98.74it/s] 

--- Train epoch-1, step-706 ---
loss: 0.2750



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 219.24it/s]


--- Eval epoch-1, step-706 ---
pr_auc_samples: 0.7134
loss: 0.2650
New best pr_auc_samples score (0.7134) at epoch-1, step-706



Epoch 2 / 10: 100%|██████████| 353/353 [00:04<00:00, 85.47it/s]

--- Train epoch-2, step-1059 ---
loss: 0.2693



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 235.00it/s]


--- Eval epoch-2, step-1059 ---
pr_auc_samples: 0.7078
loss: 0.2710



Epoch 3 / 10: 100%|██████████| 353/353 [00:03<00:00, 90.52it/s]

--- Train epoch-3, step-1412 ---
loss: 0.4866



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 209.78it/s]


--- Eval epoch-3, step-1412 ---
pr_auc_samples: 0.6921
loss: 0.3054



Epoch 4 / 10: 100%|██████████| 353/353 [00:03<00:00, 90.88it/s]

--- Train epoch-4, step-1765 ---
loss: 0.2895



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 237.98it/s]


--- Eval epoch-4, step-1765 ---
pr_auc_samples: 0.7137
loss: 0.2705
New best pr_auc_samples score (0.7137) at epoch-4, step-1765



Epoch 5 / 10: 100%|██████████| 353/353 [00:03<00:00, 94.59it/s]

--- Train epoch-5, step-2118 ---
loss: 0.2728



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 140.14it/s]


--- Eval epoch-5, step-2118 ---
pr_auc_samples: 0.7227
loss: 0.2653
New best pr_auc_samples score (0.7227) at epoch-5, step-2118



Epoch 6 / 10: 100%|██████████| 353/353 [00:03<00:00, 101.64it/s]

--- Train epoch-6, step-2471 ---
loss: 0.2668



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 243.41it/s]


--- Eval epoch-6, step-2471 ---
pr_auc_samples: 0.7171
loss: 0.2678



Epoch 7 / 10: 100%|██████████| 353/353 [00:03<00:00, 91.85it/s]

--- Train epoch-7, step-2824 ---
loss: 0.2674



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 227.47it/s]


--- Eval epoch-7, step-2824 ---
pr_auc_samples: 0.7206
loss: 0.2648



Epoch 8 / 10: 100%|██████████| 353/353 [00:03<00:00, 91.21it/s]

--- Train epoch-8, step-3177 ---
loss: 0.2682



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 156.23it/s]


--- Eval epoch-8, step-3177 ---
pr_auc_samples: 0.7158
loss: 0.2710



Epoch 9 / 10: 100%|██████████| 353/353 [00:03<00:00, 96.24it/s] 

--- Train epoch-9, step-3530 ---
loss: 0.2710



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 247.03it/s]


--- Eval epoch-9, step-3530 ---
pr_auc_samples: 0.7140
loss: 0.2669
Loaded best model


Evaluation: 100%|██████████| 45/45 [00:00<00:00, 242.77it/s]


{'pr_auc_samples': 0.7158427348069376, 'loss': 0.2679338147242864}


In [13]:
model_no_pre = Transformer(
    dataset=mimic3_ds,
    feature_keys=["conditions", "procedures"],
    label_key="drugs",
    mode="multilabel",
)

In [16]:
# STEP 4: define trainer
trainer = Trainer(model=model_no_pre)
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=10,
    monitor="pr_auc_samples",
)

# STEP 5: evaluate
print(trainer.evaluate(test_dataloader))

Transformer(
  (embeddings): ModuleDict(
    (conditions): Embedding(266, 128, padding_idx=0)
    (procedures): Embedding(207, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (transformer): ModuleDict(
    (conditions): TransformerLayer(
      (transformer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadedAttention(
            (linear_layers): ModuleList(
              (0): Linear(in_features=128, out_features=128, bias=False)
              (1): Linear(in_features=128, out_features=128, bias=False)
              (2): Linear(in_features=128, out_features=128, bias=False)
            )
            (output_linear): Linear(in_features=128, out_features=128, bias=False)
            (attention): Attention()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (feed_forward): PositionwiseFeedForward(
            (w_1): Linear(in_features=128, out_features=512, bias=True)
            (w_2): Linear(in_features=512, out_features=128, 

Epoch 0 / 10: 100%|██████████| 353/353 [00:03<00:00, 96.92it/s]

--- Train epoch-0, step-353 ---
loss: 0.2501



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 239.97it/s]


--- Eval epoch-0, step-353 ---
pr_auc_samples: 0.7319
loss: 0.2561
New best pr_auc_samples score (0.7319) at epoch-0, step-353



Epoch 1 / 10: 100%|██████████| 353/353 [00:03<00:00, 88.91it/s]

--- Train epoch-1, step-706 ---
loss: 0.2476



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 269.82it/s]


--- Eval epoch-1, step-706 ---
pr_auc_samples: 0.7312
loss: 0.2569



Epoch 2 / 10: 100%|██████████| 353/353 [00:04<00:00, 85.80it/s]

--- Train epoch-2, step-1059 ---
loss: 0.2469



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 199.83it/s]


--- Eval epoch-2, step-1059 ---
pr_auc_samples: 0.7291
loss: 0.2575



Epoch 3 / 10: 100%|██████████| 353/353 [00:04<00:00, 86.24it/s]

--- Train epoch-3, step-1412 ---
loss: 0.2454



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 276.57it/s]


--- Eval epoch-3, step-1412 ---
pr_auc_samples: 0.7289
loss: 0.2563



Epoch 4 / 10: 100%|██████████| 353/353 [00:04<00:00, 86.50it/s]

--- Train epoch-4, step-1765 ---
loss: 0.2444



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 278.52it/s]


--- Eval epoch-4, step-1765 ---
pr_auc_samples: 0.7301
loss: 0.2616



Epoch 5 / 10: 100%|██████████| 353/353 [00:03<00:00, 89.25it/s]

--- Train epoch-5, step-2118 ---
loss: 0.2443



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 277.80it/s]


--- Eval epoch-5, step-2118 ---
pr_auc_samples: 0.7309
loss: 0.2589



Epoch 6 / 10: 100%|██████████| 353/353 [00:03<00:00, 93.74it/s]

--- Train epoch-6, step-2471 ---
loss: 0.2429



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 188.87it/s]


--- Eval epoch-6, step-2471 ---
pr_auc_samples: 0.7296
loss: 0.2577



Epoch 7 / 10: 100%|██████████| 353/353 [00:03<00:00, 90.26it/s]

--- Train epoch-7, step-2824 ---
loss: 0.2409



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 183.14it/s]


--- Eval epoch-7, step-2824 ---
pr_auc_samples: 0.7305
loss: 0.2624



Epoch 8 / 10: 100%|██████████| 353/353 [00:03<00:00, 88.51it/s]

--- Train epoch-8, step-3177 ---
loss: 0.2407



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 274.85it/s]


--- Eval epoch-8, step-3177 ---
pr_auc_samples: 0.7282
loss: 0.2611



Epoch 9 / 10: 100%|██████████| 353/353 [00:03<00:00, 90.44it/s]

--- Train epoch-9, step-3530 ---
loss: 0.2395



Evaluation: 100%|██████████| 45/45 [00:00<00:00, 177.23it/s]


--- Eval epoch-9, step-3530 ---
pr_auc_samples: 0.7295
loss: 0.2651
Loaded best model


Evaluation: 100%|██████████| 45/45 [00:00<00:00, 269.94it/s]


{'pr_auc_samples': 0.7230120046162672, 'loss': 0.2581981142361959}
