# Medical Image Classification with PyHealth

Welcome to the PyHealth tutorial on image classification and saliency mapping. In this notebook, we will explore how to use PyHealth to analyze chest X-ray images, classify them into various chest diseases, and visualize the model's decision-making process using gradient saliency maps.

## Environment Setup

First, let's install the required packages and set up our environment.

In [None]:
!pip install mne pandarallel rdkit transformers torch torchvision openpyxl polars

In [None]:
!rm -rf PyHealth
# !git clone https://github.com/sunlabuiuc/PyHealth.git
!git clone -b SaliencyMappingClass https://github.com/Nimanui/PyHealth-fitzpa15.git PyHealth

In [None]:
import sys

sys.path.append("./PyHealth")
sys.path.append("./PyHealth-fitzpa15")

## Download Data

Next, we will download the dataset containing COVID-19 data. This dataset includes chest X-ray images of normal cases, lung opacity, viral pneumonia, and COVID-19 patients. You can find more information about the dataset [here](https://www.kaggle.com/datasets/tawsifurrahman/covid19-radiography-database).

Download and extract the dataset:

In [None]:
!wget -N https://storage.googleapis.com/pyhealth/covid19_cxr_data/archive.zip

In [None]:
!unzip -q -o archive.zip

In [None]:
!ls -1 COVID-19_Radiography_Dataset

Next, we will proceed with the chest X-ray classification task using PyHealth, following a five-stage pipeline.

## Step 1. Load Data in PyHealth

The initial step involves loading the data into PyHealth's internal structure. This process is straightforward: import the appropriate dataset class from PyHealth and specify the root directory where the raw dataset is stored. PyHealth will handle the dataset processing automatically.

In [None]:
from pyhealth.datasets import COVID19CXRDataset

root = "COVID-19_Radiography_Dataset"
base_dataset = COVID19CXRDataset(root)

Once the data is loaded, we can perform simple queries on the dataset:

In [None]:
base_dataset.stats()

In [None]:
base_dataset.get_patient("0").get_events()

## Step 2. Define the Task

The next step is to define the machine learning task. This step instructs the package to generate a list of samples with the desired features and labels based on the data for each individual patient. Please note that in this dataset, patient identification information is not available. Therefore, we will assume that each chest X-ray belongs to a unique patient.

For this dataset, PyHealth offers a default task specifically for chest X-ray classification. This task takes the image as input and aims to predict the chest diseases associated with it.

In [None]:
base_dataset.default_task

In [None]:
sample_dataset = base_dataset.set_task()

Here is an example of a single sample, represented as a dictionary. The dictionary contains keys for feature names, label names, and other metadata associated with the sample.

In [None]:
sample_dataset[0]

We can also check the input and output schemas, which specify the data types of the features and labels.

In [None]:
sample_dataset.input_schema

In [None]:
sample_dataset.output_schema

Below, we plot the number of samples per classes, and visualize some samples.

In [None]:
label2id = sample_dataset.output_processors["disease"].label_vocab
id2label = {v: k for k, v in label2id.items()}

In [None]:
from collections import defaultdict
import matplotlib.pyplot as plt

label_counts = defaultdict(int)
for sample in sample_dataset.samples:
    label_counts[id2label[sample["disease"].item()]] += 1
print(label_counts)
plt.bar(label_counts.keys(), label_counts.values())

In [None]:
import random

label_to_idxs = defaultdict(list)
for idx, sample in enumerate(sample_dataset.samples):
    label_to_idxs[sample["disease"].item()].append(idx)

fig, axs = plt.subplots(1, 4, figsize=(15, 3))
for ax, label in zip(axs, label_to_idxs.keys()):
    ax.set_title(id2label[label], fontsize=15)
    idx = random.choice(label_to_idxs[label])
    sample = sample_dataset[idx]
    image = sample["image"][0]
    ax.imshow(image, cmap="gray")

Finally, we will split the entire dataset into training, validation, and test sets using the ratios of 70%, 10%, and 20%, respectively. We will then obtain the corresponding data loaders for each set.

In [None]:
from pyhealth.datasets import split_by_sample

train_dataset, val_dataset, test_dataset = split_by_sample(
    dataset=sample_dataset,
    ratios=[0.7, 0.1, 0.2]
)

In [None]:
from pyhealth.datasets import get_dataloader

train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

## Step 3. Define the Model

Next, we will define the deep learning model we want to use for our task. PyHealth supports all major vision models available in the Torchvision package. You can load any of these models using the model_name argument.

In [None]:
from pyhealth.models import TorchvisionModel

resnet = TorchvisionModel(
    dataset=sample_dataset,
    model_name="resnet18",
    model_config={"weights": "DEFAULT"}
)

resnet

In [None]:
from pyhealth.models import TorchvisionModel

vit = TorchvisionModel(
    dataset=sample_dataset,
    model_name="vit_b_16",
    model_config={"weights": "DEFAULT"}
)

vit

## Step 4. Training

In this step, we will train the model using PyHealth's Trainer class, which simplifies the training process and provides standard functionalities.

Let us first train the ResNet model.

In [None]:
from pyhealth.trainer import Trainer

resnet_trainer = Trainer(model=resnet)

Before we begin training, let's first evaluate the initial performance of the model.

In [None]:
print(resnet_trainer.evaluate(test_dataloader))

Now, let's start the training process. Due to computational constraints, we will train the model for only one epoch.

In [None]:
resnet_trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=1,
    monitor="accuracy"
)

