# 1. Environment Setup
Use this section to configure deterministic behaviour and import the libraries required for the rest of the tutorial.

In [7]:
import os
import random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from IPython.display import display

from pyhealth.datasets import MIMIC4Dataset
from pyhealth.datasets.splitter import split_by_patient
from pyhealth.datasets.utils import get_dataloader
from pyhealth.tasks.mortality_prediction import MortalityPredictionMIMIC4

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

Running on device: cuda


# 2. Load MIMIC-IV Sample Extract
Point to the preprocessed MIMIC-IV tables, optionally override individual files, and preview their structure before building a task dataset.

In [8]:
EHR_ROOT = os.environ.get("PYHEALTH_MIMIC4_EHR_ROOT", "/home/logic/Github/mimic4")
EHR_TABLES = [
    "patients",
    "admissions",
    "diagnoses_icd",
    "procedures_icd",
    "prescriptions",
    "labevents",
]

dataset = MIMIC4Dataset(
    ehr_root=EHR_ROOT,
    ehr_tables=EHR_TABLES,
    dev=True,
 )

print(f"Using MIMIC-IV root: {EHR_ROOT}")
print(f"Requested tables: {EHR_TABLES}")
print(f"EHR tables actually loaded: {sorted(dataset.sub_datasets['ehr'].tables)}")
print(f"Unique patients discovered: {len(dataset.unique_patient_ids)}")

Memory usage Starting MIMIC4Dataset init: 1280.3 MB
Initializing MIMIC4EHRDataset with tables: ['patients', 'admissions', 'diagnoses_icd', 'procedures_icd', 'prescriptions', 'labevents'] (dev mode: True)
Using default EHR config: /home/logic/miniforge3/envs/pyhealth/lib/python3.12/site-packages/pyhealth/datasets/configs/mimic4_ehr.yaml
Memory usage Before initializing mimic4_ehr: 1280.3 MB
Duplicate table names in tables list. Removing duplicates.
Initializing mimic4_ehr dataset from /home/logic/Github/mimic4 (dev mode: False)
Scanning table: procedures_icd from /home/logic/Github/mimic4/hosp/procedures_icd.csv.gz
Joining with table: /home/logic/Github/mimic4/hosp/admissions.csv.gz
Scanning table: prescriptions from /home/logic/Github/mimic4/hosp/prescriptions.csv.gz
Initializing MIMIC4EHRDataset with tables: ['patients', 'admissions', 'diagnoses_icd', 'procedures_icd', 'prescriptions', 'labevents'] (dev mode: True)
Using default EHR config: /home/logic/miniforge3/envs/pyhealth/lib/pyt

# 3. Prepare PyHealth Dataset
Instantiate a task that augments the mortality labels with a lightweight lab feature vector and split the resulting samples into training, validation, and test subsets.

In [None]:
class MortalityWithPseudoLabs(MortalityPredictionMIMIC4):
    """Extend the default mortality task with a simple numeric lab surrogate."""

    input_schema = {
        "conditions": "sequence",
        "procedures": "sequence",
        "drugs": "sequence",
        "labs": "tensor",
    }

    def __call__(self, patient):
        base_samples = super().__call__(patient)
        enriched_samples = []
        for sample in base_samples:
            labs_vector = [
                float(len(sample.get("conditions", []))),
                float(len(sample.get("procedures", []))),
                float(len(sample.get("drugs", []))),
            ]
            enriched_sample = dict(sample)
            enriched_sample["labs"] = labs_vector
            enriched_samples.append(enriched_sample)
        return enriched_samples

task = MortalityWithPseudoLabs()
sample_dataset = dataset.set_task(task)
print(f"Total task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

if len(sample_dataset) == 0:
    raise RuntimeError("The task did not produce any samples. Disable dev mode or adjust table selections.")

train_ds, val_ds, test_ds = split_by_patient(sample_dataset, [0.7, 0.1, 0.2], seed=SEED)
print(f"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}")

# 4. Inspect Batch Structure
Build PyHealth dataloaders and verify that the tensors emitted for `conditions` and `labs` have the expected shapes before training.

In [None]:
BATCH_SIZE = 32

train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE) if len(val_ds) else None
test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE) if len(test_ds) else None

if len(train_loader) == 0:
    raise RuntimeError("The training loader is empty. Increase the dataset size or adjust the task configuration.")

first_batch = next(iter(train_loader))
print(f"Batch keys: {list(first_batch.keys())}")

conditions_batch = first_batch["conditions"]
labs_batch = first_batch["labs"]

print(f"conditions tensor shape: {conditions_batch.shape}")
print(f"labs tensor shape: {labs_batch.shape}")
print(f"mortality tensor shape: {first_batch['mortality'].shape}")

# 5. Instantiate CNN Model
Create the PyHealth CNN with custom hyperparameters and inspect the parameter footprint prior to optimisation.

