
# Getting Started
Here, we will go over the following with StageNet across all utility modules in PyHealth:

1. Loading the data
2. Task Processing (with padding to ensure compatibility)
3. ML Model Initialization 
4. Model training
5. Holdout Inference on Sets of Codes Not in Vocabulary
6. Interpretability Example with DeepLift

## Installation

Install the latest alpha release of StageNet modernized for PyHealth:

```bash
pip install pyhealth==2.0a10
```

## Loading Data

Load the PyHealth dataset for mortality prediction.

PyHealth datasets use a `config.yaml` file to define:
- Input tables (.csv, .tsv, etc.)
- Features to extract
- Aggregation methods

The result is a single dataframe where each row represents one patient and their features.

For more details on PyHealth datasets, see [this resource](https://colab.research.google.com/drive/1voSx7wEfzXfEf2sIfW6b-8p1KqMyuWxK#scrollTo=NSrb2PGFqUgS).
```

In [None]:
"""
Example of using StageNet for mortality prediction on MIMIC-IV.

This example demonstrates:
1. Loading MIMIC-IV data
2. Applying the MortalityPredictionStageNetMIMIC4 task
3. Creating a SampleDataset with StageNet processors
4. Training a StageNet model
"""

from pyhealth.datasets import (
    MIMIC4Dataset,
    get_dataloader,
    split_by_patient,
)
from pyhealth.models import StageNet
from pyhealth.tasks import MortalityPredictionStageNetMIMIC4
from pyhealth.trainer import Trainer

# STEP 1: Load MIMIC-IV base dataset
base_dataset = MIMIC4Dataset(
    ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/",
    ehr_tables=[
        "patients",
        "admissions",
        "diagnoses_icd",
        "procedures_icd",
        "labevents",
    ],
    dev=True
)

## Input and Output Schemas
Input and output schemas map feature keys (e.g., "labs", "icd_codes") to StageNet processors. Each processor converts features into tuple objects used for training and inference.

**Required format:** Each feature processed in our task call must follow this structure:
```python
"feature": (my_times_list, my_values_list)
```
We offer two types of StageNet processors, one for categorical variables, and the other for numerical feature variables. Our goal here is to represent each feature as a pre-defined tuple (time, value) that we can later pass to StageNet for processing.


## What are these processors?

Effectively processors take existing data variables and turns them into a tensor format. Here, we define a set of custom processors so we can leverage StageNet's ability to take in a time-series set of time intervals and feature sets.

## StageNetProcessor - For Categories (Labels)

**What it handles:** Text labels like diagnosis codes, medication names, or lab test types.

**What it does:**
- Takes lists of codes (like `["diabetes", "hypertension"]`)
- Converts each word into a unique number (like `["diabetes"=1, "hypertension"=2]`)
- Keeps track of when things happened (timestamps)
- Can handle nested lists (like multiple codes per visit)

**Example:** If a patient had 3 doctor visits with different diagnoses, this processor remembers what diagnosis happened at each visit and when.

## StageNetTensorProcessor - For Numbers (Measurements)

**What it handles:** Actual measurements like blood pressure, temperature, or lab values.

**What it does:**
- Takes lists of numbers (like `[98.6, 99.1, 98.8]` for temperatures)
- Fills in missing measurements using the last known value (forward-fill)
- Keeps track of when measurements were taken
- Can handle multiple measurements at once (like blood pressure AND heart rate)

**Example:** If a patient's heart rate was measured as `[72, None, 68]`, it fills in the missing value as `[72, 72, 68]` (copying the last known value).

## How Time Processing Works

Both processors handle time information in a flexible way:

**Input formats accepted:**
- Simple list: `[0.0, 1.5, 3.0]` - time intervals in hours/days
- Nested list: `[[0.0], [1.5], [3.0]]` - automatically flattened
- No time: `None` - when timing doesn't matter

**What the time means:**
- Times represent intervals or delays between events
- For example: `[0.0, 2.5, 1.0]` could mean "first event at start, second event 2.5 hours later, third event 1 hour after that"
- Times are converted to float tensors so the model can learn temporal patterns

**Example:**
```python
# Patient temperature readings
data = {
    "value": [98.6, 99.1, 98.8],  # temperatures in °F
    "time": [0.0, 2.0, 1.0]        # hours since previous admissions
}
```

The processor keeps the time and values paired together, so the model knows that 99.1°F was recorded at 2 hours after admission.

For syntactic reasons, we add the suffix "Ex" as they're already implemented in PyHealth. This is more to showcase what's happening underneath the hood.

In [None]:
from typing import Any, Dict, List, Optional, Tuple

import torch

from pyhealth.processors import register_processor
from pyhealth.processors.base_processor import FeatureProcessor

@register_processor("stagenet_ex")
class StageNetProcessor(FeatureProcessor):
    """
    Feature processor for StageNet CODE inputs with coupled value/time data.

    This processor handles categorical code sequences (flat or nested).
    For numeric features, use StageNetTensorProcessor instead.

    Input Format (tuple):
        (time, values) where:
        - time: List of scalars [0.0, 2.0, 1.3] or None
        - values: ["code1", "code2"] or [["A", "B"], ["C"]]

    The processor automatically detects:
    - List of strings -> flat code sequences
    - List of lists of strings -> nested code sequences

    Args:
        padding: Additional padding to add on top of the observed maximum nested
            sequence length. The actual padding length will be observed_max + padding.
            This ensures the processor can handle sequences longer than those in the
            training data. Default: 0 (no extra padding). Only applies to nested sequences.

    Returns:
        Tuple of (time_tensor, value_tensor) where time_tensor can be None

    Examples:
        >>> # Case 1: Code sequence with time
        >>> processor = StageNetProcessor()
        >>> data = ([0.0, 1.5, 2.3], ["code1", "code2", "code3"])
        >>> time, values = processor.process(data)
        >>> values.shape  # (3,) - sequence of code indices
        >>> time.shape    # (3,) - time intervals

        >>> # Case 2: Nested codes with time (with custom padding for extra capacity)
        >>> processor = StageNetProcessor(padding=20)
        >>> data = ([0.0, 1.5], [["A", "B"], ["C"]])
        >>> time, values = processor.process(data)
        >>> values.shape  # (2, observed_max + 20) - padded nested sequences
        >>> time.shape    # (2,)

        >>> # Case 3: Codes without time
        >>> data = (None, ["code1", "code2"])
        >>> time, values = processor.process(data)
        >>> values.shape  # (2,)
        >>> time          # None
    """

    def __init__(self, padding: int = 0):
        # <unk> will be set to len(vocab) after fit
        self.code_vocab: Dict[Any, int] = {"<unk>": None, "<pad>": 0}
        self._next_index = 1
        self._is_nested = None  # Will be determined during fit
        # Max inner sequence length for nested codes
        self._max_nested_len = None
        self._padding = padding  # Additional padding beyond observed max

    def fit(self, samples: List[Dict], key: str) -> None:
        """Build vocabulary and determine input structure.

        Args:
            samples: List of sample dictionaries
            key: The key in samples that contains tuple (time, values)
        """
        # Examine first non-None sample to determine structure
        for sample in samples:
            if key in sample and sample[key] is not None:
                # Unpack tuple: (time, values)
                time_data, value_data = sample[key]

                # Determine nesting level for codes
                if isinstance(value_data, list) and len(value_data) > 0:
                    first_elem = value_data[0]

                    if isinstance(first_elem, str):
                        # Case 1: ["code1", "code2", ...]
                        self._is_nested = False
                    elif isinstance(first_elem, list):
                        if len(first_elem) > 0 and isinstance(first_elem[0], str):
                            # Case 2: [["A", "B"], ["C"], ...]
                            self._is_nested = True
                break

        # Build vocabulary for codes and find max nested length
        max_inner_len = 0
        for sample in samples:
            if key in sample and sample[key] is not None:
                # Unpack tuple: (time, values)
                time_data, value_data = sample[key]

                if self._is_nested:
                    # Nested codes
                    for inner_list in value_data:
                        # Track max inner length
                        max_inner_len = max(max_inner_len, len(inner_list))
                        for code in inner_list:
                            if code is not None and code not in self.code_vocab:
                                self.code_vocab[code] = self._next_index
                                self._next_index += 1
                else:
                    # Flat codes
                    for code in value_data:
                        if code is not None and code not in self.code_vocab:
                            self.code_vocab[code] = self._next_index
                            self._next_index += 1

        # Store max nested length: add user-specified padding to observed maximum
        # This ensures the processor can handle sequences longer than those in training data
        if self._is_nested:
            observed_max = max(1, max_inner_len)
            self._max_nested_len = observed_max + self._padding

        # Set <unk> token to the next available index
        # Since <unk> is already in the vocab dict, we use _next_index
        self.code_vocab["<unk>"] = self._next_index

    def process(
        self, value: Tuple[Optional[List], List]
    ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
        """Process tuple format data into tensors.

        Args:
            value: Tuple of (time, values) where values are codes

        Returns:
            Tuple of (time_tensor, value_tensor), time can be None
        """
        # Unpack tuple: (time, values)
        time_data, value_data = value

        # Encode codes to indices
        if self._is_nested:
            # Nested codes: [["A", "B"], ["C"]]
            value_tensor = self._encode_nested_codes(value_data)
        else:
            # Flat codes: ["code1", "code2"]
            value_tensor = self._encode_codes(value_data)

        # Process time if present
        time_tensor = None
        if time_data is not None and len(time_data) > 0:
            # Handle both [0.0, 1.5] and [[0.0], [1.5]] formats
            if isinstance(time_data[0], list):
                # Flatten [[0.0], [1.5]] -> [0.0, 1.5]
                time_data = [t[0] if isinstance(t, list) else t for t in time_data]
            time_tensor = torch.tensor(time_data, dtype=torch.float)

        return (time_tensor, value_tensor)

    def _encode_codes(self, codes: List[str]) -> torch.Tensor:
        """Encode flat code list to indices."""
        # Handle empty code list - return single padding token
        if len(codes) == 0:
            return torch.tensor([self.code_vocab["<pad>"]], dtype=torch.long)

        indices = []
        for code in codes:
            if code is None or code not in self.code_vocab:
                indices.append(self.code_vocab["<unk>"])
            else:
                indices.append(self.code_vocab[code])
        return torch.tensor(indices, dtype=torch.long)

    def _encode_nested_codes(self, nested_codes: List[List[str]]) -> torch.Tensor:
        """Encode nested code lists to padded 2D tensor.

        Pads all inner sequences to self._max_nested_len (global max).
        """
        # Handle empty nested codes (no visits/events)
        # Return single padding token with shape (1, max_len)
        if len(nested_codes) == 0:
            pad_token = self.code_vocab["<pad>"]
            return torch.tensor([[pad_token] * self._max_nested_len], dtype=torch.long)

        encoded_sequences = []
        # Use global max length determined during fit
        max_len = self._max_nested_len

        for inner_codes in nested_codes:
            indices = []
            for code in inner_codes:
                if code is None or code not in self.code_vocab:
                    indices.append(self.code_vocab["<unk>"])
                else:
                    indices.append(self.code_vocab[code])
            # Pad to GLOBAL max_len
            while len(indices) < max_len:
                indices.append(self.code_vocab["<pad>"])
            encoded_sequences.append(indices)

        return torch.tensor(encoded_sequences, dtype=torch.long)

    def size(self) -> int:
        """Return vocabulary size."""
        return len(self.code_vocab)

    def __repr__(self):
        if self._is_nested:
            return (
                f"StageNetProcessor(is_nested={self._is_nested}, "
                f"vocab_size={len(self.code_vocab)}, "
                f"max_nested_len={self._max_nested_len}, "
                f"padding={self._padding})"
            )
        else:
            return (
                f"StageNetProcessor(is_nested={self._is_nested}, "
                f"vocab_size={len(self.code_vocab)}, "
                f"padding={self._padding})"
            )


@register_processor("stagenet_tensor_ex")
class StageNetTensorProcessor(FeatureProcessor):
    """
    Feature processor for StageNet NUMERIC inputs with coupled value/time data.

    This processor handles numeric feature sequences (flat or nested) and applies
    forward-fill imputation to handle missing values (NaN/None).
    For categorical codes, use StageNetProcessor instead.

    Format:
    {
        "value": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],  # nested numerics
        "time": [0.0, 1.5] or None
    }

    The processor automatically detects:
    - List of numbers -> flat numeric sequences
    - List of lists of numbers -> nested numeric sequences (feature vectors)

    Imputation Strategy:
    - Forward-fill: Missing values (NaN/None) are filled with the last observed
      value for that feature dimension. If no prior value exists, 0.0 is used.
    - Applied per feature dimension independently

    Returns:
        Tuple of (time_tensor, value_tensor) where time_tensor can be None

    Examples:
        >>> # Case 1: Feature vectors with missing values
        >>> processor = StageNetTensorProcessor()
        >>> data = {
        ...     "value": [[1.0, None, 3.0], [None, 5.0, 6.0], [7.0, 8.0, None]],
        ...     "time": [0.0, 1.5, 3.0]
        ... }
        >>> time, values = processor.process(data)
        >>> values  # [[1.0, 0.0, 3.0], [1.0, 5.0, 6.0], [7.0, 8.0, 6.0]]
        >>> values.dtype  # torch.float32
        >>> time.shape    # (3,)
    """

    def __init__(self):
        self._size = None  # Feature dimension (set during fit)
        self._is_nested = None

    def fit(self, samples: List[Dict], key: str) -> None:
        """Determine input structure.

        Args:
            samples: List of sample dictionaries
            key: The key in samples that contains tuple (time, values)
        """
        # Examine first non-None sample to determine structure
        for sample in samples:
            if key in sample and sample[key] is not None:
                # Unpack tuple: (time, values)
                time_data, value_data = sample[key]

                # Determine nesting level for numerics
                if isinstance(value_data, list) and len(value_data) > 0:
                    first_elem = value_data[0]

                    if isinstance(first_elem, (int, float)):
                        # Flat numeric: [1.5, 2.0, ...]
                        self._is_nested = False
                        self._size = 1
                    elif isinstance(first_elem, list):
                        if len(first_elem) > 0:
                            if isinstance(first_elem[0], (int, float)):
                                # Nested numerics: [[1.0, 2.0], [3.0, 4.0]]
                                self._is_nested = True
                                self._size = len(first_elem)
                break

    def process(
        self, value: Tuple[Optional[List], List]
    ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
        """Process tuple format numeric data into tensors.

        Applies forward-fill imputation to handle NaN/None values.
        For each feature dimension, missing values are filled with the
        last observed value (or 0.0 if no prior value exists).

        Args:
            value: Tuple of (time, values) where values are numerics

        Returns:
            Tuple of (time_tensor, value_tensor), time can be None
        """
        # Unpack tuple: (time, values)
        time_data, value_data = value

        # Convert to numpy for easier imputation handling
        import numpy as np

        value_array = np.array(value_data, dtype=float)

        # Apply forward-fill imputation
        if value_array.ndim == 1:
            # Flat numeric: [1.5, 2.0, nan, 3.0, ...]
            last_value = 0.0
            for i in range(len(value_array)):
                if not np.isnan(value_array[i]):
                    last_value = value_array[i]
                else:
                    value_array[i] = last_value
        elif value_array.ndim == 2:
            # Feature vectors: [[1.0, nan, 3.0], [nan, 5.0, 6.0]]
            num_features = value_array.shape[1]
            for f in range(num_features):
                last_value = 0.0
                for t in range(value_array.shape[0]):
                    if not np.isnan(value_array[t, f]):
                        last_value = value_array[t, f]
                    else:
                        value_array[t, f] = last_value

        # Convert to float tensor
        value_tensor = torch.tensor(value_array, dtype=torch.float)

        # Process time if present
        time_tensor = None
        if time_data is not None and len(time_data) > 0:
            # Handle both [0.0, 1.5] and [[0.0], [1.5]] formats
            if isinstance(time_data[0], list):
                # Flatten [[0.0], [1.5]] -> [0.0, 1.5]
                time_data = [t[0] if isinstance(t, list) else t for t in time_data]
            time_tensor = torch.tensor(time_data, dtype=torch.float)

        return (time_tensor, value_tensor)

    @property
    def size(self):
        """Return feature dimension."""
        return self._size

    def __repr__(self):
        return (
            f"StageNetTensorProcessor(is_nested={self._is_nested}, "
            f"feature_dim={self._size})"
        )


## Defining a Our StageNet-specific Task

We'll predict patient mortality using StageNet across time-series data from multiple visits. Each visit includes:

- Diagnosis codes
- Procedure codes
- Lab events

Here, each feature will also need have its own corresponding time intervals. As defined by the StageNet paper, each time interval is defined as the difference in time between the current visit and the previous visit. 

To define a task, specify the `__call__` method, input schema, and output schema. For a detailed explanation, see [this tutorial](https://colab.research.google.com/drive/1kKKBVS_GclHoYTbnOtjyYnSee79hsyT?usp=sharing).

### Helper Functions

Use `patient.get_events()` to retrieve all events from a specific table, with optional filtering. See the [MIMIC-IV YAML file](https://github.com/sunlabuiuc/PyHealth/blob/master/pyhealth/datasets/configs/mimic4_ehr.yaml) for available tables.

In [None]:
from datetime import datetime
from typing import Any, ClassVar, Dict, List, Tuple

import polars as pl

from pyhealth.tasks.base_task import BaseTask


class MortalityPredictionStageNetMIMIC4(BaseTask):
    """Task for predicting mortality using MIMIC-IV with StageNet format.

    This task creates PATIENT-LEVEL samples (not visit-level) by aggregating
    all admissions for each patient. ICD codes (diagnoses + procedures) and
    lab results across all visits are combined with time intervals calculated
    from the patient's first admission timestamp.

    Time Calculation:
        - ICD codes: Hours from previous admission (0 for first visit,
          then time intervals between consecutive visits)
        - Labs: Hours from admission start (within-visit measurements)

    Lab Processing:
        - 10-dimensional vectors (one per lab category)
        - Multiple itemids per category → take first observed value
        - Missing categories → None/NaN in vector

    Args:
        padding: Additional padding for StageNet processor to handle
            sequences longer than observed during training. Default: 0.

    Attributes:
        task_name (str): The name of the task.
        input_schema (Dict[str, str]): The schema for input data:
            - icd_codes: Combined diagnosis + procedure ICD codes
              (stagenet format, nested by visit)
            - labs: Lab results (stagenet_tensor, 10D vectors per timestamp)
        output_schema (Dict[str, str]): The schema for output data:
            - mortality: Binary indicator (1 if any admission had mortality)
    """

    task_name: str = "MortalityPredictionStageNetMIMIC4"

    def __init__(self, padding: int = 0):
        """Initialize task with optional padding parameter.

        Args:
            padding: Additional padding for nested sequences. Default: 0.
        """
        self.padding = padding
        # Use tuple format to pass kwargs to processor
        self.input_schema: Dict[str, Tuple[str, Dict[str, Any]]] = {
            "icd_codes": ("stagenet", {"padding": padding}),
            "labs": ("stagenet_tensor", {}),
        }
        self.output_schema: Dict[str, str] = {"mortality": "binary"}

    # Organize lab items by category
    # Each category will map to ONE dimension in the output vector
    LAB_CATEGORIES: ClassVar[Dict[str, List[str]]] = {
        "Sodium": ["50824", "52455", "50983", "52623"],
        "Potassium": ["50822", "52452", "50971", "52610"],
        "Chloride": ["50806", "52434", "50902", "52535"],
        "Bicarbonate": ["50803", "50804"],
        "Glucose": ["50809", "52027", "50931", "52569"],
        "Calcium": ["50808", "51624"],
        "Magnesium": ["50960"],
        "Anion Gap": ["50868", "52500"],
        "Osmolality": ["52031", "50964", "51701"],
        "Phosphate": ["50970"],
    }

    # Ordered list of category names (defines vector dimension order)
    LAB_CATEGORY_NAMES: ClassVar[List[str]] = [
        "Sodium",
        "Potassium",
        "Chloride",
        "Bicarbonate",
        "Glucose",
        "Calcium",
        "Magnesium",
        "Anion Gap",
        "Osmolality",
        "Phosphate",
    ]

    # Flat list of all lab item IDs for filtering
    LABITEMS: ClassVar[List[str]] = [
        item for itemids in LAB_CATEGORIES.values() for item in itemids
    ]

    def __call__(self, patient: Any) -> List[Dict[str, Any]]:
        """Process a patient to create mortality prediction samples.

        Creates ONE sample per patient with all admissions aggregated.
        Time intervals are calculated between consecutive admissions.

        Args:
            patient: Patient object with get_events method

        Returns:
            List with single sample containing patient_id, all conditions,
            procedures, labs across visits, and final mortality label
        """
        # Filter patients by age (>= 18)
        demographics = patient.get_events(event_type="patients")
        if not demographics:
            return []

        demographics = demographics[0]
        try:
            anchor_age = int(demographics.anchor_age)
            if anchor_age < 18:
                return []
        except (ValueError, TypeError, AttributeError):
            # If age can't be determined, skip patient
            return []

        # Get all admissions
        admissions = patient.get_events(event_type="admissions")
        if len(admissions) < 1:
            return []

        # Initialize aggregated data structures
        # List of ICD codes (diagnoses + procedures) per visit
        all_icd_codes = []
        all_icd_times = []  # Time from previous admission per visit
        all_lab_values = []  # List of 10D lab vectors
        all_lab_times = []  # Time from admission start per measurement

        # Track previous admission timestamp for interval calculation
        previous_admission_time = None

        # Track if patient had any mortality event
        final_mortality = 0

        # Process each admission
        for i, admission in enumerate(admissions):
            # Parse admission and discharge times
            try:
                admission_time = admission.timestamp
                admission_dischtime = datetime.strptime(
                    admission.dischtime, "%Y-%m-%d %H:%M:%S"
                )
            except (ValueError, AttributeError):
                # Skip if timestamps invalid
                continue

            # Skip if discharge is before admission (data quality issue)
            if admission_dischtime < admission_time:
                continue

            # Calculate time from previous admission (in hours)
            # First admission will have time = 0
            if previous_admission_time is None:
                time_from_previous = 0.0
            else:
                time_from_previous = (
                    admission_time - previous_admission_time
                ).total_seconds() / 3600.0

            # Update previous admission time for next iteration
            previous_admission_time = admission_time

            # Update mortality label if this admission had mortality
            try:
                if int(admission.hospital_expire_flag) == 1:
                    final_mortality = 1
            except (ValueError, TypeError, AttributeError):
                pass

            # Get diagnosis codes for this admission using hadm_id
            diagnoses_icd = patient.get_events(
                event_type="diagnoses_icd",
                filters=[("hadm_id", "==", admission.hadm_id)],
            )
            visit_diagnoses = [
                event.icd_code
                for event in diagnoses_icd
                if hasattr(event, "icd_code") and event.icd_code
            ]

            # Get procedure codes for this admission using hadm_id
            procedures_icd = patient.get_events(
                event_type="procedures_icd",
                filters=[("hadm_id", "==", admission.hadm_id)],
            )
            visit_procedures = [
                event.icd_code
                for event in procedures_icd
                if hasattr(event, "icd_code") and event.icd_code
            ]

            # Combine diagnoses and procedures into single ICD code list
            visit_icd_codes = visit_diagnoses + visit_procedures

            if visit_icd_codes:
                all_icd_codes.append(visit_icd_codes)
                all_icd_times.append(time_from_previous)

            # Get lab events for this admission
            labevents_df = patient.get_events(
                event_type="labevents",
                start=admission_time,
                end=admission_dischtime,
                return_df=True,
            )

            # Filter to relevant lab items
            labevents_df = labevents_df.filter(
                pl.col("labevents/itemid").is_in(self.LABITEMS)
            )

            # Parse storetime and filter
            if labevents_df.height > 0:
                labevents_df = labevents_df.with_columns(
                    pl.col("labevents/storetime").str.strptime(
                        pl.Datetime, "%Y-%m-%d %H:%M:%S"
                    )
                )
                labevents_df = labevents_df.filter(
                    pl.col("labevents/storetime") <= admission_dischtime
                )

                if labevents_df.height > 0:
                    # Select relevant columns
                    labevents_df = labevents_df.select(
                        pl.col("timestamp"),
                        pl.col("labevents/itemid"),
                        pl.col("labevents/valuenum").cast(pl.Float64),
                    )

                    # Group by timestamp and aggregate into 10D vectors
                    # For each timestamp, create vector of lab categories
                    unique_timestamps = sorted(
                        labevents_df["timestamp"].unique().to_list()
                    )

                    for lab_ts in unique_timestamps:
                        # Get all lab events at this timestamp
                        ts_labs = labevents_df.filter(pl.col("timestamp") == lab_ts)

                        # Create 10-dimensional vector (one per category)
                        lab_vector = []
                        for category_name in self.LAB_CATEGORY_NAMES:
                            category_itemids = self.LAB_CATEGORIES[category_name]

                            # Find first matching value for this category
                            category_value = None
                            for itemid in category_itemids:
                                matching = ts_labs.filter(
                                    pl.col("labevents/itemid") == itemid
                                )
                                if matching.height > 0:
                                    category_value = matching["labevents/valuenum"][0]
                                    break

                            lab_vector.append(category_value)

                        # Calculate time from admission start (hours)
                        time_from_admission = (
                            lab_ts - admission_time
                        ).total_seconds() / 3600.0

                        all_lab_values.append(lab_vector)
                        all_lab_times.append(time_from_admission)

        # Skip if no lab events (required for this task)
        if len(all_lab_values) == 0:
            return []

        # Also skip if no ICD codes across all admissions
        if len(all_icd_codes) == 0:
            return []

        # Format as tuples: (time, values)
        # ICD codes: nested list with times
        icd_codes_data = (all_icd_times, all_icd_codes)

        # Labs: list of 10D vectors with times
        labs_data = (all_lab_times, all_lab_values)

        # Create single patient-level sample
        sample = {
            "patient_id": patient.patient_id,
            "icd_codes": icd_codes_data,
            "labs": labs_data,
            "mortality": final_mortality,
        }
        return [sample]


## Setting the task and caching the data for quicker use down the road with padding
We can finally set our task and get our training set below. Notice that we save a processed version of our dataset in .parquet files in our "cache_dir" here. We can also define a number of works for faster parallel processing (note this can be unstable if the value is too high).

We can also save and load processors so we don't need to refit the processor again (and we can also transfer processors across different samples)

In [None]:
from pyhealth.datasets.utils import save_processors, load_processors
import os 
processor_dir = "../../output/processors/stagenet_mortality_mimic4"
cache_dir = "../../mimic4_stagenet_cache_v3"

if os.path.exists(os.path.join(processor_dir, "input_processors.pkl")):
    print("\n=== Loading Pre-fitted Processors ===")
    input_processors, output_processors = load_processors(processor_dir)

    sample_dataset = base_dataset.set_task(
        MortalityPredictionStageNetMIMIC4(padding=20),
        num_workers=1,
        cache_dir=cache_dir,
        input_processors=input_processors,
        output_processors=output_processors,
    )
else:
    print("\n=== Fitting New Processors ===")
    sample_dataset = base_dataset.set_task(
        MortalityPredictionStageNetMIMIC4(padding=20),
        num_workers=1,
        cache_dir=cache_dir,
    )

    # Save processors for future runs
    print("\n=== Saving Processors ===")
    save_processors(sample_dataset, processor_dir)

print(f"Total samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

In [None]:
# Inspect a sample
sample = sample_dataset.samples[0]
print("\nSample structure:")
print(f"  Patient ID: {sample['patient_id']}")
print(f"ICD Codes: {sample['icd_codes']}")
print(f"  Labs shape: {len(sample['labs'][0])} timesteps")
print(f"  Mortality: {sample['mortality']}")



## Train, Validation, Test Splits and Training

This section fundamentally follows any typical training pipeline. We don't recommend the PyHealth trainer beyond just testing out baselines, but any code you write here should flexibly translate to more advanced deep learning training packages like PyTorch lightning and many others.

In [None]:
sample_dataset.input_schema

In [None]:
# STEP 3: Split dataset
train_dataset, val_dataset, test_dataset = split_by_patient(
    sample_dataset, [0.8, 0.1, 0.1]
)

# Create dataloaders
train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False)

# STEP 4: Initialize StageNet model
model = StageNet(
    dataset=sample_dataset,
    embedding_dim=128,
    chunk_size=128,
    levels=3,
    dropout=0.3,
)

num_params = sum(p.numel() for p in model.parameters())
print(f"\nModel initialized with {num_params} parameters")

# STEP 5: Train the model
trainer = Trainer(
    model=model,
    device="cuda:4",  # or "cpu"
    metrics=["pr_auc", "roc_auc", "accuracy", "f1"],
)

# 1 epoch for demonstration; increase for real training, it should work pretty well closer to 50
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=1,
    monitor="roc_auc",
    optimizer_params={"lr": 1e-5},
)

# STEP 6: Evaluate on test set
results = trainer.evaluate(test_loader)
print("\nTest Results:")
for metric, value in results.items():
    print(f"  {metric}: {value:.4f}")

# STEP 7: Inspect model predictions
sample_batch = next(iter(test_loader))
with torch.no_grad():
    output = model(**sample_batch)

print("\nSample predictions:")
print(f"  Predicted probabilities: {output['y_prob'][:5]}")
print(f"  True labels: {output['y_true'][:5]}")

## Inference On a Holdout Set Example
Below, we'll generate some pseudo samples with a bunch of unknown tokens and visit lengths beyond what's observed in the training set.

In [None]:
from pyhealth.datasets.base_dataset import SampleDataset
import random
import numpy as np

def generate_holdout_set(
    sample_dataset: SampleDataset, num_samples: int = 10, seed: int = 42
) -> SampleDataset:
    """Generate synthetic hold-out set with unseen codes and varying lengths.

    This function creates synthetic samples to test the processor's ability to:
    1. Handle completely unseen tokens (mapped to <unk>)
    2. Handle sequence lengths larger than training but within padding

    Args:
        sample_dataset: Original SampleDataset with fitted processors
        num_samples: Number of synthetic samples to generate
        seed: Random seed for reproducibility

    Returns:
        SampleDataset with synthetic samples using fitted processors
    """
    random.seed(seed)
    np.random.seed(seed)

    # Get the fitted processors
    icd_processor = sample_dataset.input_processors["icd_codes"]

    # Get max nested length from ICD processor
    max_icd_len = icd_processor._max_nested_len
    # Handle both old and new processor versions
    padding = getattr(icd_processor, "_padding", 0)

    print("\n=== Hold-out Set Generation ===")
    print(f"Processor attributes: {dir(icd_processor)}")
    print(f"Has _padding attribute: {hasattr(icd_processor, '_padding')}")
    print(f"ICD max nested length: {max_icd_len}")
    print(f"Padding (via getattr): {padding}")
    if hasattr(icd_processor, "_padding"):
        print(f"Padding (direct access): {icd_processor._padding}")
    print(f"Observed max (without padding): {max_icd_len - padding}")

    synthetic_samples = []

    for i in range(num_samples):
        # Generate random number of visits (1-5)
        num_visits = random.randint(1, 5)

        # Generate ICD codes with unseen tokens
        icd_codes_list = []
        icd_times_list = []

        for visit_idx in range(num_visits):
            # Generate sequence length between observed_max and max_icd_len
            # This tests the padding capacity
            observed_max = max_icd_len - padding
            seq_len = random.randint(max(1, observed_max - 2), max_icd_len - 1)

            # Generate unseen codes
            visit_codes = [f"NEWCODE_{i}_{visit_idx}_{j}" for j in range(seq_len)]
            icd_codes_list.append(visit_codes)

            # Generate time intervals (hours from previous visit)
            if visit_idx == 0:
                icd_times_list.append(0.0)
            else:
                icd_times_list.append(random.uniform(24.0, 720.0))

        # Generate lab data (10-dimensional vectors)
        num_lab_timestamps = random.randint(5, 15)
        lab_values_list = []
        lab_times_list = []

        for ts_idx in range(num_lab_timestamps):
            # Generate 10D vector with some random values and some None
            lab_vector = []
            for dim in range(10):
                if random.random() < 0.8:  # 80% chance of value
                    lab_vector.append(random.uniform(50.0, 150.0))
                else:
                    lab_vector.append(None)

            lab_values_list.append(lab_vector)
            lab_times_list.append(random.uniform(0.0, 48.0))

        # Create sample in the expected format (before processing)
        synthetic_sample = {
            "patient_id": f"HOLDOUT_PATIENT_{i}",
            "icd_codes": (icd_times_list, icd_codes_list),
            "labs": (lab_times_list, lab_values_list),
            "mortality": random.randint(0, 1),
        }

        synthetic_samples.append(synthetic_sample)

    # Create a new SampleDataset with the FITTED processors
    holdout_dataset = SampleDataset(
        samples=synthetic_samples,
        input_schema=sample_dataset.input_schema,
        output_schema=sample_dataset.output_schema,
        dataset_name=f"{sample_dataset.dataset_name}_holdout",
        task_name=sample_dataset.task_name,
        input_processors=sample_dataset.input_processors,
        output_processors=sample_dataset.output_processors,
    )

    print(f"Generated {len(holdout_dataset)} synthetic samples")
    sample_seq_lens = [len(s["icd_codes"][1]) for s in synthetic_samples[:3]]
    print(f"Sample ICD sequence lengths: {sample_seq_lens}")
    sample_codes_per_visit = [
        [len(visit) for visit in s["icd_codes"][1]] for s in synthetic_samples[:3]
    ]
    print(f"Sample codes per visit: {sample_codes_per_visit}")

    return holdout_dataset



In [None]:
holdout_dataset = generate_holdout_set(sample_dataset, num_samples=10, seed=42)
# Create dataloader for hold-out set
holdout_loader = get_dataloader(holdout_dataset, batch_size=16, shuffle=False)
# Inspect processed samples
print("\n=== Inspecting Processed Hold-out Samples ===")
holdout_batch = next(iter(holdout_loader))
with torch.no_grad():
    holdout_output = model(**holdout_batch)

## Post-hoc ML processing (Interpretability)
We note that once the model's trained and evaluation metrics are derived. People may be interested in things like post-hoc interpretability or uncertainty quantification.

We note that this is quite a work-in-progress for PyHealth 2.0, but the roadmap includes the following:

- Integrated Gradients (deep NN-based interpretability)
- Conformal Prediction: We do have many other UQ techniques [here](https://pyhealth.readthedocs.io/en/latest/api/calib.html)


In [18]:
from pyhealth.medcode import CrossMap, InnerMap

LAB_CATEGORY_NAMES = MortalityPredictionStageNetMIMIC4.LAB_CATEGORY_NAMES

def unravel(flat_index: int, shape: torch.Size):
    coords = []
    remaining = flat_index
    for dim in reversed(shape):
        coords.append(remaining % dim)
        remaining //= dim
    return list(reversed(coords))

def decode_token(idx: int, processor, feature_key: str):
    icd9cm = InnerMap.load("ICD9CM")

    if processor is None or not hasattr(processor, "code_vocab"):
        return str(idx)
    reverse_vocab = {index: token for token, index in processor.code_vocab.items()}
    token = reverse_vocab.get(idx, f"<UNK:{idx}>")

    if feature_key == "icd_codes" and token not in {"<unk>", "<pad>"}:
        desc = icd9cm.lookup(token)
        if desc:
            return f"{token}: {desc}"

    return token


def print_top_attributions(
    attributions,
    batch,
    processors,
    top_k: int = 10,
):
    for feature_key, attr in attributions.items():
        attr_cpu = attr.detach().cpu()
        if attr_cpu.dim() == 0 or attr_cpu.size(0) == 0:
            continue

        feature_input = batch[feature_key]
        if isinstance(feature_input, tuple):
            feature_input = feature_input[1]
        feature_input = feature_input.detach().cpu()

        flattened = attr_cpu[0].flatten()
        if flattened.numel() == 0:
            continue

        print(f"\nFeature: {feature_key}")
        k = min(top_k, flattened.numel())
        top_values, top_indices = torch.topk(flattened.abs(), k=k)
        processor = processors.get(feature_key) if processors else None
        is_continuous = torch.is_floating_point(feature_input)

        for rank, (_, flat_idx) in enumerate(zip(top_values, top_indices), 1):
            attribution_value = flattened[flat_idx].item()
            coords = unravel(flat_idx.item(), attr_cpu[0].shape)

            if is_continuous:
                actual_value = feature_input[0][tuple(coords)].item()
                label = ""
                if feature_key == "labs" and len(coords) >= 1:
                    lab_idx = coords[-1]
                    if lab_idx < len(LAB_CATEGORY_NAMES):
                        label = f"{LAB_CATEGORY_NAMES[lab_idx]} "
                print(
                    f"  {rank:2d}. idx={coords} {label}value={actual_value:.4f} "
                    f"attr={attribution_value:+.6f}"
                )
            else:
                token_idx = int(feature_input[0][tuple(coords)].item())
                token = decode_token(token_idx, processor, feature_key)
                print(
                    f"  {rank:2d}. idx={coords} token='{token}' "
                    f"attr={attribution_value:+.6f}"
                )


In [15]:
from pyhealth.interpret.methods import DeepLift, IntegratedGradients
def move_batch_to_device(batch, target_device):
    moved = {}
    for key, value in batch.items():
        if isinstance(value, torch.Tensor):
            moved[key] = value.to(target_device)
        elif isinstance(value, tuple):
            moved[key] = tuple(v.to(target_device) for v in value)
        else:
            moved[key] = value
    return moved

device = torch.device("cpu")
model.to(device)
ig = IntegratedGradients(model)


sample_batch = next(iter(test_loader))
sample_batch_device = move_batch_to_device(sample_batch, device)

with torch.no_grad():
    output = model(**sample_batch_device)
    probs = output["y_prob"]
    preds = torch.argmax(probs, dim=-1)
    label_key = model.label_key
    true_label = sample_batch_device[label_key]

    print("\nModel prediction for the sampled patient:")
    print(f"  True label: {int(true_label.cpu()[0].item())}")
    print(f"  Predicted class: {int(preds.cpu()[0].item())}")
    print(f"  Probabilities: {probs[0].cpu().numpy()}")


attributions = ig.attribute(**sample_batch_device)
print_top_attributions(attributions, sample_batch_device, input_processors, top_k=10)



Feature: icd_codes
   1. idx=[24, 48] token='<pad>' attr=-0.000597
   2. idx=[24, 62] token='<pad>' attr=-0.000513
   3. idx=[24, 41] token='<pad>' attr=+0.000425
   4. idx=[24, 32] token='<pad>' attr=-0.000415
   5. idx=[24, 12] token='<pad>' attr=+0.000386
   6. idx=[24, 50] token='<pad>' attr=+0.000382
   7. idx=[24, 42] token='<pad>' attr=-0.000380
   8. idx=[24, 28] token='<pad>' attr=+0.000370
   9. idx=[24, 57] token='<pad>' attr=-0.000350
  10. idx=[24, 38] token='<pad>' attr=-0.000348

Feature: labs
   1. idx=[401, 5] Calcium value=0.0000 attr=+0.001794
   2. idx=[401, 3] Bicarbonate value=0.0000 attr=+0.001794
   3. idx=[401, 4] Glucose value=0.0000 attr=+0.001794
   4. idx=[401, 6] Magnesium value=0.0000 attr=+0.001794
   5. idx=[401, 1] Potassium value=0.0000 attr=+0.001794
   6. idx=[401, 7] Anion Gap value=0.0000 attr=+0.001794
   7. idx=[401, 8] Osmolality value=0.0000 attr=+0.001794
   8. idx=[401, 9] Phosphate value=0.0000 attr=+0.001794
   9. idx=[401, 2] Chloride va

In [16]:
def build_random_embedding_baseline(
    model: StageNet,
    batch: dict,
    scale: float = 0.01,
    seed: int = 42,
) -> dict:
    """Construct a non-empty baseline directly in embedding space.

    DeepLIFT subtracts the baseline embedding from the actual embedding.
    Using pure zeros collapses StageNet masks (all visits become padding),
    so we add small random noise to keep at least one timestep active.
    """

    torch.manual_seed(seed)
    feature_inputs = {}
    for key in model.feature_keys:
        value = batch[key]
        if isinstance(value, tuple):
            value = value[1]
        feature_inputs[key] = value.to(model.device)

    embedded = model.embedding_model(feature_inputs)
    baseline = {}
    for key, emb in embedded.items():
        baseline[key] = torch.randn_like(emb) * scale
    return baseline


In [19]:
deeplift = DeepLift(model)

random_baseline = build_random_embedding_baseline(model, sample_batch_device)
attributions = deeplift.attribute(
    baseline=random_baseline,
    **sample_batch_device,
)
print_top_attributions(attributions, sample_batch_device, input_processors, top_k=10)



Feature: icd_codes
   1. idx=[0, 1] token='42832: Chronic diastolic heart failure' attr=+0.079825
   2. idx=[0, 6] token='V5861: Long-term (current) use of anticoagulants' attr=-0.070667
   3. idx=[0, 5] token='V4501: Cardiac pacemaker in situ' attr=-0.058043
   4. idx=[0, 10] token='370: Keratitis' attr=+0.056914
   5. idx=[2, 10] token='V4501: Cardiac pacemaker in situ' attr=-0.050888
   6. idx=[0, 7] token='4019: Unspecified essential hypertension' attr=-0.048502
   7. idx=[0, 3] token='4280: Congestive heart failure, unspecified' attr=+0.045676
   8. idx=[0, 2] token='4233: Cardiac tamponade' attr=+0.037603
   9. idx=[2, 13] token='4019: Unspecified essential hypertension' attr=-0.031371
  10. idx=[2, 5] token='4280: Congestive heart failure, unspecified' attr=-0.025716

Feature: labs
   1. idx=[400, 5] Calcium value=0.0000 attr=+0.004160
   2. idx=[400, 3] Bicarbonate value=0.0000 attr=+0.004160
   3. idx=[400, 4] Glucose value=0.0000 attr=+0.004160
   4. idx=[400, 6] Magnesium v