# EHRMAMBA on MIMIC-IV (Demo)

This notebook trains the EHRMAMBA model (Mamba/SSM-based foundation model for EHR) on a MIMIC-IV in-hospital mortality prediction task. Same pipeline as the Transformer example; only the model class changes.

In [4]:
from pyhealth.datasets import MIMIC4Dataset, get_dataloader, split_by_sample
from pyhealth.tasks import InHospitalMortalityMIMIC4
from pyhealth.models import EHRMamba
from pyhealth.trainer import Trainer

In [5]:
# Use the bundled MIMIC-IV demo (test-resources/core/mimic4demo). For full MIMIC-IV,
# download from PhysioNet and set ehr_root to that path (e.g. /app/data/mimic4 in Docker).
from pathlib import Path
_repo = Path.cwd().parent if (Path.cwd().parent / "test-resources").exists() else Path.cwd()
ehr_root = str(_repo / "test-resources" / "core" / "mimic4demo")

dataset = MIMIC4Dataset(
    ehr_root=ehr_root,
    ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents"],
    dev=True,
)

Memory usage Starting MIMIC4Dataset init: 924.2 MB
Initializing mimic4 dataset from /app/test-resources/core/mimic4demo|None|None (dev mode: True)
Initializing MIMIC4EHRDataset with tables: ['diagnoses_icd', 'procedures_icd', 'prescriptions', 'labevents'] (dev mode: True)
No cache_dir provided. Using default cache dir: /root/.cache/pyhealth/b8f6701e-811b-5f3b-87fd-8a6e9a4e078a
Using default EHR config: /app/pyhealth/datasets/configs/mimic4_ehr.yaml
Memory usage Before initializing mimic4_ehr: 924.2 MB
Initializing mimic4_ehr dataset from /app/test-resources/core/mimic4demo (dev mode: True)
Memory usage After initializing mimic4_ehr: 924.2 MB
Memory usage After EHR dataset initialization: 924.2 MB
Memory usage Completed MIMIC4Dataset init: 924.2 MB


In [6]:
task = InHospitalMortalityMIMIC4()
sample_dataset = dataset.set_task(task)
train_dataset, val_dataset, test_dataset = split_by_sample(sample_dataset, ratios=[0.7, 0.1, 0.2])

Setting task InHospitalMortalityMIMIC4 for mimic4 base dataset...
Applying task transformations on data with 1 workers...
Combining data from ehr dataset
Scanning table: diagnoses_icd from /app/test-resources/core/mimic4demo/hosp/diagnoses_icd.csv.gz
Joining with table: /app/test-resources/core/mimic4demo/hosp/admissions.csv.gz
Scanning table: procedures_icd from /app/test-resources/core/mimic4demo/hosp/procedures_icd.csv.gz
Joining with table: /app/test-resources/core/mimic4demo/hosp/admissions.csv.gz
Scanning table: prescriptions from /app/test-resources/core/mimic4demo/hosp/prescriptions.csv.gz
Scanning table: labevents from /app/test-resources/core/mimic4demo/hosp/labevents.csv.gz
Joining with table: /app/test-resources/core/mimic4demo/hosp/d_labitems.csv.gz
Scanning table: patients from /app/test-resources/core/mimic4demo/hosp/patients.csv.gz
Scanning table: admissions from /app/test-resources/core/mimic4demo/hosp/admissions.csv.gz
Scanning table: icustays from /app/test-resources



Detected Jupyter notebook environment, setting num_workers to 1
Single worker mode, processing sequentially
Worker 0 started processing 2 patients. (Polars threads: 8)


  0%|                                                                           | 0/2 [00:00<?, ?it/s]

Rank 0 inferred the following `['bytes']` data format.


100%|███████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  6.80it/s]

Worker 0 finished processing patients.
Fitting processors on the dataset...
Label mortality vocab: {0: 0, 1: 1}
Processing samples and saving to /root/.cache/pyhealth/b8f6701e-811b-5f3b-87fd-8a6e9a4e078a/tasks/InHospitalMortalityMIMIC4_f8cedbe4-72a8-53c3-922d-4cc8730f4c2d/samples_6e5c4c84-6592-5ced-9ee0-7774a09e9d36.ld...
Applying processors on data with 1 workers...
Detected Jupyter notebook environment, setting num_workers to 1
Single worker mode, processing sequentially
Worker 0 started processing 4 samples. (0 to 4)



  0%|                                                                           | 0/4 [00:00<?, ?it/s]

