In [None]:
# Check GPU availability
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ GPU not available. Please enable GPU: Runtime > Change runtime type > GPU")

## 1. Installation

Install PyHealth and required dependencies:

In [None]:
# Install PyHealth (adjust path/version as needed)
!pip install pyhealth polars -q

# If using development version from GitHub:
# !pip install git+https://github.com/sunlabuiuc/PyHealth.git -q

## 2. Download MIMIC-IV Demo Dataset

Download the MIMIC-IV demo dataset. You'll need PhysioNet credentials.

In [None]:
import os
from pathlib import Path

# Create data directory
data_dir = Path("/content/mimic-iv-demo/2.2")
data_dir.mkdir(parents=True, exist_ok=True)

# Download MIMIC-IV demo dataset
# Note: Replace with actual download method or mount Google Drive with dataset
print(f"Data directory: {data_dir}")
print("\n⚠️ Please download MIMIC-IV demo dataset from:")
print("https://physionet.org/content/mimic-iv-demo/2.2/")
print("\nOr mount Google Drive if you have the dataset stored there.")

## 3. Load Pre-trained Model Checkpoint

Upload or download the pre-trained StageNet model checkpoint.

In [None]:
# Create resources directory
resources_dir = Path("/content/resources")
resources_dir.mkdir(parents=True, exist_ok=True)

# Upload model checkpoint
# You can use Google Colab's file upload or download from URL
# from google.colab import files
# uploaded = files.upload()

checkpoint_path = resources_dir / "best.ckpt"
print(f"Model checkpoint should be at: {checkpoint_path}")
print(f"Checkpoint exists: {checkpoint_path.exists()}")

## 4. Load Dataset and Processors

In [None]:
from pathlib import Path
import polars as pl
import torch

from pyhealth.datasets import (
    MIMIC4EHRDataset,
    get_dataloader,
    load_processors,
    split_by_patient,
)
from pyhealth.interpret.methods import ShapExplainer
from pyhealth.models import StageNet
from pyhealth.tasks import MortalityPredictionStageNetMIMIC4

# Configure dataset location
dataset = MIMIC4EHRDataset(
    root="/content/mimic-iv-demo/2.2/",  # Adjust path as needed
    tables=[
        "patients",
        "admissions",
        "diagnoses_icd",
        "procedures_icd",
        "labevents",
    ],
)

print(f"Dataset loaded: {len(dataset.patients)} patients")

In [None]:
# Load processors and set task
input_processors, output_processors = load_processors("/content/resources/")

sample_dataset = dataset.set_task(
    MortalityPredictionStageNetMIMIC4(),
    cache_dir="/content/.cache/pyhealth/mimic4_stagenet_mortality",
    input_processors=input_processors,
    output_processors=output_processors,
)
print(f"Total samples: {len(sample_dataset)}")

## 5. Load ICD Code Descriptions

In [None]:
def load_icd_description_map(dataset_root: str) -> dict:
    """Load ICD code → long title mappings from MIMIC-IV reference tables."""
    mapping = {}
    root_path = Path(dataset_root).expanduser()
    diag_path = root_path / "hosp" / "d_icd_diagnoses.csv.gz"
    proc_path = root_path / "hosp" / "d_icd_procedures.csv.gz"

    icd_dtype = {"icd_code": pl.Utf8, "long_title": pl.Utf8}

    if diag_path.exists():
        diag_df = pl.read_csv(
            diag_path,
            columns=["icd_code", "long_title"],
            dtypes=icd_dtype,
        )
        mapping.update(
            zip(diag_df["icd_code"].to_list(), diag_df["long_title"].to_list())
        )

    if proc_path.exists():
        proc_df = pl.read_csv(
            proc_path,
            columns=["icd_code", "long_title"],
            dtypes=icd_dtype,
        )
        mapping.update(
            zip(proc_df["icd_code"].to_list(), proc_df["long_title"].to_list())
        )

    return mapping

ICD_CODE_TO_DESC = load_icd_description_map(dataset.root)
print(f"Loaded {len(ICD_CODE_TO_DESC)} ICD code descriptions")

## 6. Load Pre-trained StageNet Model on GPU

In [None]:
# Set device to GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize model
model = StageNet(
    dataset=sample_dataset,
    embedding_dim=128,
    chunk_size=128,
    levels=3,
    dropout=0.3,
)

# Load checkpoint
state_dict = torch.load("/content/resources/best.ckpt", map_location=device)
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()

