# 1. Environment Setup
Use this section to configure deterministic behaviour and import the libraries required for the rest of the tutorial.

In [1]:
import os
import random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from IPython.display import display

from pyhealth.datasets import MIMIC4Dataset
from pyhealth.datasets.splitter import split_by_patient
from pyhealth.datasets.utils import get_dataloader
from pyhealth.tasks.mortality_prediction import MortalityPredictionMIMIC4

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

Running on device: cuda


# 2. Load MIMIC-IV Sample Extract
Point to the preprocessed MIMIC-IV tables, optionally override individual files, and preview their structure before building a task dataset.

In [4]:
dataset = MIMIC4Dataset(
    ehr_root="/home/logic/physionet.org/files/mimic-iv-demo/2.2",
    ehr_tables=[
        "patients",
        "admissions",
        "diagnoses_icd",
        "procedures_icd",
        "prescriptions",
        "labevents",
    ],
    dev=True,
)

Memory usage Starting MIMIC4Dataset init: 1415.5 MB
Initializing MIMIC4EHRDataset with tables: ['patients', 'admissions', 'diagnoses_icd', 'procedures_icd', 'prescriptions', 'labevents'] (dev mode: True)
Using default EHR config: /home/logic/miniforge3/envs/pyhealth/lib/python3.12/site-packages/pyhealth/datasets/configs/mimic4_ehr.yaml
Memory usage Before initializing mimic4_ehr: 1415.5 MB
Duplicate table names in tables list. Removing duplicates.
Initializing mimic4_ehr dataset from /home/logic/physionet.org/files/mimic-iv-demo/2.2 (dev mode: False)
Scanning table: admissions from /home/logic/physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz
Scanning table: labevents from /home/logic/physionet.org/files/mimic-iv-demo/2.2/hosp/labevents.csv.gz
Joining with table: /home/logic/physionet.org/files/mimic-iv-demo/2.2/hosp/d_labitems.csv.gz
Scanning table: diagnoses_icd from /home/logic/physionet.org/files/mimic-iv-demo/2.2/hosp/diagnoses_icd.csv.gz
Joining with table: /home/logic

# 3. Prepare PyHealth Dataset
Leverage the built-in `MortalityPredictionMIMIC4` task to convert patients into labeled visit samples and split them into training, validation, and test subsets.

In [5]:
task = MortalityPredictionMIMIC4()
sample_dataset = dataset.set_task(task)

print(f"Total task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

if len(sample_dataset) == 0:
    raise RuntimeError("The task did not produce any samples. Disable dev mode or adjust table selections.")

train_ds, val_ds, test_ds = split_by_patient(sample_dataset, [0.7, 0.1, 0.2], seed=SEED)
print(f"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}")

Setting task MortalityPredictionMIMIC4 for mimic4 base dataset...
Generating samples with 1 worker(s)...
Collecting global event dataframe...
Dev mode enabled: limiting to 1000 patients
Collected dataframe with shape: (131557, 47)


Generating samples for MortalityPredictionMIMIC4 with 1 worker: 100%|██████████| 100/100 [00:00<00:00, 133.69it/s]

Label mortality vocab: {0: 0, 1: 1}



Processing samples: 100%|██████████| 108/108 [00:00<00:00, 33546.98it/s]

Generated 108 samples for task MortalityPredictionMIMIC4
Total task samples: 108
Input schema: {'conditions': 'sequence', 'procedures': 'sequence', 'drugs': 'sequence'}
Output schema: {'mortality': 'binary'}
Train/Val/Test sizes: 79, 7, 22





# 4. Inspect Batch Structure
Build PyHealth dataloaders and quickly verify the keys and tensor shapes emitted before training.

In [6]:
BATCH_SIZE = 32

train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE) if len(val_ds) else None
test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE) if len(test_ds) else None

if len(train_loader) == 0:
    raise RuntimeError("The training loader is empty. Increase the dataset size or adjust the task configuration.")

first_batch = next(iter(train_loader))

def describe(value):
    if hasattr(value, "shape"):
        return f"{type(value).__name__}(shape={tuple(value.shape)})"
    if isinstance(value, (list, tuple)):
        return f"{type(value).__name__}(len={len(value)})"
    return type(value).__name__

batch_summary = {key: describe(value) for key, value in first_batch.items()}
print(batch_summary)

