# Transformer on MIMIC-IV (Demo)
This minimal walkthrough uses PyHealth's built-in utilities to train the PyHealth Transformer model on a MIMIC-IV mortality prediction task.

In [1]:
from pyhealth.datasets import MIMIC4Dataset, get_dataloader, split_by_sample
from pyhealth.tasks import InHospitalMortalityMIMIC4
from pyhealth.models import Transformer
from pyhealth.trainer import Trainer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = MIMIC4Dataset(
    ehr_root="/home/logic/Github/mimic4",
    ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents"],
    dev=True,
)

Memory usage Starting MIMIC4Dataset init: 921.5 MB
Initializing MIMIC4EHRDataset with tables: ['diagnoses_icd', 'procedures_icd', 'prescriptions', 'labevents'] (dev mode: True)
Using default EHR config: /home/logic/miniforge3/envs/pyhealth/lib/python3.12/site-packages/pyhealth/datasets/configs/mimic4_ehr.yaml
Memory usage Before initializing mimic4_ehr: 921.5 MB
Initializing mimic4_ehr dataset from /home/logic/Github/mimic4 (dev mode: False)
Scanning table: diagnoses_icd from /home/logic/Github/mimic4/hosp/diagnoses_icd.csv.gz
Initializing MIMIC4EHRDataset with tables: ['diagnoses_icd', 'procedures_icd', 'prescriptions', 'labevents'] (dev mode: True)
Using default EHR config: /home/logic/miniforge3/envs/pyhealth/lib/python3.12/site-packages/pyhealth/datasets/configs/mimic4_ehr.yaml
Memory usage Before initializing mimic4_ehr: 921.5 MB
Initializing mimic4_ehr dataset from /home/logic/Github/mimic4 (dev mode: False)
Scanning table: diagnoses_icd from /home/logic/Github/mimic4/hosp/diagno

In [3]:
task = InHospitalMortalityMIMIC4()
sample_dataset = dataset.set_task(
    task,
    cache_dir="../../test_cache_transformer_m4"
)
train_dataset, val_dataset, test_dataset = split_by_sample(sample_dataset, ratios=[0.7, 0.1, 0.2])

Setting task InHospitalMortalityMIMIC4 for mimic4 base dataset...
Generating samples with 1 worker(s)...
Collecting global event dataframe...
Dev mode enabled: limiting to 1000 patients
Generating samples with 1 worker(s)...
Collecting global event dataframe...
Dev mode enabled: limiting to 1000 patients
Collected dataframe with shape: (131557, 47)
Collected dataframe with shape: (131557, 47)


Generating samples for InHospitalMortalityMIMIC4 with 1 worker: 100%|██████████| 100/100 [00:01<00:00, 95.23it/s]

Caching samples to ../../test_cache_transformer_m4/InHospitalMortalityMIMIC4.parquet
Failed to cache samples: failed to determine supertype of list[datetime[μs]] and object
Label mortality vocab: {0: 0, 1: 1}
Failed to cache samples: failed to determine supertype of list[datetime[μs]] and object
Label mortality vocab: {0: 0, 1: 1}



Processing samples: 100%|██████████| 216/216 [00:00<00:00, 1793.08it/s]

Generated 216 samples for task InHospitalMortalityMIMIC4





In [4]:
train_loader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

In [6]:
model = Transformer(dataset=sample_dataset, embedding_dim=128, heads=2, num_layers=2, dropout=0.1)
trainer = Trainer(model=model, metrics=["roc_auc", "pr_auc"])

Transformer(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (labs): Linear(in_features=27, out_features=128, bias=True)
  ))
  (transformer): ModuleDict(
    (labs): TransformerLayer(
      (transformer): ModuleList(
        (0-1): 2 x TransformerBlock(
          (attention): MultiHeadedAttention(
            (linear_layers): ModuleList(
              (0-2): 3 x Linear(in_features=128, out_features=128, bias=False)
            )
            (output_linear): Linear(in_features=128, out_features=128, bias=False)
            (attention): Attention()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (feed_forward): PositionwiseFeedForward(
            (w_1): Linear(in_features=128, out_features=512, bias=True)
            (w_2): Linear(in_features=512, out_features=128, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (activation): GELU(approximate='none')
          )
          (input_sublayer): SublayerConnection(
 

In [7]:
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=5,
    monitor="roc_auc",
    optimizer_params={"lr": 1e-4},
)

Training:
Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.0001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7f5e0a8b1fd0>
Monitor: roc_auc
Monitor criterion: max
Epochs: 5

Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.0001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7f5e0a8b1fd0>
Monitor: roc_auc
Monitor criterion: max
Epochs: 5



Epoch 0 / 5: 100%|██████████| 5/5 [00:00<00:00,  7.38it/s]

--- Train epoch-0, step-5 ---
loss: 9.4410
loss: 9.4410



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 389.26it/s]

--- Eval epoch-0, step-5 ---
roc_auc: 0.0000
pr_auc: 0.0476
loss: 7.3290
New best roc_auc score (0.0000) at epoch-0, step-5

roc_auc: 0.0000
pr_auc: 0.0476
loss: 7.3290
New best roc_auc score (0.0000) at epoch-0, step-5




Epoch 1 / 5: 100%|██████████| 5/5 [00:00<00:00, 184.46it/s]

--- Train epoch-1, step-10 ---
loss: 8.0077
loss: 8.0077



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 378.68it/s]

--- Eval epoch-1, step-10 ---
roc_auc: 0.1500
pr_auc: 0.0556
loss: 4.7464
New best roc_auc score (0.1500) at epoch-1, step-10

roc_auc: 0.1500
pr_auc: 0.0556
loss: 4.7464
New best roc_auc score (0.1500) at epoch-1, step-10




Epoch 2 / 5: 100%|██████████| 5/5 [00:00<00:00, 185.98it/s]

--- Train epoch-2, step-15 ---
loss: 4.9490
loss: 4.9490



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 454.91it/s]

--- Eval epoch-2, step-15 ---
roc_auc: 0.1500
pr_auc: 0.0556
loss: 2.3103

roc_auc: 0.1500
pr_auc: 0.0556
loss: 2.3103




Epoch 3 / 5: 100%|██████████| 5/5 [00:00<00:00, 192.87it/s]

--- Train epoch-3, step-20 ---
loss: 3.4749
loss: 3.4749



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 445.30it/s]

--- Eval epoch-3, step-20 ---
roc_auc: 0.2500
pr_auc: 0.0625
loss: 0.7218
New best roc_auc score (0.2500) at epoch-3, step-20

roc_auc: 0.2500
pr_auc: 0.0625
loss: 0.7218
New best roc_auc score (0.2500) at epoch-3, step-20




Epoch 4 / 5: 100%|██████████| 5/5 [00:00<00:00, 195.50it/s]

--- Train epoch-4, step-25 ---
loss: 2.0244
loss: 2.0244



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 438.87it/s]

--- Eval epoch-4, step-25 ---
roc_auc: 0.7500
pr_auc: 0.1667
loss: 0.2822
New best roc_auc score (0.7500) at epoch-4, step-25
roc_auc: 0.7500
pr_auc: 0.1667
loss: 0.2822
New best roc_auc score (0.7500) at epoch-4, step-25
Loaded best model
Loaded best model





In [8]:
trainer.evaluate(test_loader)

Evaluation: 100%|██████████| 2/2 [00:00<00:00, 381.40it/s]
Evaluation: 100%|██████████| 2/2 [00:00<00:00, 381.40it/s]


{'roc_auc': 0.4227642276422764,
 'pr_auc': 0.07558781088192854,
 'loss': 0.4880794435739517}