In [None]:
from pyhealth.models import CNN

model = CNN(
    dataset=train_ds,
    embedding_dim=64,
    hidden_dim=64,
    num_layers=2,
    ).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Feature keys: {model.feature_keys}")
print(f"Label key: {model.label_key}")
print(f"Total parameters: {total_params:,}")

# 6. Define Training Utilities
Configure the optimiser, learning-rate scheduler, and helper functions for batching, evaluation, and metric tracking.

In [None]:
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR

def build_optimizer(model: torch.nn.Module) -> torch.optim.Optimizer:
    params = [p for p in model.parameters() if p.requires_grad]
    return torch.optim.Adam(params, lr=1e-3, weight_decay=1e-5)

def build_scheduler(optimizer: torch.optim.Optimizer) -> StepLR:
    return StepLR(optimizer, step_size=5, gamma=0.5)

def run_batch(batch, device):
    features = {key: value.to(device) for key, value in batch['x'].items()}
    labels = batch['y'].to(device)
    return features, labels

def compute_loss(model, features, labels):
    outputs = model(features)
    loss = F.binary_cross_entropy_with_logits(outputs, labels.float())
    return outputs, loss

# 7. Train the Model
Run multiple epochs with gradient clipping, scheduler updates, and logging of loss/metrics per epoch.

In [None]:
from collections import defaultdict
from sklearn.metrics import roc_auc_score

torch.manual_seed(42)
optimizer = build_optimizer(model)
scheduler = build_scheduler(optimizer)
num_epochs = 10
clip_norm = 5.0
training_history = defaultdict(list)

for epoch in range(1, num_epochs + 1):
    model.train()
    epoch_loss = 0.0
    total_labels = []
    total_preds = []

    for batch in train_loader:
        features, labels = run_batch(batch, device)
        optimizer.zero_grad()
        logits, loss = compute_loss(model, features, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
        optimizer.step()

        epoch_loss += loss.item() * labels.size(0)
        total_labels.append(labels.detach().cpu())
        total_preds.append(torch.sigmoid(logits).detach().cpu())

    scheduler.step()
    labels_tensor = torch.cat(total_labels)
    preds_tensor = torch.cat(total_preds)
    train_loss = epoch_loss / len(labels_tensor)
    train_auc = roc_auc_score(labels_tensor.numpy(), preds_tensor.numpy())

    training_history['epoch'].append(epoch)
    training_history['train_loss'].append(train_loss)
    training_history['train_auc'].append(train_auc)
    print(f"Epoch {epoch:02d} | loss={train_loss:.4f} | auc={train_auc:.4f}")

# 8. Evaluate on Validation Split
Switch to evaluation mode, collect predictions for the validation split, and compute AUROC and loss.

In [None]:
model.eval()
val_labels = []
val_logits = []
val_loss_total = 0.0

with torch.no_grad():
    for batch in val_loader:
        features, labels = run_batch(batch, device)
        logits, loss = compute_loss(model, features, labels)
        val_labels.append(labels.detach().cpu())
        val_logits.append(torch.sigmoid(logits).detach().cpu())
        val_loss_total += loss.item() * labels.size(0)

val_labels = torch.cat(val_labels)
val_probs = torch.cat(val_logits)
val_loss = val_loss_total / len(val_labels)
val_auc = roc_auc_score(val_labels.numpy(), val_probs.numpy())

print(f"Validation loss = {val_loss:.4f}")
print(f"Validation AUROC = {val_auc:.4f}")

# 9. Visualise and Save Artifacts
Plot the training curves, probability histogram, and optionally persist trained weights and metrics for later analysis.

In [None]:
import json
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd

history_df = pd.DataFrame(training_history)
display(history_df)

fig, ax = plt.subplots(figsize=(8, 4))
history_df.plot(x='epoch', y=['train_loss', 'train_auc'], ax=ax)
ax.set_title('Training Progress')
ax.set_ylabel('Metric Value')
ax.grid(True)
plt.tight_layout()
plt.show()

fig, ax = plt.subplots(figsize=(8, 4))
ax.hist(val_probs.numpy(), bins=20, alpha=0.7)
ax.set_title('Validation Probability Distribution')
ax.set_xlabel('Positive Class Probability')
ax.set_ylabel('Count')
plt.tight_layout()
plt.show()

artifact_dir = Path('artifacts/cnn_mimic4_demo')
artifact_dir.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), artifact_dir / 'cnn_mimic4_state_dict.pt')
history_df.to_csv(artifact_dir / 'training_history.csv', index=False)
with open(artifact_dir / 'validation_metrics.json', 'w', encoding='utf-8') as fp:
    json.dump({'val_loss': float(val_loss), 'val_auc': float(val_auc)}, fp, indent=2)

print(f"Artifacts saved to {artifact_dir.resolve()}")