mortality_targets = first_batch["mortality"]
if hasattr(mortality_targets, "shape"):
    preview = mortality_targets[:5].cpu().tolist()
else:
    preview = list(mortality_targets)[:5]
print(f"Sample mortality labels: {preview}")

{'visit_id': 'list(len=32)', 'patient_id': 'list(len=32)', 'conditions': 'Tensor(shape=(32, 38))', 'procedures': 'Tensor(shape=(32, 21))', 'drugs': 'Tensor(shape=(32, 314))', 'mortality': 'Tensor(shape=(32, 1))'}
Sample mortality labels: [[1.0], [0.0], [0.0], [0.0], [1.0]]


# 5. Instantiate GCN Model
Create the PyHealth GCN with custom hyperparameters and inspect the parameter footprint prior to optimisation.

In [7]:
from pyhealth.models import GCN

model = GCN(
    dataset=sample_dataset,
    embedding_dim=64,
    nhid=64,
    num_layers=2,
    ).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Feature keys: {model.feature_keys}")
print(f"Label key: {model.label_key}")
print(f"Total parameters: {total_params:,}")

  import pkg_resources
  from .autonotebook import tqdm as notebook_tqdm


| Xavier Initialization
| Xavier Initialization
| Xavier Initialization
Feature keys: ['conditions', 'procedures', 'drugs']
Label key: mortality
Total parameters: 116,993


# 6. Visit-level adjacency during training
We wrap each dataloader so every batch carries a `visit_adj` matrix. Here we use an identity graph (patients connect only to themselves), but you can replace it with any patient-similarity graph to encode cohort structure.


In [8]:
class VisitAdjacencyLoader:
    def __init__(self, base_loader, label_key):
        self.base_loader = base_loader
        self.label_key = label_key
        self.batch_size = getattr(base_loader, "batch_size", None)
        self.dataset = getattr(base_loader, "dataset", None)

    def __len__(self):
        return len(self.base_loader)

    def __iter__(self):
        for batch in self.base_loader:
            batch = {k: v for k, v in batch.items()}
            batch_size = batch[self.label_key].shape[0]
            batch["visit_adj"] = torch.eye(batch_size, dtype=torch.float32)
            yield batch

    def __getattr__(self, name):
        return getattr(self.base_loader, name)

train_loader = VisitAdjacencyLoader(train_loader, model.label_key)
val_loader = VisitAdjacencyLoader(val_loader, model.label_key) if val_loader else None
test_loader = VisitAdjacencyLoader(test_loader, model.label_key) if test_loader else None

visit_adj = next(iter(train_loader))["visit_adj"]
print(f"Visit adjacency injected with shape {tuple(visit_adj.shape)}")


Visit adjacency injected with shape (32, 32)


# 7. Configure Trainer
Wrap the model with the PyHealth `Trainer` to handle optimisation, gradient clipping, and metric logging.

In [9]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["roc_auc"],
    device=str(device),
    enable_logging=False,
 )

training_config = {
    "epochs": 5,
    "optimizer_params": {"lr": 1e-3},
    "max_grad_norm": 5.0,
    "monitor": "roc_auc",
}

GCN(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (conditions): Embedding(865, 64, padding_idx=0)
    (procedures): Embedding(218, 64, padding_idx=0)
    (drugs): Embedding(486, 64, padding_idx=0)
  ))
  (gcn_layers): ModuleList(
    (0): GraphConvolution (192 -> 64)
    (1): GraphConvolution (64 -> 64)
    (2): GraphConvolution (64 -> 1)
  )
)
Metrics: ['roc_auc']
Device: cuda



# 8. Train the Model
Run multiple epochs with gradient clipping, scheduler updates, and logging of loss/metrics per epoch.

In [10]:
train_kwargs = dict(training_config)
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    **train_kwargs,
 )

Training:
Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight decay: 0.0
Max grad norm: 5.0
Val dataloader: <__main__.VisitAdjacencyLoader object at 0x7d140c852c30>
Monitor: roc_auc
Monitor criterion: max
Epochs: 5



Epoch 0 / 5: 100%|██████████| 3/3 [00:00<00:00,  5.68it/s]

--- Train epoch-0, step-3 ---
loss: 0.6926



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 710.18it/s]

--- Eval epoch-0, step-3 ---
roc_auc: 1.0000
loss: 0.6918
New best roc_auc score (1.0000) at epoch-0, step-3




