
# Getting Started

## Installation

Install the latest alpha release of StageNet modernized for PyHealth:

```bash
pip install pyhealth==2.0a8
```

## Loading Data

Load the PyHealth dataset for mortality prediction.

PyHealth datasets use a `config.yaml` file to define:
- Input tables (.csv, .tsv, etc.)
- Features to extract
- Aggregation methods

The result is a single dataframe where each row represents one patient and their features.

For more details on PyHealth datasets, see [this resource](https://colab.research.google.com/drive/1voSx7wEfzXfEf2sIfW6b-8p1KqMyuWxK#scrollTo=NSrb2PGFqUgS).
```

In [None]:
"""
Example of using StageNet for mortality prediction on MIMIC-IV.

This example demonstrates:
1. Loading MIMIC-IV data
2. Applying the MortalityPredictionStageNetMIMIC4 task
3. Creating a SampleDataset with StageNet processors
4. Training a StageNet model
"""

from pyhealth.datasets import (
    MIMIC4Dataset,
    get_dataloader,
    split_by_patient,
)
from pyhealth.models import StageNet
from pyhealth.tasks import MortalityPredictionStageNetMIMIC4
from pyhealth.trainer import Trainer
import torch

# STEP 1: Load MIMIC-IV base dataset
base_dataset = MIMIC4Dataset(
    ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/",
    ehr_tables=[
        "patients",
        "admissions",
        "diagnoses_icd",
        "procedures_icd",
        "labevents",
    ],
)

## Defining a PyHealth Task

We'll predict patient mortality using StageNet across time-series data from multiple visits. Each visit includes:

- Diagnosis codes
- Procedure codes
- Lab events

To define a task, specify the `__call__` method, input schema, and output schema. For a detailed explanation, see [this tutorial](https://colab.research.google.com/drive/1kKKBVS_GclHoYTbnOtjyYnSee79hsyT?usp=sharing).

### Helper Functions

Use `patient.get_events()` to retrieve all events from a specific table, with optional filtering. See the [MIMIC-IV YAML file](https://github.com/sunlabuiuc/PyHealth/blob/master/pyhealth/datasets/configs/mimic4_ehr.yaml) for available tables.

In [None]:
from datetime import datetime
from typing import Any, ClassVar, Dict, List

import polars as pl

from pyhealth.tasks.base_task import BaseTask


class MortalityPredictionStageNetMIMIC4(BaseTask):
    """Task for predicting mortality using MIMIC-IV with StageNet format.

    This task creates PATIENT-LEVEL samples (not visit-level) by aggregating
    all admissions for each patient. ICD codes (diagnoses + procedures) and
    lab results across all visits are combined with time intervals calculated
    from the patient's first admission timestamp.

    Time Calculation:
        - ICD codes: Hours from previous admission (0 for first visit,
          then time intervals between consecutive visits)
        - Labs: Hours from admission start (within-visit measurements)

    Lab Processing:
        - 10-dimensional vectors (one per lab category)
        - Multiple itemids per category → take first observed value
        - Missing categories → None/NaN in vector

    Attributes:
        task_name (str): The name of the task.
        input_schema (Dict[str, str]): The schema for input data:
            - icd_codes: Combined diagnosis + procedure ICD codes
              (stagenet format, nested by visit)
            - labs: Lab results (stagenet_tensor, 10D vectors per timestamp)
        output_schema (Dict[str, str]): The schema for output data:
            - mortality: Binary indicator (1 if any admission had mortality)
    """

    task_name: str = "MortalityPredictionStageNetMIMIC4"
    input_schema: Dict[str, str] = {
        "icd_codes": "stagenet",
        "labs": "stagenet_tensor",
    }
    output_schema: Dict[str, str] = {"mortality": "binary"}

    # Organize lab items by category
    # Each category will map to ONE dimension in the output vector
    LAB_CATEGORIES: ClassVar[Dict[str, List[str]]] = {
        "Sodium": ["50824", "52455", "50983", "52623"],
        "Potassium": ["50822", "52452", "50971", "52610"],
        "Chloride": ["50806", "52434", "50902", "52535"],
        "Bicarbonate": ["50803", "50804"],
        "Glucose": ["50809", "52027", "50931", "52569"],
        "Calcium": ["50808", "51624"],
        "Magnesium": ["50960"],
        "Anion Gap": ["50868", "52500"],
        "Osmolality": ["52031", "50964", "51701"],
        "Phosphate": ["50970"],
    }

    # Ordered list of category names (defines vector dimension order)
    LAB_CATEGORY_NAMES: ClassVar[List[str]] = [
        "Sodium",
        "Potassium",
        "Chloride",
        "Bicarbonate",
        "Glucose",
        "Calcium",
        "Magnesium",
        "Anion Gap",
        "Osmolality",
        "Phosphate",
    ]

    # Flat list of all lab item IDs for filtering
    LABITEMS: ClassVar[List[str]] = [
        item for itemids in LAB_CATEGORIES.values() for item in itemids
    ]

    def __call__(self, patient: Any) -> List[Dict[str, Any]]:
        """Process a patient to create mortality prediction samples.

        Creates ONE sample per patient with all admissions aggregated.
        Time intervals are calculated between consecutive admissions.

        Args:
            patient: Patient object with get_events method

        Returns:
            List with single sample containing patient_id, all conditions,
            procedures, labs across visits, and final mortality label
        """
        # Filter patients by age (>= 18)
        demographics = patient.get_events(event_type="patients")
        if not demographics:
            return []

        demographics = demographics[0]
        try:
            anchor_age = int(demographics.anchor_age)
            if anchor_age < 18:
                return []
        except (ValueError, TypeError, AttributeError):
            # If age can't be determined, skip patient
            return []

        # Get all admissions
        admissions = patient.get_events(event_type="admissions")
        if len(admissions) < 1:
            return []

        # Initialize aggregated data structures
        # List of ICD codes (diagnoses + procedures) per visit
        all_icd_codes = []
        all_icd_times = []  # Time from previous admission per visit
        all_lab_values = []  # List of 10D lab vectors
        all_lab_times = []  # Time from admission start per measurement

        # Track previous admission timestamp for interval calculation
        previous_admission_time = None

        # Track if patient had any mortality event
        final_mortality = 0

        # Process each admission
        for i, admission in enumerate(admissions):
            # Parse admission and discharge times
            try:
                admission_time = admission.timestamp
                admission_dischtime = datetime.strptime(
                    admission.dischtime, "%Y-%m-%d %H:%M:%S"
                )
            except (ValueError, AttributeError):
                # Skip if timestamps invalid
                continue

            # Skip if discharge is before admission (data quality issue)
            if admission_dischtime < admission_time:
                continue

            # Calculate time from previous admission (in hours)
            # First admission will have time = 0
            if previous_admission_time is None:
                time_from_previous = 0.0
            else:
                time_from_previous = (
                    admission_time - previous_admission_time
                ).total_seconds() / 3600.0

            # Update previous admission time for next iteration
            previous_admission_time = admission_time

            # Update mortality label if this admission had mortality
            try:
                if int(admission.hospital_expire_flag) == 1:
                    final_mortality = 1
            except (ValueError, TypeError, AttributeError):
                pass

            # Get diagnosis codes for this admission using hadm_id
            diagnoses_icd = patient.get_events(
                event_type="diagnoses_icd",
                filters=[("hadm_id", "==", admission.hadm_id)],
            )
            visit_diagnoses = [
                event.icd_code
                for event in diagnoses_icd
                if hasattr(event, "icd_code") and event.icd_code
            ]

            # Get procedure codes for this admission using hadm_id
            procedures_icd = patient.get_events(
                event_type="procedures_icd",
                filters=[("hadm_id", "==", admission.hadm_id)],
            )
            visit_procedures = [
                event.icd_code
                for event in procedures_icd
                if hasattr(event, "icd_code") and event.icd_code
            ]

            # Combine diagnoses and procedures into single ICD code list
            visit_icd_codes = visit_diagnoses + visit_procedures

            if visit_icd_codes:
                all_icd_codes.append(visit_icd_codes)
                all_icd_times.append(time_from_previous)

            # Get lab events for this admission
            labevents_df = patient.get_events(
                event_type="labevents",
                start=admission_time,
                end=admission_dischtime,
                return_df=True,
            )

            # Filter to relevant lab items
            labevents_df = labevents_df.filter(
                pl.col("labevents/itemid").is_in(self.LABITEMS)
            )

            # Parse storetime and filter
            if labevents_df.height > 0:
                labevents_df = labevents_df.with_columns(
                    pl.col("labevents/storetime").str.strptime(
                        pl.Datetime, "%Y-%m-%d %H:%M:%S"
                    )
                )
                labevents_df = labevents_df.filter(
                    pl.col("labevents/storetime") <= admission_dischtime
                )

                if labevents_df.height > 0:
                    # Select relevant columns
                    labevents_df = labevents_df.select(
                        pl.col("timestamp"),
                        pl.col("labevents/itemid"),
                        pl.col("labevents/valuenum").cast(pl.Float64),
                    )

                    # Group by timestamp and aggregate into 10D vectors
                    # For each timestamp, create vector of lab categories
                    unique_timestamps = sorted(
                        labevents_df["timestamp"].unique().to_list()
                    )

                    for lab_ts in unique_timestamps:
                        # Get all lab events at this timestamp
                        ts_labs = labevents_df.filter(pl.col("timestamp") == lab_ts)

                        # Create 10-dimensional vector (one per category)
                        lab_vector = []
                        for category_name in self.LAB_CATEGORY_NAMES:
                            category_itemids = self.LAB_CATEGORIES[category_name]

                            # Find first matching value for this category
                            category_value = None
                            for itemid in category_itemids:
                                matching = ts_labs.filter(
                                    pl.col("labevents/itemid") == itemid
                                )
                                if matching.height > 0:
                                    category_value = matching["labevents/valuenum"][0]
                                    break

                            lab_vector.append(category_value)

                        # Calculate time from admission start (hours)
                        time_from_admission = (
                            lab_ts - admission_time
                        ).total_seconds() / 3600.0

                        all_lab_values.append(lab_vector)
                        all_lab_times.append(time_from_admission)

        # Skip if no lab events (required for this task)
        if len(all_lab_values) == 0:
            return []

        # Also skip if no ICD codes across all admissions
        if len(all_icd_codes) == 0:
            return []

        # Format as tuples: (time, values)
        # ICD codes: nested list with times
        icd_codes_data = (all_icd_times, all_icd_codes)

        # Labs: list of 10D vectors with times
        labs_data = (all_lab_times, all_lab_values)

        # Create single patient-level sample
        sample = {
            "patient_id": patient.patient_id,
            "icd_codes": icd_codes_data,
            "labs": labs_data,
            "mortality": final_mortality,
        }
        return [sample]


## Input and Output Schemas
Input and output schemas map feature keys (e.g., "labs", "icd_codes") to StageNet processors. Each processor converts features into `StageNetFeature` objects used for training and inference.

**Required format:** Each feature processed in our task call must follow this structure:
```python
"feature": (my_times_list, my_values_list)

We offer two types of StageNet processors, one for categorical variables, and the other for numerical feature variables. Our goal here is to represent each feature as a pre-defined tuple (time, value) that we can later pass to StageNet for processing.

In [None]:
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import torch

from pyhealth.processors import register_processor
from pyhealth.processors.base_processor import FeatureProcessor


@dataclass
class StageNetFeature:
    """Container for StageNet feature with values and optional time intervals.

    Attributes:
        value: The feature tensor (1D for sequences, 2D for nested sequences, 3D for feature vectors)
        time: Optional time interval tensor (1D, matching the sequence length of value)
    """

    value: torch.Tensor
    time: Optional[torch.Tensor] = None


@register_processor("stagenet")
class StageNetProcessor(FeatureProcessor):
    """
    Feature processor for StageNet CODE inputs with coupled value/time data.

    This processor handles categorical code sequences (flat or nested).
    For numeric features, use StageNetTensorProcessor instead.

    Format:
    {
        "value": ["code1", "code2"] or [["A", "B"], ["C"]],
        "time": [0.0, 2.0, 1.3] or None
    }

    The processor automatically detects:
    - List of strings -> flat code sequences
    - List of lists of strings -> nested code sequences

    Time intervals should be simple lists of scalars, one per sequence position.

    Examples:
        >>> # Case 1: Code sequence with time
        >>> processor = StageNetProcessor()
        >>> data = {"value": ["code1", "code2", "code3"], "time": [0.0, 1.5, 2.3]}
        >>> result = processor.process(data)
        >>> result.value.shape  # (3,) - sequence of code indices
        >>> result.time.shape   # (3,) - time intervals

        >>> # Case 2: Nested codes with time
        >>> data = {"value": [["A", "B"], ["C"]], "time": [0.0, 1.5]}
        >>> result = processor.process(data)
        >>> result.value.shape  # (2, max_inner_len) - padded nested sequences
        >>> result.time.shape   # (2,)

        >>> # Case 3: Feature vectors without time
        >>> data = {"value": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], "time": None}
        >>> result = processor.process(data)
        >>> result.value.shape  # (2, 3)
        >>> result.time         # None
    """

    def __init__(self):
        self.code_vocab: Dict[Any, int] = {"<unk>": -1, "<pad>": 0}
        self._next_index = 1
        self._is_nested = None  # Will be determined during fit
        self._max_nested_len = None  # Max inner sequence length for nested codes

    def fit(self, samples: List[Dict], key: str) -> None:
        """Build vocabulary and determine input structure.

        Args:
            samples: List of sample dictionaries
            key: The key in samples that contains StageNet format data
        """
        # Examine first non-None sample to determine structure
        for sample in samples:
            if key in sample and sample[key] is not None:
                value_data = sample[key]["value"]

                # Determine nesting level for codes
                if isinstance(value_data, list) and len(value_data) > 0:
                    first_elem = value_data[0]

                    if isinstance(first_elem, str):
                        # Case 1: ["code1", "code2", ...]
                        self._is_nested = False
                    elif isinstance(first_elem, list):
                        if len(first_elem) > 0 and isinstance(first_elem[0], str):
                            # Case 2: [["A", "B"], ["C"], ...]
                            self._is_nested = True
                break

        # Build vocabulary for codes and find max nested length
        max_inner_len = 0
        for sample in samples:
            if key in sample and sample[key] is not None:
                value_data = sample[key]["value"]

                if self._is_nested:
                    # Nested codes
                    for inner_list in value_data:
                        # Track max inner length
                        max_inner_len = max(max_inner_len, len(inner_list))
                        for code in inner_list:
                            if code is not None and code not in self.code_vocab:
                                self.code_vocab[code] = self._next_index
                                self._next_index += 1
                else:
                    # Flat codes
                    for code in value_data:
                        if code is not None and code not in self.code_vocab:
                            self.code_vocab[code] = self._next_index
                            self._next_index += 1

        # Store max nested length (at least 1 for empty sequences)
        if self._is_nested:
            self._max_nested_len = max(1, max_inner_len)

    def process(self, value: Dict[str, Any]) -> StageNetFeature:
        """Process StageNet format data into tensors.

        Args:
            value: Dictionary with "value" and optional "time" keys

        Returns:
            StageNetFeature with value and time tensors
        """
        value_data = value["value"]
        time_data = value.get("time", None)

        # Encode codes to indices
        if self._is_nested:
            # Nested codes: [["A", "B"], ["C"]]
            value_tensor = self._encode_nested_codes(value_data)
        else:
            # Flat codes: ["code1", "code2"]
            value_tensor = self._encode_codes(value_data)

        # Process time if present
        time_tensor = None
        if time_data is not None and len(time_data) > 0:
            # Handle both [0.0, 1.5] and [[0.0], [1.5]] formats
            if isinstance(time_data[0], list):
                # Flatten [[0.0], [1.5]] -> [0.0, 1.5]
                time_data = [t[0] if isinstance(t, list) else t for t in time_data]
            time_tensor = torch.tensor(time_data, dtype=torch.float)

        return StageNetFeature(value=value_tensor, time=time_tensor)

    def _encode_codes(self, codes: List[str]) -> torch.Tensor:
        """Encode flat code list to indices."""
        # Handle empty code list - return single padding token
        if len(codes) == 0:
            return torch.tensor([self.code_vocab["<pad>"]], dtype=torch.long)

        indices = []
        for code in codes:
            if code is None or code not in self.code_vocab:
                indices.append(self.code_vocab["<unk>"])
            else:
                indices.append(self.code_vocab[code])
        return torch.tensor(indices, dtype=torch.long)

    def _encode_nested_codes(self, nested_codes: List[List[str]]) -> torch.Tensor:
        """Encode nested code lists to padded 2D tensor.

        Pads all inner sequences to self._max_nested_len (global max).
        """
        # Handle empty nested codes (no visits/events)
        # Return single padding token with shape (1, max_len)
        if len(nested_codes) == 0:
            pad_token = self.code_vocab["<pad>"]
            return torch.tensor([[pad_token] * self._max_nested_len], dtype=torch.long)

        encoded_sequences = []
        # Use global max length determined during fit
        max_len = self._max_nested_len

        for inner_codes in nested_codes:
            indices = []
            for code in inner_codes:
                if code is None or code not in self.code_vocab:
                    indices.append(self.code_vocab["<unk>"])
                else:
                    indices.append(self.code_vocab[code])
            # Pad to GLOBAL max_len
            while len(indices) < max_len:
                indices.append(self.code_vocab["<pad>"])
            encoded_sequences.append(indices)

        return torch.tensor(encoded_sequences, dtype=torch.long)

    def size(self) -> int:
        """Return vocabulary size."""
        return len(self.code_vocab)

    def __repr__(self):
        if self._is_nested:
            return (
                f"StageNetProcessor(is_nested={self._is_nested}, "
                f"vocab_size={len(self.code_vocab)}, "
                f"max_nested_len={self._max_nested_len})"
            )
        else:
            return (
                f"StageNetProcessor(is_nested={self._is_nested}, "
                f"vocab_size={len(self.code_vocab)})"
            )


@register_processor("stagenet_tensor")
class StageNetTensorProcessor(FeatureProcessor):
    """
    Feature processor for StageNet NUMERIC inputs with coupled value/time data.

    This processor handles numeric feature sequences (flat or nested) and applies
    forward-fill imputation to handle missing values (NaN/None).
    For categorical codes, use StageNetProcessor instead.

    Format:
    {
        "value": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],  # nested numerics
        "time": [0.0, 1.5] or None
    }

    The processor automatically detects:
    - List of numbers -> flat numeric sequences
    - List of lists of numbers -> nested numeric sequences (feature vectors)

    Imputation Strategy:
    - Forward-fill: Missing values (NaN/None) are filled with the last observed
      value for that feature dimension. If no prior value exists, 0.0 is used.
    - Applied per feature dimension independently

    Examples:
        >>> # Case 1: Feature vectors with missing values
        >>> processor = StageNetTensorProcessor()
        >>> data = {
        ...     "value": [[1.0, None, 3.0], [None, 5.0, 6.0], [7.0, 8.0, None]],
        ...     "time": [0.0, 1.5, 3.0]
        ... }
        >>> result = processor.process(data)
        >>> result.value  # [[1.0, 0.0, 3.0], [1.0, 5.0, 6.0], [7.0, 8.0, 6.0]]
        >>> result.value.dtype  # torch.float32
        >>> result.time.shape   # (3,)
    """

    def __init__(self):
        self._size = None  # Feature dimension (set during fit)
        self._is_nested = None

    def fit(self, samples: List[Dict], key: str) -> None:
        """Determine input structure.

        Args:
            samples: List of sample dictionaries
            key: The key in samples that contains StageNet format data
        """
        # Examine first non-None sample to determine structure
        for sample in samples:
            if key in sample and sample[key] is not None:
                value_data = sample[key]["value"]

                # Determine nesting level for numerics
                if isinstance(value_data, list) and len(value_data) > 0:
                    first_elem = value_data[0]

                    if isinstance(first_elem, (int, float)):
                        # Flat numeric: [1.5, 2.0, ...]
                        self._is_nested = False
                        self._size = 1
                    elif isinstance(first_elem, list):
                        if len(first_elem) > 0:
                            if isinstance(first_elem[0], (int, float)):
                                # Nested numerics: [[1.0, 2.0], [3.0, 4.0]]
                                self._is_nested = True
                                self._size = len(first_elem)
                break

    def process(self, value: Dict[str, Any]) -> StageNetFeature:
        """Process StageNet format numeric data into tensors.

        Applies forward-fill imputation to handle NaN/None values in the data.
        For each feature dimension, missing values are filled with the last
        observed value (or 0.0 if no prior value exists).

        Args:
            value: Dictionary with "value" and optional "time" keys

        Returns:
            StageNetFeature with value and time tensors (imputed)
        """
        value_data = value["value"]
        time_data = value.get("time", None)

        # Convert to numpy for easier imputation handling
        import numpy as np

        value_array = np.array(value_data, dtype=float)

        # Apply forward-fill imputation
        if value_array.ndim == 1:
            # Flat numeric: [1.5, 2.0, nan, 3.0, ...]
            last_value = 0.0
            for i in range(len(value_array)):
                if not np.isnan(value_array[i]):
                    last_value = value_array[i]
                else:
                    value_array[i] = last_value
        elif value_array.ndim == 2:
            # Feature vectors: [[1.0, nan, 3.0], [nan, 5.0, 6.0], ...]
            num_features = value_array.shape[1]
            for f in range(num_features):
                last_value = 0.0
                for t in range(value_array.shape[0]):
                    if not np.isnan(value_array[t, f]):
                        last_value = value_array[t, f]
                    else:
                        value_array[t, f] = last_value

        # Convert to float tensor
        value_tensor = torch.tensor(value_array, dtype=torch.float)

        # Process time if present
        time_tensor = None
        if time_data is not None and len(time_data) > 0:
            # Handle both [0.0, 1.5] and [[0.0], [1.5]] formats
            if isinstance(time_data[0], list):
                # Flatten [[0.0], [1.5]] -> [0.0, 1.5]
                time_data = [t[0] if isinstance(t, list) else t for t in time_data]
            time_tensor = torch.tensor(time_data, dtype=torch.float)

        return StageNetFeature(value=value_tensor, time=time_tensor)

    @property
    def size(self):
        """Return feature dimension."""
        return self._size

    def __repr__(self):
        return (
            f"StageNetTensorProcessor(is_nested={self._is_nested}, "
            f"feature_dim={self._size})"
        )


## Setting the task and caching the data for quicker use down the road
We can finally set our task and get our training set below. Notice that we save a processed version of our dataset in .parquet files in our "cache_dir" here. We can also define a number of works for faster parallel processing (note this can be unstable if the value is too high).

In [None]:
# STEP 2: Apply StageNet mortality prediction task
sample_dataset = base_dataset.set_task(
    MortalityPredictionStageNetMIMIC4(),
    num_workers=4,
    cache_dir="../../mimic4_stagenet_cache",
)

In [None]:
# Inspect a sample
sample = sample_dataset.samples[0]
print("\nSample structure:")
print(f"  Patient ID: {sample['patient_id']}")
print(f"ICD Codes: {sample['icd_codes']}")
print(f"  Labs shape: {len(sample['labs'].value)} timesteps")
print(f"  Mortality: {sample['mortality']}")



## Train, Validation, Test Splits and Training

This section fundamentally follows any typical training pipeline. We don't recommend the PyHealth trainer beyond just testing out baselines, but any code you write here should flexibly translate to more advanced deep learning training packages like PyTorch lightning and many others.

In [None]:
# STEP 3: Split dataset
train_dataset, val_dataset, test_dataset = split_by_patient(
    sample_dataset, [0.8, 0.1, 0.1]
)

# Create dataloaders
train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False)

