# 1. Environment Setup
Use this section to configure deterministic behaviour and import the libraries required for the rest of the tutorial.

In [1]:
import os
import random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from IPython.display import display

from pyhealth.datasets import MIMIC4Dataset
from pyhealth.datasets.splitter import split_by_patient
from pyhealth.datasets.utils import get_dataloader
from pyhealth.tasks.mortality_prediction import MortalityPredictionMIMIC4

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

Running on device: cuda


# 2. Load MIMIC-IV Sample Extract
Point to the preprocessed MIMIC-IV tables, optionally override individual files, and preview their structure before building a task dataset.

In [2]:
dataset = MIMIC4Dataset(
    ehr_root="/home/logic/Github/mimic4",
    ehr_tables=[
        "patients",
        "admissions",
        "diagnoses_icd",
        "procedures_icd",
        "prescriptions",
        "labevents",
    ],
    dev=True,
)

Memory usage Starting MIMIC4Dataset init: 1407.4 MB
Initializing MIMIC4EHRDataset with tables: ['patients', 'admissions', 'diagnoses_icd', 'procedures_icd', 'prescriptions', 'labevents'] (dev mode: True)
Using default EHR config: /home/logic/miniforge3/envs/pyhealth/lib/python3.12/site-packages/pyhealth/datasets/configs/mimic4_ehr.yaml
Memory usage Before initializing mimic4_ehr: 1407.4 MB
Duplicate table names in tables list. Removing duplicates.
Initializing mimic4_ehr dataset from /home/logic/Github/mimic4 (dev mode: False)
Scanning table: diagnoses_icd from /home/logic/Github/mimic4/hosp/diagnoses_icd.csv.gz
Joining with table: /home/logic/Github/mimic4/hosp/admissions.csv.gz
Scanning table: admissions from /home/logic/Github/mimic4/hosp/admissions.csv.gz
Scanning table: prescriptions from /home/logic/Github/mimic4/hosp/prescriptions.csv.gz
Scanning table: icustays from /home/logic/Github/mimic4/icu/icustays.csv.gz
Scanning table: patients from /home/logic/Github/mimic4/hosp/patien

# 3. Prepare PyHealth Dataset
Leverage the built-in `MortalityPredictionMIMIC4` task to convert patients into labeled visit samples and split them into training, validation, and test subsets.

In [3]:
task = MortalityPredictionMIMIC4()
sample_dataset = dataset.set_task(task)

print(f"Total task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

if len(sample_dataset) == 0:
    raise RuntimeError("The task did not produce any samples. Disable dev mode or adjust table selections.")

train_ds, val_ds, test_ds = split_by_patient(sample_dataset, [0.7, 0.1, 0.2], seed=SEED)
print(f"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}")

Setting task MortalityPredictionMIMIC4 for mimic4 base dataset...
Generating samples with 1 worker(s)...
Collecting global event dataframe...
Dev mode enabled: limiting to 1000 patients
Collected dataframe with shape: (131557, 47)


Generating samples for MortalityPredictionMIMIC4 with 1 worker: 100%|██████████| 100/100 [00:02<00:00, 47.07it/s]

Label mortality vocab: {0: 0, 1: 1}



Processing samples: 100%|██████████| 108/108 [00:00<00:00, 20164.92it/s]

Generated 108 samples for task MortalityPredictionMIMIC4
Total task samples: 108
Input schema: {'conditions': 'sequence', 'procedures': 'sequence', 'drugs': 'sequence'}
Output schema: {'mortality': 'binary'}
Train/Val/Test sizes: 76, 12, 20





# 4. Inspect Batch Structure
Build PyHealth dataloaders and quickly verify the keys and tensor shapes emitted before training.

In [4]:
BATCH_SIZE = 32

train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE) if len(val_ds) else None
test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE) if len(test_ds) else None

if len(train_loader) == 0:
    raise RuntimeError("The training loader is empty. Increase the dataset size or adjust the task configuration.")

first_batch = next(iter(train_loader))

def describe(value):
    if hasattr(value, "shape"):
        return f"{type(value).__name__}(shape={tuple(value.shape)})"
    if isinstance(value, (list, tuple)):
        return f"{type(value).__name__}(len={len(value)})"
    return type(value).__name__

batch_summary = {key: describe(value) for key, value in first_batch.items()}
print(batch_summary)

mortality_targets = first_batch["mortality"]
if hasattr(mortality_targets, "shape"):
    preview = mortality_targets[:5].cpu().tolist()
else:
    preview = list(mortality_targets)[:5]
print(f"Sample mortality labels: {preview}")

{'visit_id': 'list(len=32)', 'patient_id': 'list(len=32)', 'conditions': 'Tensor(shape=(32, 39))', 'procedures': 'Tensor(shape=(32, 16))', 'drugs': 'Tensor(shape=(32, 348))', 'mortality': 'Tensor(shape=(32, 1))'}
Sample mortality labels: [[0.0], [0.0], [0.0], [0.0], [0.0]]


# 5. Instantiate CNN Model
Create the PyHealth CNN with custom hyperparameters and inspect the parameter footprint prior to optimisation.

In [5]:
from pyhealth.models import CNN

