# SHAP Interpretability for StageNet on MIMIC-IV

This notebook demonstrates how to use the SHAP (SHapley Additive exPlanations) interpretability method with a StageNet model trained on MIMIC-IV data for mortality prediction.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/naveenkcb/PyHealth/blob/master/examples/shap_stagenet_mimic4.ipynb)

## Setup: Install PyHealth from Your Forked Repository

First, we'll install PyHealth directly from your forked GitHub repository.

In [None]:
# Install PyHealth from forked repository
!pip install git+https://github.com/naveenkcb/PyHealth.git -q

# Install additional required dependencies
!pip install polars -q

print("✓ Installation complete!")

## Import Required Libraries

In [None]:
from pathlib import Path
import polars as pl
import torch

from pyhealth.datasets import (
    MIMIC4EHRDataset,
    get_dataloader,
    load_processors,
    split_by_patient,
)
from pyhealth.interpret.methods import ShapExplainer
from pyhealth.models import StageNet
from pyhealth.tasks import MortalityPredictionStageNetMIMIC4

print("✓ All libraries imported successfully!")

## Setup MIMIC-IV Dataset Path

**Note**: You'll need to:
1. Have access to MIMIC-IV dataset (requires PhysioNet credentialing)
2. Update the `dataset_root` path below to point to your MIMIC-IV data location
3. If running on Colab, you may need to mount Google Drive or upload the data

In [None]:
# Option 1: For local MIMIC-IV data
dataset_root = "/home/logic/physionet.org/files/mimic-iv-demo/2.2/"

# Option 2: For Google Drive (uncomment if using Colab with Drive)
# from google.colab import drive
# drive.mount('/content/drive')
# dataset_root = "/content/drive/MyDrive/mimic-iv-demo/2.2/"

# Option 3: For demo data (update path as needed)
# dataset_root = "/path/to/your/mimic-iv-demo/"

print(f"Dataset root: {dataset_root}")

## Load MIMIC-IV Dataset

In [None]:
# Configure dataset location and load cached processors
dataset = MIMIC4EHRDataset(
    root=dataset_root,
    tables=[
        "patients",
        "admissions",
        "diagnoses_icd",
        "procedures_icd",
        "labevents",
    ],
)

print(f"✓ Dataset loaded with {len(dataset.patients)} patients")

## Setup ICD Code Description Mapping

In [None]:
def load_icd_description_map(dataset_root: str) -> dict:
    """Load ICD code → long title mappings from MIMIC-IV reference tables."""
    mapping = {}
    root_path = Path(dataset_root).expanduser()
    diag_path = root_path / "hosp" / "d_icd_diagnoses.csv.gz"
    proc_path = root_path / "hosp" / "d_icd_procedures.csv.gz"

    icd_dtype = {"icd_code": pl.Utf8, "long_title": pl.Utf8}

    if diag_path.exists():
        diag_df = pl.read_csv(
            diag_path,
            columns=["icd_code", "long_title"],
            dtypes=icd_dtype,
        )
        mapping.update(
            zip(diag_df["icd_code"].to_list(), diag_df["long_title"].to_list())
        )

    if proc_path.exists():
        proc_df = pl.read_csv(
            proc_path,
            columns=["icd_code", "long_title"],
            dtypes=icd_dtype,
        )
        mapping.update(
            zip(proc_df["icd_code"].to_list(), proc_df["long_title"].to_list())
        )

    return mapping


ICD_CODE_TO_DESC = load_icd_description_map(dataset.root)
print(f"✓ Loaded {len(ICD_CODE_TO_DESC)} ICD code descriptions")

## Setup Mortality Prediction Task

**Note**: You'll need preprocessed data (processors) and a trained model checkpoint. 
Update the paths below or train a model first using the PyHealth training pipeline.

In [None]:
# Path to cached processors (update this path)
processors_path = "../resources/"

# Load or create processors
try:
    input_processors, output_processors = load_processors(processors_path)
    print("✓ Loaded cached processors")
except:
    print("⚠ Could not load processors. Will create new ones.")
    input_processors = None
    output_processors = None

# Set up the task
sample_dataset = dataset.set_task(
    MortalityPredictionStageNetMIMIC4(),
    cache_dir="~/.cache/pyhealth/mimic4_stagenet_mortality",
    input_processors=input_processors,
    output_processors=output_processors,
)

print(f"✓ Total samples: {len(sample_dataset)}")

## Load Pre-trained StageNet Model

**Note**: You need a trained model checkpoint. Update the path below or train a model first.

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize model
model = StageNet(
    dataset=sample_dataset,
    embedding_dim=128,
    chunk_size=128,
    levels=3,
    dropout=0.3,
)

# Load trained weights (update this path)
checkpoint_path = "../resources/best.ckpt"

try:
    state_dict = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state_dict)
    print("✓ Loaded pre-trained model")
