# Dr. Agent Model Tutorial with MIMIC-IV

This notebook demonstrates how to use the Dr. Agent model for mortality prediction on MIMIC-IV data. Dr. Agent uses two reinforcement learning agents with dynamic skip connections to capture long-term dependencies in patient EHR sequences.

**Paper:** Gao et al. "Dr. Agent: Clinical predictive model via mimicked second opinions" (JAMIA 2020)


# 1. Environment Setup

Configure deterministic behaviour and import the libraries required for the tutorial.

In [None]:
import os
import random
from pathlib import Path

import numpy as np
import torch
from IPython.display import display

from pyhealth.datasets import MIMIC4Dataset
from pyhealth.datasets.splitter import split_by_patient
from pyhealth.datasets.utils import get_dataloader
from pyhealth.tasks.mortality_prediction import MortalityPredictionMIMIC4

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

# 2. Load MIMIC-IV Dataset
Point to the preprocessed MIMIC-IV tables and load the dataset.

In [None]:
dataset = MIMIC4Dataset(
    ehr_root="/path/to/mimic4",  # Update this path
    ehr_tables=[
        "patients",
        "admissions",
        "diagnoses_icd",
        "procedures_icd",
        "prescriptions",
    ],
    dev=True,  # Set to False for full dataset
)

# 3. Prepare PyHealth Dataset
Use the built-in `MortalityPredictionMIMIC4` task to convert patients into labeled visit samples and split them into training, validation, and test subsets.

In [None]:
task = MortalityPredictionMIMIC4()
sample_dataset = dataset.set_task(task)

print(f"Total task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

if len(sample_dataset) == 0:
    raise RuntimeError(
        "The task did not produce any samples. "
        "Disable dev mode or adjust table selections."
    )

train_ds, val_ds, test_ds = split_by_patient(
    sample_dataset, [0.7, 0.1, 0.2], seed=SEED
)
print(f"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}")

# 4. Inspect Batch Structure

Build PyHealth dataloaders and verify the keys and tensor shapes before training.

In [None]:
BATCH_SIZE = 32

train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE) if len(val_ds) else None
test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE) if len(test_ds) else None

if len(train_loader) == 0:
    raise RuntimeError(
        "The training loader is empty. "
        "Increase the dataset size or adjust the task configuration."
    )

first_batch = next(iter(train_loader))


def describe(value):
    if hasattr(value, "shape"):
        return f"{type(value).__name__}(shape={tuple(value.shape)})"
    if isinstance(value, (list, tuple)):
        return f"{type(value).__name__}(len={len(value)})"
    return type(value).__name__


batch_summary = {key: describe(value) for key, value in first_batch.items()}
print("Batch structure:")
for key, desc in batch_summary.items():
    print(f"  {key}: {desc}")

# 5. Instantiate Dr. Agent Model

Create the PyHealth Agent model with custom hyperparameters. The model uses:
- **Primary agent**: Observes current visit to learn current health status
- **Second-opinion agent**: Considers entire patient history for a global view
- **Dynamic skip connections**: Selects optimal historical states via policy gradient

In [None]:
from pyhealth.models import Agent

model = Agent(
    dataset=sample_dataset,
    embedding_dim=128,
    hidden_dim=128,
    # Agent-specific hyperparameters
    n_actions=10,      # History window size (K in paper)
    n_units=64,        # Agent MLP hidden units
    dropout=0.5,       # Dropout rate
    lamda=0.5,         # Skip connection weight
    cell="gru",        # RNN cell type: "gru" or "lstm"
    use_baseline=True, # Use baseline for variance reduction
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Feature keys: {model.feature_keys}")
print(f"Label key: {model.label_key}")
print(f"Mode: {model.mode}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# 6. Verify Forward Pass

Run a single forward pass to ensure the model works correctly before training.

In [None]:
model.eval()
with torch.no_grad():
    sample_batch = {k: v.to(device) if hasattr(v, "to") else v for k, v in first_batch.items()}
    output = model(**sample_batch)

print("Forward pass output keys:", list(output.keys()))
print(f"Loss: {output['loss'].item():.4f}")
print(f"y_prob shape: {output['y_prob'].shape}")
print(f"y_true shape: {output['y_true'].shape}")

# 7. Configure Trainer

Wrap the model with the PyHealth `Trainer` to handle optimisation, gradient clipping, and metric logging.

In [None]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["roc_auc", "pr_auc"],
    device=str(device),
    enable_logging=True,
)

training_config = {
    "epochs": 20,
    "optimizer_params": {"lr": 1e-3},
    "max_grad_norm": 5.0,
    "monitor": "roc_auc",
}

# 8. Train the Model

Run multiple epochs with gradient clipping, scheduler updates, and logging of loss/metrics per epoch.

In [None]:
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    **training_config,
)

# 9. Evaluate on Test Split

Switch to evaluation mode, collect predictions for validation and test splits, and compute metrics.

In [None]:
evaluation_results = {}

for split_name, loader in [("validation", val_loader), ("test", test_loader)]:
    if loader is None:
        print(f"Skipping {split_name} (no data)")
        continue
    
    metrics = trainer.evaluate(loader)
    evaluation_results[split_name] = metrics
    
    formatted = ", ".join(f"{k}={v:.4f}" for k, v in metrics.items())
    print(f"{split_name.title()} metrics: {formatted}")

# 10. Inspect Sample Predictions

Run inference to preview predicted probabilities alongside ground-truth labels.

In [None]:
target_loader = test_loader if test_loader is not None else val_loader
if target_loader is None:
    target_loader = train_loader

y_true, y_prob, mean_loss = trainer.inference(target_loader)

# For binary classification, get positive class probability
positive_prob = y_prob if y_prob.ndim == 1 else y_prob[..., -1]

print(f"Mean inference loss: {mean_loss:.4f}")
print(f"\nSample predictions (first 10):")
print("-" * 40)
print(f"{'True Label':<12} {'Pred Prob':<12}")
print("-" * 40)

for label, prob in zip(y_true[:10].tolist(), positive_prob[:10].tolist()):
    print(f"{int(label):<12} {prob:.4f}")

# 11. Save Model Checkpoint

In [None]:
checkpoint_path = "agent_mimic4_mortality.pt"

torch.save({
    "model_state_dict": model.state_dict(),
    "training_config": training_config,
    "evaluation_results": evaluation_results,
}, checkpoint_path)

print(f"Model saved to: {checkpoint_path}")