After training the model, we can compare its performance before and after. We should expect to see an increase in the accuracy score as the model learns from the training data.

## Step 5. Evaluation

Lastly, we can evaluate the ResNet model on the test set. This can be done using PyHealth's `Trainer.evaluate()` function.

In [None]:
print(resnet_trainer.evaluate(test_dataloader))

Additionally, you can perform inference using the `Trainer.inference()` function.

In [None]:
y_true, y_prob, loss = resnet_trainer.inference(test_dataloader)
y_pred = y_prob.argmax(axis=1)

Below we show a confusion matrix of the trained ResNet model.

In [None]:
!pip install seaborn

from sklearn.metrics import confusion_matrix
import seaborn as sns

cf_matrix = confusion_matrix(y_true, y_pred)
ax = sns.heatmap(cf_matrix, linewidths=1, annot=True, fmt='g')
ax.set_xticklabels([id2label[i] for i in range(4)])
ax.set_yticklabels([id2label[i] for i in range(4)])
ax.set_xlabel("Pred")
ax.set_ylabel("True")

# 6 Gradient Saliency Mapping
For a bonus let's look at some simple gradient saliency maps applied to our sample dataset.

In [None]:
def add_requires_grad(in_dataset):
  for sample in in_dataset:
    sample['image'].requires_grad_()

In [None]:
from pyhealth.interpret.methods.basic_gradient import BasicGradientSaliencyMaps
from pyhealth.interpret.methods import SaliencyVisualizer
import torch

# Create a batch with only COVID samples
covid_label = label2id['COVID']
covid_samples = [sample for sample in sample_dataset.samples if sample['disease'].item() == covid_label]

# Take the first 32 COVID samples and create a batch
batch_size = min(32, len(covid_samples))
covid_batch = {
    'image': torch.stack([covid_samples[i]['image'] for i in range(batch_size)]),
    'disease': torch.stack([covid_samples[i]['disease'] for i in range(batch_size)])
}

print(f"Created COVID batch with {batch_size} samples")

# Initialize saliency maps with batch input only
saliency_maps = BasicGradientSaliencyMaps(
    resnet,
    input_batch=covid_batch
)

# Initialize the visualization module with correct parameter names
visualizer = SaliencyVisualizer(default_cmap='hot', default_alpha=0.6, figure_size=(15, 7))

In [None]:
# Show saliency map for the first image in the batch
image_0 = covid_batch['image'][0]
# Compute saliency for single image using attribute method
saliency_result_0 = saliency_maps.attribute(image=image_0.unsqueeze(0), disease=covid_batch['disease'][0:1])
visualizer.plot_saliency_overlay(
    plt, 
    image=image_0, 
    saliency=saliency_result_0['image'][0],
    title=f"Gradient Saliency - {id2label[covid_label]} (Sample 0)"
)

# Show saliency map for another image in the batch
image_3 = covid_batch['image'][3]
saliency_result_3 = saliency_maps.attribute(image=image_3.unsqueeze(0), disease=covid_batch['disease'][3:4])
visualizer.plot_saliency_overlay(
    plt, 
    image=image_3, 
    saliency=saliency_result_3['image'][0],
    title=f"Gradient Saliency - {id2label[covid_label]} (Sample 3)"
)

# 7. Layer-wise Relevance Propagation (LRP)

LRP is a powerful interpretability method that explains neural network predictions by propagating relevance scores backward through the network. Unlike gradient-based methods, LRP satisfies the conservation property: the sum of relevances at the input layer approximately equals the model's output for the target class.

**New Implementation**: PyHealth now includes **UnifiedLRP** - a modular implementation supporting both CNNs and embedding-based models with 11 layer handlers including Conv2d, MaxPool2d, BatchNorm2d, and more!

Let's demonstrate LRP on our ResNet model:

In [None]:
import logging
from pyhealth.interpret.methods import UnifiedLRP

# Suppress conservation warnings for cleaner output
logging.getLogger('pyhealth.interpret.methods.lrp_base').setLevel(logging.ERROR)

# Initialize UnifiedLRP with epsilon rule
lrp = UnifiedLRP(
    model=resnet.model,  # Use the underlying PyTorch model
    rule='epsilon',
    epsilon=0.1,  # Larger epsilon for numerical stability
    validate_conservation=False
)