except:
    print("⚠ Could not load checkpoint. Using randomly initialized model.")
    print("   (Results will not be meaningful without a trained model)")

model = model.to(device)
model.eval()
print(model)

## Prepare Test Data

In [None]:
# Split dataset
_, _, test_data = split_by_patient(sample_dataset, [0.7, 0.1, 0.2], seed=42)
test_loader = get_dataloader(test_data, batch_size=1, shuffle=False)

print(f"✓ Test set: {len(test_data)} samples")

## Helper Functions for Attribution Analysis

In [None]:
def move_batch_to_device(batch, target_device):
    """Move all tensors in batch to target device."""
    moved = {}
    for key, value in batch.items():
        if isinstance(value, torch.Tensor):
            moved[key] = value.to(target_device)
        elif isinstance(value, tuple):
            moved[key] = tuple(v.to(target_device) for v in value)
        else:
            moved[key] = value
    return moved


LAB_CATEGORY_NAMES = MortalityPredictionStageNetMIMIC4.LAB_CATEGORY_NAMES


def decode_token(idx: int, processor, feature_key: str):
    """Decode token index to human-readable string."""
    if processor is None or not hasattr(processor, "code_vocab"):
        return str(idx)
    reverse_vocab = {index: token for token, index in processor.code_vocab.items()}
    token = reverse_vocab.get(idx, f"<UNK:{idx}>")

    if feature_key == "icd_codes" and token not in {"<unk>", "<pad>"}:
        desc = ICD_CODE_TO_DESC.get(token)
        if desc:
            return f"{token}: {desc}"

    return token


def unravel(flat_index: int, shape: torch.Size):
    """Convert flat index to multi-dimensional coordinates."""
    coords = []
    remaining = flat_index
    for dim in reversed(shape):
        coords.append(remaining % dim)
        remaining //= dim
    return list(reversed(coords))


def print_top_attributions(
    attributions,
    batch,
    processors,
    top_k: int = 10,
):
    """Print top-k most important features from SHAP attributions."""
    for feature_key, attr in attributions.items():
        attr_cpu = attr.detach().cpu()
        if attr_cpu.dim() == 0 or attr_cpu.size(0) == 0:
            continue

        feature_input = batch[feature_key]
        if isinstance(feature_input, tuple):
            feature_input = feature_input[1]
        feature_input = feature_input.detach().cpu()

        flattened = attr_cpu[0].flatten()
        if flattened.numel() == 0:
            continue

        print(f"\nFeature: {feature_key}")
        print(f"  Shape: {attr_cpu[0].shape}")
        print(f"  Total attribution sum: {flattened.sum().item():+.6f}")
        print(f"  Mean attribution: {flattened.mean().item():+.6f}")
        
        k = min(top_k, flattened.numel())
        top_values, top_indices = torch.topk(flattened.abs(), k=k)
        processor = processors.get(feature_key) if processors else None
        is_continuous = torch.is_floating_point(feature_input)

        print(f"\n  Top {k} most important features:")
        for rank, (_, flat_idx) in enumerate(zip(top_values, top_indices), 1):
            attribution_value = flattened[flat_idx].item()
            coords = unravel(flat_idx.item(), attr_cpu[0].shape)

            if is_continuous:
                actual_value = feature_input[0][tuple(coords)].item()
                label = ""
                if feature_key == "labs" and len(coords) >= 1:
                    lab_idx = coords[-1]
                    if lab_idx < len(LAB_CATEGORY_NAMES):
                        label = f"{LAB_CATEGORY_NAMES[lab_idx]} "
                print(
                    f"    {rank:2d}. idx={coords} {label}value={actual_value:.4f} "
                    f"SHAP={attribution_value:+.6f}"
                )
            else:
                token_idx = int(feature_input[0][tuple(coords)].item())
                token = decode_token(token_idx, processor, feature_key)
                print(
                    f"    {rank:2d}. idx={coords} token='{token}' "
                    f"SHAP={attribution_value:+.6f}"
                )

print("✓ Helper functions defined")

## Initialize SHAP Explainer

In [None]:
print("="*80)
print("Initializing SHAP Explainer")
print("="*80)

# Initialize SHAP explainer with custom parameters
shap_explainer = ShapExplainer(
    model,
    use_embeddings=True,  # Use embeddings for discrete features
    n_background_samples=50,  # Number of background samples
    max_coalitions=200,  # Number of feature coalitions to sample
    random_seed=42,  # For reproducibility
)

print("\nSHAP Configuration:")
print(f"  Use embeddings: {shap_explainer.use_embeddings}")
print(f"  Background samples: {shap_explainer.n_background_samples}")
print(f"  Max coalitions: {shap_explainer.max_coalitions}")
print(f"  Regularization: {shap_explainer.regularization}")
print(f"  Random seed: {shap_explainer.random_seed}")