print(f"\nModel loaded successfully on {device}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 7. Prepare Test Data

In [None]:
# Split dataset
_, _, test_data = split_by_patient(sample_dataset, [0.7, 0.1, 0.2], seed=42)
test_loader = get_dataloader(test_data, batch_size=1, shuffle=False)

print(f"Test samples: {len(test_data)}")

def move_batch_to_device(batch, target_device):
    """Move all tensors in batch to target device."""
    moved = {}
    for key, value in batch.items():
        if isinstance(value, torch.Tensor):
            moved[key] = value.to(target_device)
        elif isinstance(value, tuple):
            moved[key] = tuple(v.to(target_device) for v in value)
        else:
            moved[key] = value
    return moved

## 8. Define Helper Functions for Visualization

In [None]:
LAB_CATEGORY_NAMES = MortalityPredictionStageNetMIMIC4.LAB_CATEGORY_NAMES

def decode_token(idx: int, processor, feature_key: str):
    """Decode token index to human-readable string."""
    if processor is None or not hasattr(processor, "code_vocab"):
        return str(idx)
    reverse_vocab = {index: token for token, index in processor.code_vocab.items()}
    token = reverse_vocab.get(idx, f"<UNK:{idx}>")

    if feature_key == "icd_codes" and token not in {"<unk>", "<pad>"}:
        desc = ICD_CODE_TO_DESC.get(token)
        if desc:
            return f"{token}: {desc}"

    return token

def unravel(flat_index: int, shape: torch.Size):
    """Convert flat index to multi-dimensional coordinates."""
    coords = []
    remaining = flat_index
    for dim in reversed(shape):
        coords.append(remaining % dim)
        remaining //= dim
    return list(reversed(coords))

def print_top_attributions(
    attributions,
    batch,
    processors,
    top_k: int = 10,
):
    """Print top-k most important features from SHAP attributions."""
    for feature_key, attr in attributions.items():
        attr_cpu = attr.detach().cpu()
        if attr_cpu.dim() == 0 or attr_cpu.size(0) == 0:
            continue

        feature_input = batch[feature_key]
        if isinstance(feature_input, tuple):
            feature_input = feature_input[1]
        feature_input = feature_input.detach().cpu()

        flattened = attr_cpu[0].flatten()
        if flattened.numel() == 0:
            continue

        print(f"\nFeature: {feature_key}")
        print(f"  Shape: {attr_cpu[0].shape}")
        print(f"  Total attribution sum: {flattened.sum().item():+.6f}")
        print(f"  Mean attribution: {flattened.mean().item():+.6f}")
        
        k = min(top_k, flattened.numel())
        top_values, top_indices = torch.topk(flattened.abs(), k=k)
        processor = processors.get(feature_key) if processors else None
        is_continuous = torch.is_floating_point(feature_input)

        print(f"\n  Top {k} most important features:")
        for rank, (_, flat_idx) in enumerate(zip(top_values, top_indices), 1):
            attribution_value = flattened[flat_idx].item()
            coords = unravel(flat_idx.item(), attr_cpu[0].shape)

            if is_continuous:
                actual_value = feature_input[0][tuple(coords)].item()
                label = ""
                if feature_key == "labs" and len(coords) >= 1:
                    lab_idx = coords[-1]
                    if lab_idx < len(LAB_CATEGORY_NAMES):
                        label = f"{LAB_CATEGORY_NAMES[lab_idx]} "
                print(
                    f"    {rank:2d}. idx={coords} {label}value={actual_value:.4f} "
                    f"SHAP={attribution_value:+.6f}"
                )
            else:
                token_idx = int(feature_input[0][tuple(coords)].item())
                token = decode_token(token_idx, processor, feature_key)
                print(
                    f"    {rank:2d}. idx={coords} token='{token}' "
                    f"SHAP={attribution_value:+.6f}"
                )

## 9. Initialize SHAP Explainer

Initialize the SHAP explainer with Kernel SHAP configuration.

In [None]:
print("="*80)
print("Initializing SHAP Explainer")
print("="*80)

# Initialize SHAP explainer (Kernel SHAP)
shap_explainer = ShapExplainer(model)

print("\nSHAP Configuration:")
print(f"  Use embeddings: {shap_explainer.use_embeddings}")
print(f"  Background samples: {shap_explainer.n_background_samples}")
print(f"  Max coalitions: {shap_explainer.max_coalitions}")
print(f"  Regularization: {shap_explainer.regularization}")
print(f"  Device: {next(shap_explainer.model.parameters()).device}")

## 10. Get Model Prediction on Test Sample

In [None]:
# Get a sample from test set
sample_batch = next(iter(test_loader))
sample_batch_device = move_batch_to_device(sample_batch, device)

# Verify data is on GPU
for key, val in sample_batch_device.items():
    if isinstance(val, torch.Tensor):
        print(f"{key}: device={val.device}")
    elif isinstance(val, tuple) and len(val) > 0 and isinstance(val[0], torch.Tensor):
        print(f"{key}: device={val[0].device}")

# Get model prediction
with torch.no_grad():
    output = model(**sample_batch_device)
    probs = output["y_prob"]
    label_key = model.label_key
    true_label = sample_batch_device[label_key]
    
    # Handle binary classification (single probability output)
    if probs.shape[-1] == 1:
        prob_death = probs[0].item()
        prob_survive = 1 - prob_death
        preds = (probs > 0.5).long()
    else:
        # Multi-class classification
        preds = torch.argmax(probs, dim=-1)
        prob_survive = probs[0][0].item()
        prob_death = probs[0][1].item()

    print("\n" + "="*80)
    print("Model Prediction for Sampled Patient")
    print("="*80)
    print(f"  True label: {int(true_label.item())} {'(Deceased)' if true_label.item() == 1 else '(Survived)'}")
    print(f"  Predicted class: {int(preds.item())} {'(Deceased)' if preds.item() == 1 else '(Survived)'}")
    print(f"  Probabilities: [Survive={prob_survive:.4f}, Death={prob_death:.4f}]")

## 11. Compute SHAP Attributions (GPU-Accelerated)

This step computes SHAP values using Kernel SHAP, running on GPU.

In [None]:
import time

print("\n" + "="*80)
print("Computing SHAP Attributions on GPU")
print("="*80)

# Time the computation
start_time = time.time()

attributions = shap_explainer.attribute(**sample_batch_device, target_class_idx=1)

elapsed = time.time() - start_time
print(f"\n✓ Computation completed in {elapsed:.2f} seconds")

# Verify attributions are on GPU
print("\nAttribution tensor devices:")
for key, val in attributions.items():
    print(f"  {key}: device={val.device}, shape={val.shape}")

## 12. Analyze SHAP Results

In [None]:
print("\n" + "="*80)
print("SHAP Attribution Results")
print("="*80)
print("\nSHAP values explain the contribution of each feature to the model's")
print("prediction of MORTALITY (class 1). Positive values increase the")
print("mortality prediction, negative values decrease it.")

print_top_attributions(attributions, sample_batch_device, input_processors, top_k=15)

## 13. Test Different Target Classes

In [None]:
print("\n" + "="*80)
print("Comparing SHAP Attributions for Different Target Classes")
print("="*80)

# Compute for survival (class 0)
print("\nComputing attributions for SURVIVAL (class 0)...")
attr_survive = shap_explainer.attribute(**sample_batch_device, target_class_idx=0)

# Compute for mortality (class 1)
print("Computing attributions for MORTALITY (class 1)...")
attr_death = shap_explainer.attribute(**sample_batch_device, target_class_idx=1)

print("\n--- Features promoting SURVIVAL ---")
print_top_attributions(attr_survive, sample_batch_device, input_processors, top_k=5)

print("\n--- Features promoting MORTALITY ---")
print_top_attributions(attr_death, sample_batch_device, input_processors, top_k=5)

## 14. Verify GPU Memory Usage

In [None]:
if torch.cuda.is_available():
    print("\n" + "="*80)
    print("GPU Memory Usage")
    print("="*80)
    
    allocated = torch.cuda.memory_allocated(0) / 1e9
    reserved = torch.cuda.memory_reserved(0) / 1e9
    max_allocated = torch.cuda.max_memory_allocated(0) / 1e9
    
    print(f"  Currently allocated: {allocated:.2f} GB")
    print(f"  Reserved: {reserved:.2f} GB")
    print(f"  Peak allocated: {max_allocated:.2f} GB")
    
    # Reset peak stats
    torch.cuda.reset_peak_memory_stats(0)
else:
    print("GPU not available")

## 15. Test Callable Interface

In [None]:
print("\n" + "="*80)
print("Testing Callable Interface")
print("="*80)

# Both methods should produce identical results
attr_from_attribute = shap_explainer.attribute(**sample_batch_device, target_class_idx=1)
attr_from_call = shap_explainer(**sample_batch_device, target_class_idx=1)

print("\nVerifying that explainer(**data) and explainer.attribute(**data) produce")
print("identical results...")

all_close = True
for key in attr_from_attribute.keys():
    if not torch.allclose(attr_from_attribute[key], attr_from_call[key], atol=1e-6):
        all_close = False
        print(f"  ❌ {key}: Results differ!")
    else:
        print(f"  ✓ {key}: Results match")

if all_close:
    print("\n✓ All attributions match! Callable interface works correctly.")
else:
    print("\n❌ Some attributions differ.")

## Summary

This notebook demonstrated:

1. ✅ **GPU Setup**: Verified GPU availability and configured PyTorch to use CUDA
2. ✅ **Model Loading**: Loaded pre-trained StageNet model on GPU
3. ✅ **SHAP Computation**: Computed SHAP attributions on GPU for discrete features (ICD codes)
4. ✅ **Feature Interpretation**: Identified which diagnosis/procedure codes and lab values most influenced mortality predictions
5. ✅ **Multi-class Analysis**: Compared attributions for different target classes (survival vs. mortality)
6. ✅ **GPU Optimization**: Verified all tensors and computations run on GPU

**Key Takeaways:**
- SHAP provides interpretable, theoretically-grounded feature attributions
- GPU acceleration significantly speeds up coalition sampling and model evaluations
- The method works seamlessly with discrete healthcare features like ICD codes
- Positive SHAP values indicate features that increase the prediction, negative values decrease it