# Compute LRP attributions for the first COVID sample
print(f"Computing LRP attributions for COVID-19 sample...")
covid_image = covid_batch['image'][0:1]

# Convert grayscale to RGB (ResNet expects 3 channels)
if covid_image.shape[1] == 1:
    covid_image = covid_image.repeat(1, 3, 1, 1)

# Move to the same device as the model
device = next(resnet.model.parameters()).device
covid_image = covid_image.to(device)

# Forward pass to get prediction
with torch.no_grad():
    output = resnet.model(covid_image)
    predicted_class = output.argmax(dim=1).item()

# Compute LRP attributions
lrp_attributions = lrp.attribute(
    inputs={'x': covid_image},
    target_class=predicted_class
)

print(f"‚úì LRP attributions computed!")
print(f"  Input shape: {covid_image.shape}")
print(f"  Attribution shape: {lrp_attributions['x'].shape}")
print(f"  Predicted class: {id2label[predicted_class]}")
print(f"  Total relevance: {lrp_attributions['x'].sum().item():.4f}")

## Visualizing LRP Results

LRP provides pixel-level explanations showing which image regions contributed to the model's prediction.

In [None]:
# Visualize LRP relevance map
relevance_map = lrp_attributions['x'].squeeze()

# For visualization, use the first channel (all channels are the same for grayscale)
visualizer.plot_saliency_overlay(
    plt,
    image=covid_batch['image'][0],  # Original grayscale image
    saliency=relevance_map[0] if relevance_map.dim() == 3 else relevance_map,  # First channel of attribution
    title=f"LRP Relevance Map - {id2label[predicted_class]} (Epsilon Rule)",
)

# Also show gradient saliency for comparison
saliency_comparison = saliency_maps.attribute(image=covid_batch['image'][0:1], disease=covid_batch['disease'][0:1])
visualizer.plot_saliency_overlay(
    plt,
    image=covid_batch['image'][0],
    saliency=saliency_comparison['image'][0],
    title=f"Gradient Saliency (for comparison) - {id2label[predicted_class]}"
)

## Comparing Different LRP Rules

LRP supports different propagation rules that handle positive and negative contributions differently:

**Epsilon Rule (`rule="epsilon"`):**
- Adds a small stabilizer Œµ to prevent division by zero
- Best for: General use, numerical stability
- Good for layers where both positive and negative activations matter equally
- Conservation violations: 5-50% (acceptable)

**Alpha-Beta Rule (`rule="alphabeta"`):**
- Separates positive and negative contributions with different weights (Œ± and Œ≤)
- Default: Œ±=2, Œ≤=1 (emphasizes positive contributions)
- Best for: When you want to focus on excitatory (positive) evidence
- Often produces sharper, more focused heatmaps
- Conservation violations: 50-150% (acceptable)

Let's compare both rules on the same image:

In [None]:
# Epsilon rule (already computed)
print("LRP with Epsilon Rule (Œµ=0.1)")
visualizer.plot_saliency_overlay(
    plt,
    image=covid_batch['image'][0],
    saliency=relevance_map[0] if relevance_map.dim() == 3 else relevance_map,
    title=f"LRP Epsilon Rule - {id2label[predicted_class]}",
)

# Now compute LRP with Alpha-Beta Rule
print("\nComputing LRP with Alpha-Beta Rule (Œ±=2, Œ≤=1)...")
lrp_alphabeta = UnifiedLRP(
    model=resnet.model,
    rule='alphabeta',
    alpha=2.0,
    beta=1.0,
    validate_conservation=False
)

alphabeta_attributions = lrp_alphabeta.attribute(
    inputs={'x': covid_image},
    target_class=predicted_class
)

alphabeta_relevance = alphabeta_attributions['x'].squeeze()
visualizer.plot_saliency_overlay(
    plt,
    image=covid_batch['image'][0],
    saliency=alphabeta_relevance[0] if alphabeta_relevance.dim() == 3 else alphabeta_relevance,
    title=f"LRP Alpha-Beta Rule (Œ±=2, Œ≤=1) - {id2label[predicted_class]}",
)

print(f"\n‚úì Results:")
print(f"  Epsilon Rule - Total relevance: {lrp_attributions['x'].sum().item():.4f}")
print(f"  Alpha-Beta Rule - Total relevance: {alphabeta_attributions['x'].sum().item():.4f}")

### Side-by-Side Comparison of All Interpretation Methods

Let's create a comprehensive comparison showing gradient saliency and both LRP rules side by side:

