# 1. Environment Setup
Use this section to configure deterministic behaviour and import the libraries required for the rest of the tutorial.

In [None]:
import os
import random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from IPython.display import display

from pyhealth.datasets import MIMIC4Dataset
from pyhealth.datasets.splitter import split_by_patient
from pyhealth.datasets.utils import get_dataloader
from pyhealth.tasks.mortality_prediction import MortalityPredictionMIMIC4

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

# 2. Load MIMIC-IV Sample Extract
Point to the preprocessed MIMIC-IV tables, optionally override individual files, and preview their structure before building a task dataset.

In [None]:
dataset = MIMIC4Dataset(
    ehr_root="/home/logic/Github/mimic4",
    ehr_tables=[
        "patients",
        "admissions",
        "diagnoses_icd",
        "procedures_icd",
        "prescriptions",
        "labevents",
    ],
    dev=True,
)

# 3. Prepare PyHealth Dataset
Leverage the built-in `MortalityPredictionMIMIC4` task to convert patients into labeled visit samples and split them into training, validation, and test subsets.

In [None]:
task = MortalityPredictionMIMIC4()
sample_dataset = dataset.set_task(task)

print(f"Total task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

if len(sample_dataset) == 0:
    raise RuntimeError("The task did not produce any samples. Disable dev mode or adjust table selections.")

train_ds, val_ds, test_ds = split_by_patient(sample_dataset, [0.7, 0.1, 0.2], seed=SEED)
print(f"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}")

# 4. Inspect Batch Structure
Build PyHealth dataloaders and quickly verify the keys and tensor shapes emitted before training.

In [None]:
BATCH_SIZE = 32

train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE) if len(val_ds) else None
test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE) if len(test_ds) else None

if len(train_loader) == 0:
    raise RuntimeError("The training loader is empty. Increase the dataset size or adjust the task configuration.")

first_batch = next(iter(train_loader))

def describe(value):
    if hasattr(value, "shape"):
        return f"{type(value).__name__}(shape={tuple(value.shape)})"
    if isinstance(value, (list, tuple)):
        return f"{type(value).__name__}(len={len(value)})"
    return type(value).__name__

batch_summary = {key: describe(value) for key, value in first_batch.items()}
print(batch_summary)

mortality_targets = first_batch["mortality"]
if hasattr(mortality_targets, "shape"):
    preview = mortality_targets[:5].cpu().tolist()
else:
    preview = list(mortality_targets)[:5]
print(f"Sample mortality labels: {preview}")

# 5. Instantiate CNN Model
Create the PyHealth CNN with custom hyperparameters and inspect the parameter footprint prior to optimisation.

In [None]:
from pyhealth.models import CNN

model = CNN(
    dataset=sample_dataset,
    embedding_dim=64,
    hidden_dim=64,
    num_layers=2,
    ).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Feature keys: {model.feature_keys}")
print(f"Label key: {model.label_key}")
print(f"Total parameters: {total_params:,}")

# 6. Configure Trainer
Wrap the model with the PyHealth `Trainer` to handle optimisation, gradient clipping, and metric logging.

In [None]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["roc_auc"],
    device=str(device),
    enable_logging=False,
 )

training_config = {
    "epochs": 5,
    "optimizer_params": {"lr": 1e-3},
    "max_grad_norm": 5.0,
    "monitor": "roc_auc",
}

# 7. Train the Model
Run multiple epochs with gradient clipping, scheduler updates, and logging of loss/metrics per epoch.

In [None]:
train_kwargs = dict(training_config)
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    **train_kwargs,
 )

# 8. Evaluate on Validation Split
Switch to evaluation mode, collect predictions for the validation split, and compute AUROC and loss.

In [None]:
evaluation_results = {}
for split_name, loader in {"validation": val_loader, "test": test_loader}.items():
    if loader is None:
        continue
    metrics = trainer.evaluate(loader)
    evaluation_results[split_name] = metrics
    formatted = ", ".join(f"{k}={v:.4f}" for k, v in metrics.items())
    print(f"{split_name.title()} metrics: {formatted}")

# 9. Inspect Sample Predictions
Run a quick inference pass on the validation or test split to preview predicted probabilities alongside ground-truth labels.

In [None]:
target_loader = val_loader if val_loader is not None else train_loader

y_true, y_prob, mean_loss = trainer.inference(target_loader)
positive_prob = y_prob if y_prob.ndim == 1 else y_prob[..., -1]
preview_pairs = list(zip(y_true[:5].tolist(), positive_prob[:5].tolist()))
print(f"Mean loss: {mean_loss:.4f}")
print(f"Preview (label, positive_prob): {preview_pairs}")