model = CNN(
    dataset=sample_dataset,
    embedding_dim=64,
    hidden_dim=64,
    num_layers=2,
    ).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Feature keys: {model.feature_keys}")
print(f"Label key: {model.label_key}")
print(f"Total parameters: {total_params:,}")

  from .autonotebook import tqdm as notebook_tqdm


Feature keys: ['conditions', 'procedures', 'drugs']
Label key: mortality
Total parameters: 250,369


# 6. Configure Trainer
Wrap the model with the PyHealth `Trainer` to handle optimisation, gradient clipping, and metric logging.

In [6]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["roc_auc"],
    device=str(device),
    enable_logging=False,
 )

training_config = {
    "epochs": 5,
    "optimizer_params": {"lr": 1e-3},
    "max_grad_norm": 5.0,
    "monitor": "roc_auc",
}

CNN(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (conditions): Embedding(865, 64, padding_idx=0)
    (procedures): Embedding(218, 64, padding_idx=0)
    (drugs): Embedding(486, 64, padding_idx=0)
  ))
  (cnn): ModuleDict(
    (conditions): CNNLayer(
      (cnn): ModuleDict(
        (CNN-0): CNNBlock(
          (conv1): Sequential(
            (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
            (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
          (conv2): Sequential(
            (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
            (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (relu): ReLU()
        )
        (CNN-1): CNNBlock(
          (conv1): Sequential(
            (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
            (1): BatchNorm1d(64, eps=1e-05, m

# 7. Train the Model
Run multiple epochs with gradient clipping, scheduler updates, and logging of loss/metrics per epoch.

In [7]:
train_kwargs = dict(training_config)
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    **train_kwargs,
 )

Training:
Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight decay: 0.0
Max grad norm: 5.0
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7f5e1c120f50>
Monitor: roc_auc
Monitor criterion: max
Epochs: 5



Epoch 0 / 5: 100%|██████████| 3/3 [00:04<00:00,  1.45s/it]

--- Train epoch-0, step-3 ---
loss: 0.4153



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 11.32it/s]

--- Eval epoch-0, step-3 ---
roc_auc: 0.8182
loss: 0.5464
New best roc_auc score (0.8182) at epoch-0, step-3




Epoch 1 / 5: 100%|██████████| 3/3 [00:01<00:00,  2.24it/s]

--- Train epoch-1, step-6 ---
loss: 0.2351



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 238.48it/s]

--- Eval epoch-1, step-6 ---
roc_auc: 0.7273
loss: 0.4530




Epoch 2 / 5: 100%|██████████| 3/3 [00:01<00:00,  2.29it/s]

--- Train epoch-2, step-9 ---
loss: 0.2450



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 159.13it/s]

--- Eval epoch-2, step-9 ---
roc_auc: 0.3636
loss: 0.3684




Epoch 3 / 5: 100%|██████████| 3/3 [00:01<00:00,  2.15it/s]

--- Train epoch-3, step-12 ---
loss: 0.1937



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 231.51it/s]

--- Eval epoch-3, step-12 ---
roc_auc: 0.3636
loss: 0.3210




Epoch 4 / 5: 100%|██████████| 3/3 [00:02<00:00,  1.49it/s]

--- Train epoch-4, step-15 ---
loss: 0.1397



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 154.21it/s]

--- Eval epoch-4, step-15 ---
roc_auc: 0.3636
loss: 0.3027





# 8. Evaluate on Validation Split
Switch to evaluation mode, collect predictions for the validation split, and compute AUROC and loss.

In [8]:
evaluation_results = {}
for split_name, loader in {"validation": val_loader, "test": test_loader}.items():
    if loader is None:
        continue
    metrics = trainer.evaluate(loader)
    evaluation_results[split_name] = metrics
    formatted = ", ".join(f"{k}={v:.4f}" for k, v in metrics.items())
    print(f"{split_name.title()} metrics: {formatted}")

Evaluation: 100%|██████████| 1/1 [00:00<00:00, 152.18it/s]


Validation metrics: roc_auc=0.3636, loss=0.3027


Evaluation: 100%|██████████| 1/1 [00:00<00:00,  7.05it/s]

Test metrics: roc_auc=0.5882, loss=0.4141





# 9. Inspect Sample Predictions
Run a quick inference pass on the validation or test split to preview predicted probabilities alongside ground-truth labels.

In [9]:
target_loader = val_loader if val_loader is not None else train_loader

y_true, y_prob, mean_loss = trainer.inference(target_loader)
positive_prob = y_prob if y_prob.ndim == 1 else y_prob[..., -1]
preview_pairs = list(zip(y_true[:5].tolist(), positive_prob[:5].tolist()))
print(f"Mean loss: {mean_loss:.4f}")
print(f"Preview (label, positive_prob): {preview_pairs}")

Evaluation: 100%|██████████| 1/1 [00:00<00:00, 131.31it/s]

Mean loss: 0.3027
Preview (label, positive_prob): [([0.0], 0.15316888689994812), ([0.0], 0.14590689539909363), ([0.0], 0.09844371676445007), ([0.0], 0.12233827263116837), ([0.0], 0.1447029411792755)]



