# Mamba on MIMIC-IV (Demo)
This minimal walkthrough uses PyHealth's built-in utilities to train the PyHealth Mamba model on a two new MIMIC-IV tasks: mortality prediction within one month, and binary LOS prediction over one week. It also includes a demonstration of using LR schedulers with PyHealth's Trainer.

In [None]:
from pyhealth.datasets import MIMIC4Dataset, get_dataloader, split_by_sample
from pyhealth.tasks import MortalityPrediction31DaysMIMIC4, BinaryLengthOfStayPredictionMIMIC4
from pyhealth.trainer import Trainer
from pyhealth.models import Mamba

MIMIC4_PATH = "../datasets/mimic-iv-2.2"

In [None]:
dataset = MIMIC4Dataset(
    ehr_root=MIMIC4_PATH,
    ehr_tables=["patients", "admissions", "diagnoses_icd", "procedures_icd", "prescriptions"],
    dev=True
)

In [4]:
TASK = "length_of_stay_prediction"

if TASK == "length_of_stay_prediction":
    sample_dataset = dataset.set_task(
        task=BinaryLengthOfStayPredictionMIMIC4(),
        cache_dir="../test_cache_mamba_los",
    )
elif TASK == "mortality_prediction":
    sample_dataset = dataset.set_task(
        task=MortalityPrediction31DaysMIMIC4(),
        cache_dir="../test_cache_mamba_mort",
    )
train_dataset, val_dataset, test_dataset = split_by_sample(sample_dataset, ratios=[0.7, 0.1, 0.2])

Setting task BinaryLengthOfStayPredictionMIMIC4 for mimic4 base dataset...
Generating samples with 1 worker(s)...
Collecting global event dataframe...
Dev mode enabled: limiting to 1000 patients
Collected dataframe with shape: (76949, 38)


Generating samples for BinaryLengthOfStayPredictionMIMIC4 with 1 worker: 100%|██████████| 1000/1000 [00:02<00:00, 368.10it/s]

Caching samples to ../test_cache_mamba_los/BinaryLengthOfStayPredictionMIMIC4.parquet
Successfully cached 452 samples
Label length_of_stay vocab: {0: 0, 1: 1}



Processing samples: 100%|██████████| 452/452 [00:00<00:00, 56674.70it/s]

Generated 452 samples for task BinaryLengthOfStayPredictionMIMIC4





In [5]:
train_loader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

In [6]:
model = Mamba(dataset=sample_dataset, embedding_dim=128, num_layers=2, dropout=0.1)
trainer = Trainer(model=model, metrics=["roc_auc", "pr_auc"])


Mamba(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (conditions): Embedding(1766, 128, padding_idx=0)
    (procedures): Embedding(562, 128, padding_idx=0)
    (drugs): Embedding(768, 128, padding_idx=0)
  ))
  (mamba_layers): ModuleDict(
    (conditions): MambaLayers(
      (layers): ModuleList(
        (0-1): 2 x ResidualBlock(
          (mixer): MambaBlock(
            (in_proj): Linear(in_features=128, out_features=512, bias=False)
            (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)
            (x_proj): Linear(in_features=256, out_features=40, bias=False)
            (dt_proj): Linear(in_features=8, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=128, bias=False)
          )
          (norm): RMSNorm()
        )
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (procedures): MambaLayers(
      (layers): ModuleList(
        (0-1): 2 x ResidualBlock(
          

In [7]:
# Builder function for creating an SequentialLR scheduler with linear warmup and decay
from torch.optim.lr_scheduler import LinearLR, SequentialLR

def build_linear_scheduler_with_warmup_and_decay(optimizer, n_steps, warmup_ratio=0.1, decay_ratio=0.9):
    n_warmup_steps = int(warmup_ratio * n_steps)
    n_decay_steps = int(decay_ratio * n_steps)

    warmup = LinearLR(
        optimizer,
        start_factor=0.01,
        end_factor=1.0,
        total_iters=n_warmup_steps,
    )
    decay = LinearLR(
        optimizer,
        start_factor=1.0,
        end_factor=0.01,
        total_iters=n_decay_steps,
    )
    scheduler = SequentialLR(
        optimizer=optimizer,
        schedulers=[warmup, decay],
        milestones=[n_warmup_steps],
    )

    return scheduler


In [8]:
epochs = 5
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=epochs,
    monitor="roc_auc",
    optimizer_params={"lr": 1e-4},
    scheduler_class_or_fn=build_linear_scheduler_with_warmup_and_decay,
    scheduler_params={"n_steps": len(train_loader) * epochs, "warmup_ratio": 0.1, "decay_ratio": 0.9},
)

Training:
Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.0001}
Scheduler: <function build_linear_scheduler_with_warmup_and_decay at 0x7ff012e880e0>
Scheduler params: {'n_steps': 50, 'warmup_ratio': 0.1, 'decay_ratio': 0.9}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7ff01354e180>
Monitor: roc_auc
Monitor criterion: max
Epochs: 5
Patience: None



Epoch 0 / 5: 100%|██████████| 10/10 [00:00<00:00, 15.37it/s]

--- Train epoch-0, step-10 ---
loss: 0.7576



Evaluation: 100%|██████████| 2/2 [00:00<00:00, 96.55it/s]

--- Eval epoch-0, step-10 ---
roc_auc: 0.5000
pr_auc: 0.3516
loss: 0.7141
New best roc_auc score (0.5000) at epoch-0, step-10




Epoch 1 / 5: 100%|██████████| 10/10 [00:00<00:00, 20.39it/s]

--- Train epoch-1, step-20 ---
loss: 0.7315



Evaluation: 100%|██████████| 2/2 [00:00<00:00, 111.19it/s]

--- Eval epoch-1, step-20 ---
roc_auc: 0.5086
pr_auc: 0.3547
loss: 0.7094
New best roc_auc score (0.5086) at epoch-1, step-20




Epoch 2 / 5: 100%|██████████| 10/10 [00:00<00:00, 21.93it/s]

--- Train epoch-2, step-30 ---
loss: 0.7209



Evaluation: 100%|██████████| 2/2 [00:00<00:00, 104.57it/s]

--- Eval epoch-2, step-30 ---
roc_auc: 0.5114
pr_auc: 0.3561
loss: 0.7059
New best roc_auc score (0.5114) at epoch-2, step-30




Epoch 3 / 5: 100%|██████████| 10/10 [00:00<00:00, 19.74it/s]

--- Train epoch-3, step-40 ---
loss: 0.7052



Evaluation: 100%|██████████| 2/2 [00:00<00:00, 94.92it/s]

--- Eval epoch-3, step-40 ---
roc_auc: 0.5114
pr_auc: 0.3561
loss: 0.7040




Epoch 4 / 5: 100%|██████████| 10/10 [00:00<00:00, 20.19it/s]

--- Train epoch-4, step-50 ---
loss: 0.7094



Evaluation: 100%|██████████| 2/2 [00:00<00:00, 98.90it/s]

--- Eval epoch-4, step-50 ---
roc_auc: 0.5114
pr_auc: 0.3561
loss: 0.7032
Loaded best model





In [9]:
trainer.evaluate(test_loader)

Evaluation: 100%|██████████| 3/3 [00:00<00:00, 101.74it/s]


{'roc_auc': 0.37758346581875996,
 'pr_auc': 0.16746778796751344,
 'loss': 0.7669918338457743}