# STEP 4: Initialize StageNet model
model = StageNet(
    dataset=sample_dataset,
    embedding_dim=128,
    chunk_size=128,
    levels=3,
    dropout=0.3,
)

num_params = sum(p.numel() for p in model.parameters())
print(f"\nModel initialized with {num_params} parameters")

# STEP 5: Train the model
trainer = Trainer(
    model=model,
    device="cuda:2",  # or "cpu"
    metrics=["pr_auc", "roc_auc", "accuracy", "f1"],
)

trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=50,
    monitor="roc_auc",
    optimizer_params={"lr": 1e-5},
)

# STEP 6: Evaluate on test set
results = trainer.evaluate(test_loader)
print("\nTest Results:")
for metric, value in results.items():
    print(f"  {metric}: {value:.4f}")

# STEP 7: Inspect model predictions
sample_batch = next(iter(test_loader))
with torch.no_grad():
    output = model(**sample_batch)

print("\nSample predictions:")
print(f"  Predicted probabilities: {output['y_prob'][:5]}")
print(f"  True labels: {output['y_true'][:5]}")

## Post-hoc ML processing (TBD)
We note that once the model's trained and evaluation metrics are derived. People may be interested in things like post-hoc interpretability or uncertainty quantification.

We note that this is quite a work-in-progress for PyHealth 2.0, but the roadmap includes the following:

- Layer-wise relevance propagation (deep NN-based interpretability)
- Conformal Prediction: Although we do have many other UQ techniques [here](https://pyhealth.readthedocs.io/en/latest/api/calib.html)