Rank 0 inferred the following `['str', 'str', 'tensor', 'no_header_tensor:1']` data format.


100%|███████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 88.18it/s]

Worker 0 finished processing samples.
Cached processed samples to /root/.cache/pyhealth/b8f6701e-811b-5f3b-87fd-8a6e9a4e078a/tasks/InHospitalMortalityMIMIC4_f8cedbe4-72a8-53c3-922d-4cc8730f4c2d/samples_6e5c4c84-6592-5ced-9ee0-7774a09e9d36.ld





In [7]:
train_loader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

In [8]:
model = EHRMamba(
    dataset=sample_dataset,
    embedding_dim=128,
    num_layers=2,
    state_size=16,
    conv_kernel=4,
    dropout=0.1,
)
trainer = Trainer(model=model, metrics=["roc_auc", "pr_auc"])

EHRMamba(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (labs): Linear(in_features=27, out_features=128, bias=True)
  ))
  (blocks): ModuleDict(
    (labs): ModuleList(
      (0-1): 2 x MambaBlock(
        (norm): RMSNorm()
        (in_proj): Linear(in_features=128, out_features=512, bias=True)
        (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,))
        (out_proj): Linear(in_features=256, out_features=128, bias=True)
      )
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (fc): Linear(in_features=128, out_features=1, bias=True)
)
Metrics: ['roc_auc', 'pr_auc']
Device: cpu



In [9]:
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=5,
    monitor="roc_auc",
    optimizer_params={"lr": 1e-4},
)

Training:
Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.0001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7f859297d970>
Monitor: roc_auc
Monitor criterion: max
Epochs: 5
Patience: None



Epoch 0 / 5:   0%|          | 0/1 [00:00<?, ?it/s]

--- Train epoch-0, step-1 ---
loss: 16.7915


Evaluation: 100%|███████████████████████████████████████████████████████| 1/1 [00:00<00:00, 20.95it/s]

--- Eval epoch-0, step-1 ---
roc_auc: nan
pr_auc: 1.0000
loss: 0.0000






Epoch 1 / 5:   0%|          | 0/1 [00:00<?, ?it/s]

--- Train epoch-1, step-2 ---
loss: 13.0706


Evaluation: 100%|███████████████████████████████████████████████████████| 1/1 [00:00<00:00, 66.47it/s]


--- Eval epoch-1, step-2 ---
roc_auc: nan
pr_auc: 1.0000
loss: 0.0000



Epoch 2 / 5:   0%|          | 0/1 [00:00<?, ?it/s]

--- Train epoch-2, step-3 ---
loss: 12.1309


Evaluation: 100%|███████████████████████████████████████████████████████| 1/1 [00:00<00:00, 84.14it/s]


--- Eval epoch-2, step-3 ---
roc_auc: nan
pr_auc: 1.0000
loss: 0.0000



Epoch 3 / 5:   0%|          | 0/1 [00:00<?, ?it/s]

--- Train epoch-3, step-4 ---
loss: 12.1212


Evaluation: 100%|███████████████████████████████████████████████████████| 1/1 [00:00<00:00, 15.01it/s]

--- Eval epoch-3, step-4 ---
roc_auc: nan
pr_auc: 1.0000
loss: 0.0000






Epoch 4 / 5:   0%|          | 0/1 [00:00<?, ?it/s]

--- Train epoch-4, step-5 ---
loss: 7.6293


Evaluation: 100%|███████████████████████████████████████████████████████| 1/1 [00:00<00:00, 52.43it/s]

--- Eval epoch-4, step-5 ---
roc_auc: nan
pr_auc: 1.0000
loss: 0.0000





In [10]:
trainer.evaluate(test_loader)

Evaluation: 100%|███████████████████████████████████████████████████████| 1/1 [00:00<00:00, 89.67it/s]


{'roc_auc': nan, 'pr_auc': 0.0, 'loss': 10.178070068359375}