## Get Sample and Model Prediction

In [None]:
# Get a sample from test set
sample_batch = next(iter(test_loader))
sample_batch_device = move_batch_to_device(sample_batch, device)

# Get model prediction
with torch.no_grad():
    output = model(**sample_batch_device)
    probs = output["y_prob"]
    preds = torch.argmax(probs, dim=-1)
    label_key = model.label_key
    true_label = sample_batch_device[label_key]

    print("\n" + "="*80)
    print("Model Prediction for Sampled Patient")
    print("="*80)
    print(f"  True label: {int(true_label.item())} {'(Deceased)' if true_label.item() == 1 else '(Survived)'}")
    print(f"  Predicted class: {int(preds.item())} {'(Deceased)' if preds.item() == 1 else '(Survived)'}")
    print(f"  Probabilities: [Survive={probs[0][0].item():.4f}, Death={probs[0][1].item():.4f}]")

## Compute SHAP Attributions

This cell computes SHAP values for the mortality prediction (class 1). 
**Note**: This may take 1-2 minutes depending on the number of coalitions and background samples.

In [None]:
print("\n" + "="*80)
print("Computing SHAP Attributions (this may take a minute...)")
print("="*80)

attributions = shap_explainer.attribute(**sample_batch_device, target_class_idx=1)

print("\n✓ SHAP computation complete!")

## Display SHAP Attribution Results

In [None]:
print("\n" + "="*80)
print("SHAP Attribution Results")
print("="*80)
print("\nSHAP values explain the contribution of each feature to the model's")
print("prediction of MORTALITY (class 1). Positive values increase the")
print("mortality prediction, negative values decrease it.")

print_top_attributions(attributions, sample_batch_device, input_processors, top_k=15)

## Compare Different Baseline Strategies

In [None]:
print("\n" + "="*80)
print("Testing Different Baseline Strategies")
print("="*80)

# 1. Automatic baseline (default)
print("\n1. Automatic baseline generation:")
attr_auto = shap_explainer.attribute(**sample_batch_device, target_class_idx=1)
print(f"   Total attribution (icd_codes): {attr_auto['icd_codes'][0].sum().item():+.6f}")

# 2. Custom zero baseline
print("\n2. Custom zero baseline:")
zero_baseline = {}
for key in model.feature_keys:
    if key in sample_batch_device:
        feature_input = sample_batch_device[key]
        if isinstance(feature_input, tuple):
            feature_input = feature_input[1]
        zero_baseline[key] = torch.zeros(
            (shap_explainer.n_background_samples,) + feature_input.shape[1:],
            device=device,
            dtype=feature_input.dtype
        )

attr_zero = shap_explainer.attribute(
    baseline=zero_baseline,
    **sample_batch_device,
    target_class_idx=1
)
print(f"   Total attribution (icd_codes): {attr_zero['icd_codes'][0].sum().item():+.6f}")

## Test Callable Interface

Verify that both `explainer.attribute()` and `explainer()` produce identical results when using a random seed.

In [None]:
print("\n" + "="*80)
print("Testing Callable Interface")
print("="*80)

# Both methods should produce identical results (due to random_seed)
attr_from_attribute = shap_explainer.attribute(**sample_batch_device, target_class_idx=1)
attr_from_call = shap_explainer(**sample_batch_device, target_class_idx=1)

print("\nVerifying that explainer(**data) and explainer.attribute(**data) produce")
print("identical results when random_seed is set...")

all_close = True
for key in attr_from_attribute.keys():
    if not torch.allclose(attr_from_attribute[key], attr_from_call[key], atol=1e-6):
        all_close = False
        print(f"  ❌ {key}: Results differ!")
    else:
        print(f"  ✓ {key}: Results match")

if all_close:
    print("\n✓ All attributions match! Callable interface works correctly.")
else:
    print("\n❌ Some attributions differ. Check random seed configuration.")

## Summary

This notebook demonstrated:

1. **SHAP Initialization**: How to configure the `ShapExplainer` with custom parameters
2. **Attribution Computation**: Computing SHAP values for mortality prediction
3. **Feature Importance**: Identifying the most important features driving predictions
4. **Baseline Strategies**: Comparing automatic vs. custom baseline generation
5. **Reproducibility**: Using random seeds for deterministic results

### Key Takeaways:

- **Positive SHAP values** indicate features that increase the mortality prediction
- **Negative SHAP values** indicate features that decrease the mortality prediction
- The sum of SHAP values approximates the difference between the model's prediction and the baseline
- Setting a `random_seed` ensures reproducible results across multiple runs

### Next Steps:

- Analyze multiple patients to identify common patterns
- Compare SHAP results with other interpretability methods (DeepLIFT, Integrated Gradients)
- Visualize SHAP values using summary plots or waterfall charts
- Use SHAP insights to improve model performance or identify data quality issues