Epoch 1 / 5: 100%|██████████| 3/3 [00:00<00:00, 445.90it/s]

--- Train epoch-1, step-6 ---
loss: 0.6908



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 786.04it/s]

--- Eval epoch-1, step-6 ---
roc_auc: 0.8333
loss: 0.6899




Epoch 2 / 5: 100%|██████████| 3/3 [00:00<00:00, 455.33it/s]

--- Train epoch-2, step-9 ---
loss: 0.6881



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 688.83it/s]

--- Eval epoch-2, step-9 ---
roc_auc: 0.8333
loss: 0.6871




Epoch 3 / 5: 100%|██████████| 3/3 [00:00<00:00, 373.08it/s]

--- Train epoch-3, step-12 ---
loss: 0.6844



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 826.79it/s]

--- Eval epoch-3, step-12 ---
roc_auc: 0.8333
loss: 0.6826




Epoch 4 / 5: 100%|██████████| 3/3 [00:00<00:00, 433.89it/s]

--- Train epoch-4, step-15 ---
loss: 0.6783



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 719.43it/s]

--- Eval epoch-4, step-15 ---
roc_auc: 0.8333
loss: 0.6760





# 9. Evaluate on Validation Split
Switch to evaluation mode, collect predictions for the validation split, and compute AUROC and loss.

In [11]:
evaluation_results = {}
for split_name, loader in {"validation": val_loader, "test": test_loader}.items():
    if loader is None:
        continue
    metrics = trainer.evaluate(loader)
    evaluation_results[split_name] = metrics
    formatted = ", ".join(f"{k}={v:.4f}" for k, v in metrics.items())
    print(f"{split_name.title()} metrics: {formatted}")

Evaluation: 100%|██████████| 1/1 [00:00<00:00, 377.66it/s]


Validation metrics: roc_auc=0.8333, loss=0.6760


Evaluation: 100%|██████████| 1/1 [00:00<00:00, 628.45it/s]

Test metrics: roc_auc=0.6389, loss=0.6795





# 10. Inspect Sample Predictions
Run a quick inference pass on the validation or test split to preview predicted probabilities alongside ground-truth labels.

In [12]:
target_loader = val_loader if val_loader is not None else train_loader

y_true, y_prob, mean_loss = trainer.inference(target_loader)
positive_prob = y_prob if y_prob.ndim == 1 else y_prob[..., -1]
preview_pairs = list(zip(y_true[:5].tolist(), positive_prob[:5].tolist()))
print(f"Mean loss: {mean_loss:.4f}")
print(f"Preview (label, positive_prob): {preview_pairs}")

Evaluation: 100%|██████████| 1/1 [00:00<00:00, 358.95it/s]

Mean loss: 0.6760
Preview (label, positive_prob): [([0.0], 0.4901101589202881), ([0.0], 0.47850850224494934), ([0.0], 0.49114376306533813), ([0.0], 0.4918002784252167), ([1.0], 0.49170568585395813)]





# 11. Using custom adjacency matrices
During training we already pass the visit graph created above. The `GCN` model now accepts optional adjacency arguments:
- `feature_adj`: describes relationships between feature streams (e.g., diagnoses vs. procedures). Supply either a single `[num_features, num_features]` matrix shared across the batch or a `[batch_size, num_features, num_features]` tensor for patient-specific feature graphs.
- `visit_adj`: captures patient/patient (visit-level) connectivity with shape `[batch_size, batch_size]` and defaults to a fully connected graph when omitted.

The example below shows how to build dense identity graphs, but you can replace them with domain-specific structures (e.g., cosine similarity, co-occurrence statistics, etc.).


In [None]:
import torch

demo_batch = next(iter(train_loader))
demo_batch = {k: v for k, v in demo_batch.items()}
batch_size = demo_batch[model.label_key].shape[0]
num_features = len(model.feature_keys)

# Feature graph: identity matrix shared by the batch (also supports per-patient tensors)
feature_adj = torch.eye(num_features, device=device)
demo_batch["feature_adj"] = feature_adj

# Visit graph: connect only identical patients (use custom similarities as needed)
visit_adj = torch.eye(batch_size, device=device)
demo_batch["visit_adj"] = visit_adj

with torch.no_grad():
    custom_graph_output = model(**demo_batch)

print("Custom graph demo loss:", float(custom_graph_output["loss"]))