In [None]:
# Create side-by-side comparison of all three methods
attributions_dict = {
    'Gradient Saliency': saliency_comparison['image'][0],
    'LRP Epsilon (Œµ=0.1)': relevance_map[0] if relevance_map.dim() == 3 else relevance_map,
    'LRP Alpha-Beta (Œ±=2, Œ≤=1)': alphabeta_relevance[0] if alphabeta_relevance.dim() == 3 else alphabeta_relevance
}

visualizer.plot_multiple_attributions(
    plt,
    image=covid_batch['image'][0],
    attributions=attributions_dict
)

print("\nüìä Key Observations:")
print("  ‚Ä¢ Gradient Saliency: Shows regions with high gradient magnitude")
print("  ‚Ä¢ LRP Epsilon: More balanced, stable attribution across the image")
print("  ‚Ä¢ LRP Alpha-Beta: Sharper focus on positive evidence regions")

## UnifiedLRP Implementation Details

The **UnifiedLRP** implementation supports a wide range of neural network architectures through modular layer handlers:

**Supported Layers (11 handlers):**
- **Dense/Embedding**: Linear, ReLU, Embedding
- **Convolutional**: Conv2d, MaxPool2d, AvgPool2d, AdaptiveAvgPool2d  
- **Normalization**: BatchNorm2d
- **Utility**: Flatten, Dropout

This modular design makes it easy to:
- Apply LRP to both CNNs (images) and MLPs (tabular/embedding data)
- Extend with custom handlers for new layer types
- Validate conservation property at each layer

**Current Status**: Production-ready for standard CNN architectures. Future updates will add support for ResNet skip connections and Transformer attention mechanisms.

In [None]:
# Let's apply LRP to multiple samples from the batch
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Get the device where the model is located
device = next(resnet.model.parameters()).device

for idx in range(3):
    sample_image = covid_batch['image'][idx:idx+1]
    
    # Convert grayscale to RGB for ResNet
    sample_image_rgb = sample_image.repeat(1, 3, 1, 1) if sample_image.shape[1] == 1 else sample_image
    
    # Move to the correct device
    sample_image_rgb = sample_image_rgb.to(device)
    
    # Get prediction
    with torch.no_grad():
        output = resnet.model(sample_image_rgb)
        pred_class = output.argmax(dim=1).item()
    
    # Compute LRP
    sample_lrp = lrp.attribute(
        inputs={'x': sample_image_rgb},
        target_class=pred_class
    )
    
    # Plot original image (grayscale)
    axes[0, idx].imshow(sample_image.squeeze().cpu().numpy(), cmap='gray')
    axes[0, idx].set_title(f'Sample {idx}: {id2label[pred_class]}', fontsize=12, fontweight='bold')
    axes[0, idx].axis('off')
    
    # Plot LRP heatmap (use first channel since all are same for grayscale input)
    relevance = sample_lrp['x'].squeeze()
    if relevance.dim() == 3:  # If shape is (3, H, W)
        relevance = relevance[0]  # Take first channel
    im = axes[1, idx].imshow(relevance.detach().cpu().numpy(), cmap='seismic', vmin=-0.1, vmax=0.1)
    axes[1, idx].set_title(f'LRP Heatmap (Œµ=0.1)', fontsize=10)
    axes[1, idx].axis('off')

plt.tight_layout()
plt.show()

print("‚úì Applied LRP to 3 different COVID-19 X-ray samples")

## Key Takeaways: Gradient Saliency vs. LRP

**Gradient Saliency Maps:**
- ‚úì Fast - single backward pass through gradients
- ‚úì Works with any differentiable model  
- ‚úì Good for identifying "where" the model looks
- ‚úì Straightforward implementation
- ‚ö†Ô∏è Can be noisy and may require smoothing
- ‚ö†Ô∏è Doesn't satisfy conservation property

**Layer-wise Relevance Propagation (LRP):**
- ‚úì **Conservation property**: Relevances sum to model output for the target class
- ‚úì More theoretically grounded attribution
- ‚úì Modular design with layer-specific handlers
- ‚úì Better captures "how much" each pixel contributes
- ‚úì Supports both CNNs and MLPs with UnifiedLRP
- ‚ö†Ô∏è Requires layer-specific propagation rules
- ‚ö†Ô∏è Expected conservation violations of 5-150% depending on rule

**Which one to use?**
- Use **Gradient Saliency** for quick exploration and fast prototyping
- Use **LRP** when you need precise, quantifiable attributions with conservation
- Use **LRP Epsilon Rule** for numerically stable, balanced attributions
- Use **LRP Alpha-Beta Rule** for sharper visualizations emphasizing positive evidence
- Use **both** to get complementary insights into your model's behavior!

**UnifiedLRP Status:**
- ‚úÖ Production-ready for CNNs (11 layer handlers implemented)
- ‚úÖ Supports: Conv2d, MaxPool2d, BatchNorm2d, Linear, ReLU, Flatten, Dropout, and more
- ‚è≥ Future: ResNet skip connections, Transformer attention, RNN support