<a href="https://colab.research.google.com/github/thedatasense/medgemma-explainer/blob/master/notebooks/tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MedGemma Explainability Tutorial

## Implementing Chefer et al. for Vision-Language Model Interpretability

This tutorial demonstrates how to generate explanations for MedGemma predictions using the gradient-weighted attention method from Chefer et al. (2021).

**Reference:** Chefer, H., Gur, S., & Wolf, L. (2021). *Transformer Interpretability Beyond Attention Visualization.* CVPR 2021.

## Table of Contents

1. [Introduction: Why Explainability Matters](#1-introduction)
2. [Method Overview: The Chefer Approach](#2-method-overview)
3. [Setup and Dependencies](#3-setup)
4. [MedGemma Architecture Overview](#4-architecture)
5. [Loading MedGemma](#5-loading-model)
6. [Basic Usage with Test Images](#6-basic-usage)
7. [Step-by-Step Explanation Pipeline](#7-step-by-step)
8. [Medical Imaging: Chest X-ray Analysis](#8-chest-xray)
9. [Comparing Methods](#9-comparing-methods)
10. [Region Analysis](#10-region-analysis)
11. [Advanced: Layer-wise Analysis](#11-layer-analysis)
12. [Limitations and Considerations](#12-limitations)
13. [Summary](#13-summary)

## 1. Introduction: Why Explainability Matters <a name="1-introduction"></a>

Medical AI systems require **transparency** and **interpretability** for several reasons:

1. **Clinical Trust**: Physicians need to understand *why* a model made a prediction to trust and act on it
2. **Regulatory Compliance**: Medical devices require explainability for FDA approval
3. **Error Detection**: Understanding model attention helps identify when it's focusing on irrelevant features
4. **Bias Detection**: Explanations can reveal if a model relies on spurious correlations

This tutorial shows how to generate **relevancy maps** that highlight which parts of an image and text prompt contribute to MedGemma's predictions.

## 2. Method Overview: The Chefer Approach <a name="2-method-overview"></a>

### Key Insight
Raw attention weights show where the model *looks*, but not whether that attention *matters* for the final prediction. The Chefer method combines attention with **gradients** to identify truly important regions.

### Core Equations

**Equation 5: Gradient-Weighted Attention**
$$\bar{A} = \mathbb{E}_h \left[ (\nabla A \odot A)^+ \right]$$

Where:
- $A$ = attention weights
- $\nabla A$ = gradient of the loss w.r.t. attention
- $\odot$ = element-wise multiplication
- $(\cdot)^+$ = keeping only positive values
- $\mathbb{E}_h$ = averaging over attention heads

**Equation 6: Relevancy Propagation**
$$R = \bar{A} \cdot R$$

Starting with $R = I$ (identity matrix), we propagate relevancy through each layer.

### Critical: Correct Backprop Target for Causal LM

**This is the most important insight for correct explanations:**

For causal language models like MedGemma:
- **Logit at position i predicts token at position i+1**
- To explain why token at position p was generated:
  1. Backprop from logit at position **p-1**
  2. Use the **actual token id** at position p (not argmax)
  3. Extract relevancy from row **p-1** in the R matrix

**Wrong approach (common mistake):**
```python
# WRONG: Using last position with argmax
target_logit = logits[0, -1, logits[0, -1].argmax()]
```

**Correct approach:**
```python
# CORRECT: For token at position p, use logit at p-1
logit_position = target_token_position - 1
target_token_id = input_ids[0, target_token_position]  # Actual token
target_logit = logits[0, logit_position, target_token_id]
```

### Why This Works
- **Correct gradient flow**: Gradients reflect which attention patterns led to generating this specific token
- **Actual token id**: Using the real token (not argmax) ensures we explain what actually happened
- **Position p-1**: This is where the model "decided" to output the token at p

## 3. Setup and Dependencies <a name="3-setup"></a>

First, let's set up HuggingFace authentication and install dependencies.

In [1]:
# Setup HuggingFace authentication
from google.colab import userdata
import os

# Get token from Colab secrets
hf_token = userdata.get('HF_TOKEN')
os.environ['HF_TOKEN'] = hf_token

from huggingface_hub import login
login(token=hf_token, add_to_git_credential=False)
print("HuggingFace authentication successful!")

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


HuggingFace authentication successful!


In [2]:
import sys
sys.path.insert(0, '/content')

import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import requests
from io import BytesIO

# Check GPU
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

PyTorch version: 2.9.0+cu126
CUDA available: True
GPU: NVIDIA A100-SXM4-80GB
GPU Memory: 85.2 GB


## 4. MedGemma Architecture Overview <a name="4-architecture"></a>

MedGemma 1.5 4B is a vision-language model with:

### Vision Encoder (SigLIP)
- 27 transformer layers
- Processes images to 4096 patches (64×64)
- Pooled via 4×4 AvgPool2d to **256 image tokens** (16×16 grid)

### Language Model (Gemma3)
- **34 transformer layers**
- Grouped-Query Attention (GQA): 8 query heads, 4 KV heads
- 5:1 local:global attention ratio
- Global layers at indices: 5, 11, 17, 23, 29
- Local attention window: 1024 tokens

### Token Structure
```
Position 0:     <bos>
Position 1:     <start_of_turn>
Position 2:     user
Position 3:     \n\n
Position 4-5:   <start_of_image> tokens
Position 6-261: 256 IMAGE TOKENS (16×16 grid)
Position 262:   <end_of_image>
Position 263+:  Text prompt and generated response
```

**Important for Chest X-rays:**
- Left side of image = Patient's RIGHT side
- Right side of image = Patient's LEFT side
- The 16×16 grid maps to anatomical regions

## 5. Loading MedGemma <a name="5-loading-model"></a>

**Critical:** We must use `attn_implementation='eager'` to get attention outputs. SDPA (the default) doesn't support `output_attentions=True`.

In [3]:
from transformers import AutoProcessor, AutoModelForImageTextToText

model_name = "google/medgemma-1.5-4b-it"
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading {model_name}...")
print("(This may take 1-2 minutes)")

processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForImageTextToText.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=device,
    trust_remote_code=True,
    attn_implementation="eager",  # REQUIRED for attention output!
)

print("\nModel loaded successfully!")

# Print key config
if hasattr(model.config, 'text_config'):
    tc = model.config.text_config
    print(f"\nLanguage Model Config:")
    print(f"  Layers: {tc.num_hidden_layers}")
    print(f"  Query heads: {tc.num_attention_heads}")
    print(f"  KV heads: {tc.num_key_value_heads}")

Loading google/medgemma-1.5-4b-it...
(This may take 1-2 minutes)


processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/2.55k [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]


Model loaded successfully!

Language Model Config:
  Layers: 34
  Query heads: 8
  KV heads: 4


## 6. Basic Usage with Test Images <a name="6-basic-usage"></a>

Let's start with simple test images to understand how the explainer works.

In [4]:
# Import the explainability library
from medgemma_explainability import MedGemmaExplainer
from medgemma_explainability.visualization import (
    ExplanationResult,
    visualize_explanation,
    visualize_comparison,
)
from medgemma_explainability.relevancy import (
    compute_Abar,
    propagate_relevancy,
    extract_token_relevancy,
    split_relevancy,
    compute_raw_attention_relevancy,
)
from medgemma_explainability.utils import (
    normalize_relevancy,
    reshape_image_relevancy,
    is_global_layer,
)

# Create explainer
explainer = MedGemmaExplainer(model, processor, device=device)
print("Explainer created!")

ModuleNotFoundError: No module named 'medgemma_explainability'

In [None]:
# Create test images
def create_test_image(pattern='split'):
    """Create test images with different patterns."""
    img = np.zeros((224, 224, 3), dtype=np.uint8)

    if pattern == 'split':
        # Left red, right blue
        img[:, :112, 0] = 255
        img[:, 112:, 2] = 255
    elif pattern == 'quadrant':
        # Four colored quadrants
        img[:112, :112, 0] = 255  # Red top-left
        img[:112, 112:, 1] = 255  # Green top-right
        img[112:, :112, 2] = 255  # Blue bottom-left
        img[112:, 112:] = [255, 255, 0]  # Yellow bottom-right
    elif pattern == 'circle':
        # Red circle on blue background
        y, x = np.ogrid[:224, :224]
        mask = (x - 112)**2 + (y - 112)**2 <= 50**2
        img[mask] = [255, 0, 0]
        img[~mask] = [0, 0, 255]

    return Image.fromarray(img)

# Display test images
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
for ax, pattern in zip(axes, ['split', 'quadrant', 'circle']):
    ax.imshow(create_test_image(pattern))
    ax.set_title(f'{pattern.capitalize()} Pattern')
    ax.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# Quick test with the high-level API
test_image = create_test_image('split')
prompt = "What colors do you see in this image?"

print("Running explanation...")

# The explain() method now uses the CORRECT backprop target:
# - For causal LM, logit at position i predicts token at i+1
# - To explain token at position p, backprop from logit at p-1
result = explainer.explain(test_image, prompt, max_new_tokens=50)

print(f"\nGenerated text: {result.generated_text}")
print(f"\nImage relevancy shape: {result.image_relevancy.shape}")
print(f"Text relevancy shape: {result.text_relevancy.shape}")
print(f"\nMetadata: {result.metadata}")

In [None]:
# Visualize the results
fig = visualize_explanation(test_image, result)
plt.show()

In [None]:
# NEW FEATURE: Explain a specific keyword in the response
# This is useful for understanding why the model generated specific words

print("Explaining specific keyword 'red'...")
try:
    result_keyword = explainer.explain_keyword(test_image, prompt, keyword="red", max_new_tokens=50)

    print(f"\nGenerated text: {result_keyword.generated_text}")
    print(f"Target token: {result_keyword.metadata.get('target_token', 'N/A')}")
    print(f"Token position: {result_keyword.metadata.get('target_token_position', 'N/A')}")

    # Visualize
    fig = visualize_explanation(test_image, result_keyword, title_suffix=" (keyword: 'red')")
    plt.show()
except ValueError as e:
    print(f"Keyword not found: {e}")

## 7. Step-by-Step Explanation Pipeline <a name="7-step-by-step"></a>

Let's break down the explanation process to understand each step.

In [None]:
# Constants for MedGemma token structure
NUM_IMAGE_TOKENS = 256
IMAGE_START_IDX = 6

# Create test image and prompt
test_image = create_test_image('quadrant')
prompt = "Describe this image in detail."

# STEP 1: Prepare inputs
print("STEP 1: Prepare inputs")
messages = [{'role': 'user', 'content': [
    {'type': 'image', 'image': test_image},
    {'type': 'text', 'text': prompt}
]}]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=text, images=test_image, return_tensors='pt').to(device)
print(f"  Input sequence length: {inputs['input_ids'].shape[1]}")

In [None]:
# STEP 2: Generate response
print("STEP 2: Generate response")
with torch.no_grad():
    gen_outputs = model.generate(**inputs, max_new_tokens=100, do_sample=False)

generated_text = processor.decode(gen_outputs[0], skip_special_tokens=True)
if 'model' in generated_text:
    response = generated_text.split('model')[-1].strip()
else:
    response = generated_text
print(f"  Response: {response[:200]}...")

In [None]:
# STEP 3: Forward pass with attention
print("STEP 3: Forward pass with attention output")
model.train()  # Enable gradients

full_inputs = {
    'input_ids': gen_outputs,
    'attention_mask': torch.ones_like(gen_outputs),
    'pixel_values': inputs['pixel_values'],
}

outputs = model(**full_inputs, output_attentions=True, return_dict=True)

# Retain gradients on attention tensors
for attn in outputs.attentions:
    if attn is not None and attn.requires_grad:
        attn.retain_grad()

print(f"  Got {len(outputs.attentions)} attention layers")
print(f"  Attention shape per layer: {outputs.attentions[0].shape}")

In [None]:
# STEP 4: Compute attention gradients via backprop (CORRECT METHOD)
print("STEP 4: Compute attention gradients via backprop")
print("  IMPORTANT: Using CORRECT backprop target for causal LM")

# Keep model in eval mode, use torch.enable_grad() context
model.eval()

with torch.enable_grad():
    # Re-run forward pass with gradients
    outputs = model(**full_inputs, output_attentions=True, return_dict=True)

    # Retain gradients on attention tensors
    for attn in outputs.attentions:
        if attn is not None:
            attn.requires_grad_(True)
            attn.retain_grad()

    # CORRECT TARGET: To explain the last token, use logit at position -2
    # Because logit[i] predicts token[i+1]
    target_token_position = gen_outputs.shape[1] - 1  # Last token position
    target_token_id = gen_outputs[0, target_token_position].item()  # Actual token id
    logit_position = target_token_position - 1  # Logit that predicted this token

    target_logit = outputs.logits[0, logit_position, target_token_id]
    print(f"  Target token position: {target_token_position}")
    print(f"  Logit position (p-1): {logit_position}")
    print(f"  Target token: '{processor.decode([target_token_id])}'")

    target_logit.backward(retain_graph=True)

# Collect attention and gradients
attention_maps = {}
attention_grads = {}

for i, attn in enumerate(outputs.attentions):
    if attn is not None:
        attention_maps[i] = attn.detach().float()
        if attn.grad is not None:
            attention_grads[i] = attn.grad.detach().float()

print(f"  Collected {len(attention_maps)} attention maps")
print(f"  Collected {len(attention_grads)} gradient maps")

In [None]:
# STEP 5: Propagate relevancy (Chefer method)
print("STEP 5: Propagate relevancy through layers")

seq_len = gen_outputs.shape[1]

R = propagate_relevancy(
    attention_maps,
    attention_grads,
    seq_len,
    handle_local_attention=True,
    local_window_size=1024,
    device=device,
)

print(f"  Relevancy matrix shape: {R.shape}")

In [None]:
# STEP 6: Extract token relevancy (from CORRECT position)
print("STEP 6: Extract token relevancy")
print(f"  IMPORTANT: Extract from row {logit_position} (the predicting position)")

# Extract relevancy from the PREDICTING position (p-1), not the target position
token_relevancy = extract_token_relevancy(R, target_token_idx=logit_position)
image_relevancy, text_relevancy = split_relevancy(
    token_relevancy, NUM_IMAGE_TOKENS, IMAGE_START_IDX
)

# Reshape to 16x16 grid and normalize
image_relevancy_2d = reshape_image_relevancy(image_relevancy)
image_relevancy_2d = normalize_relevancy(image_relevancy_2d)
text_relevancy_norm = normalize_relevancy(text_relevancy)

print(f"  Image relevancy shape: {image_relevancy_2d.shape}")
print(f"  Text relevancy tokens: {len(text_relevancy)}")

In [None]:
# STEP 7: Visualize results
print("STEP 7: Visualize")

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Original image
axes[0].imshow(test_image)
axes[0].set_title('Original Image')
axes[0].axis('off')

# Overlay heatmap
heatmap = image_relevancy_2d.cpu().numpy()
heatmap_resized = np.array(Image.fromarray(heatmap).resize(test_image.size, Image.BILINEAR))
axes[1].imshow(test_image)
im = axes[1].imshow(heatmap_resized, cmap='jet', alpha=0.5)
axes[1].set_title('Relevancy Heatmap Overlay')
axes[1].axis('off')
plt.colorbar(im, ax=axes[1], fraction=0.046)

# Text token relevancy
text_start_idx = IMAGE_START_IDX + NUM_IMAGE_TOKENS
token_labels = [processor.decode([t.item()]) for t in gen_outputs[0][text_start_idx:]]
text_rel_np = text_relevancy_norm.cpu().numpy()

n_tokens = min(15, len(text_rel_np))
sorted_idx = np.argsort(text_rel_np)[::-1][:n_tokens]
labels = [token_labels[i][:15].replace('\n', '\\n') for i in sorted_idx]
values = [text_rel_np[i] for i in sorted_idx]

axes[2].barh(range(len(labels)), values, color='steelblue')
axes[2].set_yticks(range(len(labels)))
axes[2].set_yticklabels(labels, fontsize=9)
axes[2].invert_yaxis()
axes[2].set_xlabel('Relevancy')
axes[2].set_title('Top Text Token Relevancy')

plt.tight_layout()
plt.show()

## 8. Medical Imaging: Chest X-ray Analysis <a name="8-chest-xray"></a>

Now let's apply the explainability method to a real medical image - a chest X-ray with pneumonia.

In [None]:
# Load chest X-ray image
# Option 1: Use local file if available
import os

if os.path.exists('/content/chest_xray.jpg'):
    chest_xray = Image.open('/content/chest_xray.jpg').convert('RGB')
    print("Loaded local chest X-ray")
else:
    # Option 2: Download from URL
    url = "https://prod-images-static.radiopaedia.org/images/1371188/0a1f5edc85aa58d5780928cb39b08659c1fc4d6d7c7dce2f8db1d63c7c737234_big_gallery.jpeg"
    response = requests.get(url)
    chest_xray = Image.open(BytesIO(response.content)).convert('RGB')
    print("Downloaded chest X-ray from URL")

print(f"Image size: {chest_xray.size}")

# Display
plt.figure(figsize=(8, 8))
plt.imshow(chest_xray, cmap='gray')
plt.title('Chest X-ray (PA View)\nLeft of image = Patient RIGHT side')
plt.axis('off')
plt.show()

In [None]:
# Analyze chest X-ray for pneumonia
medical_prompt = "Analyze this chest X-ray. Is there evidence of pneumonia? If so, describe the location and appearance of the consolidation."

print("Analyzing chest X-ray...")
print(f"Prompt: {medical_prompt}")
print()

# Prepare inputs
messages = [{'role': 'user', 'content': [
    {'type': 'image', 'image': chest_xray},
    {'type': 'text', 'text': medical_prompt}
]}]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=text, images=chest_xray, return_tensors='pt').to(device)

# Generate response
with torch.no_grad():
    gen_outputs = model.generate(**inputs, max_new_tokens=200, do_sample=False)

generated_text = processor.decode(gen_outputs[0], skip_special_tokens=True)
if 'model' in generated_text:
    response = generated_text.split('model')[-1].strip()
else:
    response = generated_text

print("MedGemma's Assessment:")
print("-" * 50)
print(response)

In [None]:
# Extract attention and gradients for chest X-ray
print("Extracting attention and gradients...")

model.train()

full_inputs = {
    'input_ids': gen_outputs,
    'attention_mask': torch.ones_like(gen_outputs),
    'pixel_values': inputs['pixel_values'],
}

outputs = model(**full_inputs, output_attentions=True, return_dict=True)

for attn in outputs.attentions:
    if attn is not None and attn.requires_grad:
        attn.retain_grad()

# Backward pass
logits = outputs.logits
target_logit = logits[0, -1, logits[0, -1].argmax()]
target_logit.backward(retain_graph=True)

# Collect
attention_maps = {}
attention_grads = {}

for i, attn in enumerate(outputs.attentions):
    if attn is not None:
        attention_maps[i] = attn.detach()
        if attn.grad is not None:
            attention_grads[i] = attn.grad.detach()

print(f"Collected {len(attention_grads)} gradient maps")
model.eval()

In [None]:
# Propagate relevancy
seq_len = gen_outputs.shape[1]

R = propagate_relevancy(
    attention_maps,
    attention_grads,
    seq_len,
    handle_local_attention=True,
    local_window_size=1024,
    device=device,
)

# Extract relevancy
token_relevancy = extract_token_relevancy(R, target_token_idx=-1)
image_relevancy, text_relevancy = split_relevancy(
    token_relevancy, NUM_IMAGE_TOKENS, IMAGE_START_IDX
)

image_relevancy_2d = reshape_image_relevancy(image_relevancy)
image_relevancy_2d = normalize_relevancy(image_relevancy_2d)
text_relevancy = normalize_relevancy(text_relevancy)

# Also compute raw attention baseline
raw_attn = compute_raw_attention_relevancy(attention_maps)
raw_token_rel = raw_attn[-1, :]
raw_img, raw_txt = split_relevancy(raw_token_rel, NUM_IMAGE_TOKENS, IMAGE_START_IDX)
raw_image_2d = normalize_relevancy(reshape_image_relevancy(raw_img))

print("Relevancy computed!")

In [None]:
# Comprehensive chest X-ray visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Row 1: Images
axes[0, 0].imshow(chest_xray, cmap='gray')
axes[0, 0].set_title('Original Chest X-ray', fontsize=14)
axes[0, 0].axis('off')

# Chefer method heatmap
heatmap = image_relevancy_2d.cpu().numpy()
heatmap_resized = np.array(Image.fromarray(heatmap).resize(chest_xray.size, Image.BILINEAR))

axes[0, 1].imshow(chest_xray, cmap='gray')
im1 = axes[0, 1].imshow(heatmap_resized, cmap='jet', alpha=0.5)
axes[0, 1].set_title('Chefer Method - Relevancy Map', fontsize=14)
axes[0, 1].axis('off')
plt.colorbar(im1, ax=axes[0, 1], fraction=0.046)

# Raw attention
raw_heatmap = raw_image_2d.cpu().numpy()
raw_resized = np.array(Image.fromarray(raw_heatmap).resize(chest_xray.size, Image.BILINEAR))

axes[0, 2].imshow(chest_xray, cmap='gray')
im2 = axes[0, 2].imshow(raw_resized, cmap='jet', alpha=0.5)
axes[0, 2].set_title('Raw Attention (Last Layer)', fontsize=14)
axes[0, 2].axis('off')
plt.colorbar(im2, ax=axes[0, 2], fraction=0.046)

# Row 2: Analysis
# Heatmap grid
axes[1, 0].imshow(heatmap, cmap='hot')
axes[1, 0].set_title('Relevancy Heatmap (16x16 grid)', fontsize=14)
axes[1, 0].set_xlabel('Image column')
axes[1, 0].set_ylabel('Image row')
plt.colorbar(axes[1, 0].images[0], ax=axes[1, 0], fraction=0.046)

# Region analysis (with correct anatomical orientation!)
# Left of image = Patient's RIGHT side
regions = {
    'Pt Right Upper': heatmap[:5, :8].mean(),
    'Pt Right Middle': heatmap[5:11, :8].mean(),
    'Pt Right Lower': heatmap[11:, :8].mean(),
    'Pt Left Upper': heatmap[:5, 8:].mean(),
    'Pt Left Middle': heatmap[5:11, 8:].mean(),
    'Pt Left Lower': heatmap[11:, 8:].mean(),
}

colors = ['red' if 'Pt Right' in k else 'steelblue' for k in regions.keys()]
bars = axes[1, 1].barh(list(regions.keys()), list(regions.values()), color=colors)
axes[1, 1].set_xlabel('Mean Relevancy')
axes[1, 1].set_title('Relevancy by Lung Region\n(Red = Pneumonia side)', fontsize=14)

for bar, (name, val) in zip(bars, regions.items()):
    axes[1, 1].text(val + 0.005, bar.get_y() + bar.get_height()/2,
                    f'{val:.3f}', va='center', fontsize=10)

# Top text tokens
text_start_idx = IMAGE_START_IDX + NUM_IMAGE_TOKENS
token_labels = [processor.decode([t.item()]) for t in gen_outputs[0][text_start_idx:]]
text_rel_np = text_relevancy.cpu().numpy()

n_tokens = min(15, len(text_rel_np))
sorted_idx = np.argsort(text_rel_np)[::-1][:n_tokens]
labels = [token_labels[i][:15].replace('\n', '\\n') for i in sorted_idx]
values = [text_rel_np[i] for i in sorted_idx]

axes[1, 2].barh(range(len(labels)), values, color='steelblue')
axes[1, 2].set_yticks(range(len(labels)))
axes[1, 2].set_yticklabels(labels, fontsize=9)
axes[1, 2].invert_yaxis()
axes[1, 2].set_xlabel('Relevancy Score')
axes[1, 2].set_title('Top Text Token Relevancy', fontsize=14)

plt.suptitle('MedGemma Explainability: Chest X-ray Pneumonia Analysis', fontsize=16, y=1.02)
plt.tight_layout()
plt.show()

# Print region analysis
print("\nRegion Analysis (Left of image = Patient's RIGHT side):")
print("-" * 50)
for region, value in sorted(regions.items(), key=lambda x: -x[1]):
    marker = " <-- PNEUMONIA SIDE" if 'Pt Right' in region else ""
    print(f"  {region}: {value:.4f}{marker}")

## 9. Comparing Methods <a name="9-comparing-methods"></a>

Let's compare different explainability approaches:
1. **Standard Chefer** (all layers with gradients)
2. **Global Layers Only** (layers 5, 11, 17, 23, 29)
3. **Attention Rollout** (no gradients)

In [None]:
def attention_rollout(attention_maps, add_residual=True):
    """Compute attention rollout (Abnar & Zuidema, 2020)."""
    layer_indices = sorted(attention_maps.keys())
    seq_len = attention_maps[layer_indices[0]].shape[-1]
    device = attention_maps[layer_indices[0]].device

    rollout = torch.eye(seq_len, device=device, dtype=torch.float32)

    for layer_idx in layer_indices:
        attn = attention_maps[layer_idx].float().mean(dim=(0, 1))

        if add_residual:
            attn = 0.5 * attn + 0.5 * torch.eye(seq_len, device=device)
            attn = attn / attn.sum(dim=-1, keepdim=True)

        rollout = torch.matmul(attn, rollout)

    return rollout


def global_layers_only_relevancy(attention_maps, attention_grads, seq_len, device):
    """Compute relevancy using only global attention layers."""
    R = torch.eye(seq_len, device=device, dtype=torch.float32)

    global_layers = [i for i in sorted(attention_maps.keys()) if is_global_layer(i)]

    for layer_idx in global_layers:
        if layer_idx not in attention_grads:
            continue

        A = attention_maps[layer_idx].float()
        grad_A = attention_grads[layer_idx].float()

        Abar = compute_Abar(A, grad_A, normalize=True)
        Abar = Abar + torch.eye(seq_len, device=device, dtype=torch.float32)
        Abar = Abar / Abar.sum(dim=-1, keepdim=True).clamp(min=1e-8)

        R = torch.matmul(Abar, R)

    return R

In [None]:
# Compute all three methods
print("Computing comparison methods...")

# Method 1: Standard Chefer (already computed)
img_2d_std = image_relevancy_2d

# Method 2: Global layers only
R_global = global_layers_only_relevancy(attention_maps, attention_grads, seq_len, device)
token_rel_global = extract_token_relevancy(R_global, -1)
img_rel_global, _ = split_relevancy(token_rel_global, NUM_IMAGE_TOKENS, IMAGE_START_IDX)
img_2d_global = normalize_relevancy(reshape_image_relevancy(img_rel_global))

# Method 3: Attention rollout
R_rollout = attention_rollout(attention_maps, add_residual=True)
token_rel_rollout = extract_token_relevancy(R_rollout, -1)
img_rel_rollout, _ = split_relevancy(token_rel_rollout, NUM_IMAGE_TOKENS, IMAGE_START_IDX)
img_2d_rollout = normalize_relevancy(reshape_image_relevancy(img_rel_rollout))

print("Done!")

In [None]:
# Compare methods
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

def overlay_heatmap(ax, img, heatmap_tensor, title):
    hm = heatmap_tensor.cpu().numpy() if torch.is_tensor(heatmap_tensor) else heatmap_tensor
    hm_resized = np.array(Image.fromarray(hm.astype(np.float32)).resize(img.size, Image.BILINEAR))
    ax.imshow(img, cmap='gray')
    im = ax.imshow(hm_resized, cmap='jet', alpha=0.5)
    ax.set_title(title, fontsize=12)
    ax.axis('off')
    return im

axes[0, 0].imshow(chest_xray, cmap='gray')
axes[0, 0].set_title('Original Chest X-ray', fontsize=12)
axes[0, 0].axis('off')

overlay_heatmap(axes[0, 1], chest_xray, img_2d_std,
                'Method 1: Standard Chefer\n(All 34 layers + gradients)')

overlay_heatmap(axes[1, 0], chest_xray, img_2d_global,
                'Method 2: Global Layers Only\n(Layers 5, 11, 17, 23, 29)')

overlay_heatmap(axes[1, 1], chest_xray, img_2d_rollout,
                'Method 3: Attention Rollout\n(No gradients)')

plt.suptitle('Comparison of Explainability Methods', fontsize=14)
plt.tight_layout()
plt.show()

## 10. Region Analysis <a name="10-region-analysis"></a>

Let's quantify the relevancy distribution across anatomical regions.

In [None]:
def analyze_regions(heatmap_tensor, name):
    """Analyze relevancy by lung region."""
    hm = heatmap_tensor.cpu().numpy() if torch.is_tensor(heatmap_tensor) else heatmap_tensor

    # Remember: Left of image = Patient's RIGHT side
    regions = {
        'Pt Right Upper': hm[:5, :8].mean(),
        'Pt Right Middle': hm[5:11, :8].mean(),
        'Pt Right Lower': hm[11:, :8].mean(),
        'Pt Left Upper': hm[:5, 8:].mean(),
        'Pt Left Middle': hm[5:11, 8:].mean(),
        'Pt Left Lower': hm[11:, 8:].mean(),
    }

    right_avg = (regions['Pt Right Upper'] + regions['Pt Right Middle'] + regions['Pt Right Lower']) / 3
    left_avg = (regions['Pt Left Upper'] + regions['Pt Left Middle'] + regions['Pt Left Lower']) / 3

    return {
        'name': name,
        'regions': regions,
        'right_avg': right_avg,
        'left_avg': left_avg,
        'ratio': right_avg / left_avg if left_avg > 0 else float('inf'),
    }

# Analyze all methods
methods = [
    ('Standard Chefer', img_2d_std),
    ('Global Only', img_2d_global),
    ('Rollout', img_2d_rollout),
    ('Raw Attention', raw_image_2d),
]

results = [analyze_regions(hm, name) for name, hm in methods]

# Print results
print("Region Analysis Results")
print("=" * 60)
print("\nNote: Left of image = Patient's RIGHT side (pneumonia location)\n")

for r in results:
    print(f"\n{r['name']}:")
    print(f"  Patient's Right side (avg): {r['right_avg']:.4f}")
    print(f"  Patient's Left side (avg):  {r['left_avg']:.4f}")
    print(f"  Right/Left Ratio: {r['ratio']:.2f}x")
    print(f"  Best region: {max(r['regions'], key=r['regions'].get)}")

In [None]:
# Visualize region comparison
fig, ax = plt.subplots(figsize=(12, 6))

x = np.arange(6)
width = 0.2
region_names = list(results[0]['regions'].keys())

for i, r in enumerate(results):
    values = [r['regions'][name] for name in region_names]
    ax.bar(x + i*width, values, width, label=r['name'])

ax.set_xticks(x + width * 1.5)
ax.set_xticklabels(region_names, rotation=45, ha='right')
ax.set_ylabel('Mean Relevancy')
ax.set_title('Region Relevancy Comparison Across Methods')
ax.legend()
ax.axvline(x=2.5, color='gray', linestyle='--', alpha=0.5, label='Patient midline')

plt.tight_layout()
plt.show()

## 11. Advanced: Layer-wise Analysis <a name="11-layer-analysis"></a>

Let's examine how attention to image tokens varies across layers.

In [None]:
# Analyze attention to image tokens per layer
layer_stats = []

for layer_idx in sorted(attention_maps.keys()):
    attn = attention_maps[layer_idx].float().mean(dim=(0, 1))  # Average over heads

    # Attention from last token to image tokens
    last_token_attn = attn[-1, IMAGE_START_IDX:IMAGE_START_IDX + NUM_IMAGE_TOKENS]
    img_attn_sum = last_token_attn.sum().item()

    layer_stats.append({
        'layer': layer_idx,
        'img_attention': img_attn_sum,
        'is_global': is_global_layer(layer_idx),
    })

# Plot
fig, ax = plt.subplots(figsize=(14, 5))

layers = [s['layer'] for s in layer_stats]
img_attn = [s['img_attention'] for s in layer_stats]
colors = ['red' if s['is_global'] else 'steelblue' for s in layer_stats]

bars = ax.bar(layers, img_attn, color=colors)
ax.set_xlabel('Layer Index')
ax.set_ylabel('Attention to Image Tokens')
ax.set_title('Attention to Image Tokens by Layer\n(Red = Global layers, Blue = Local layers)')

# Highlight top layers
top_layers = sorted(layer_stats, key=lambda x: -x['img_attention'])[:5]
print("Top 5 layers by image attention:")
for s in top_layers:
    global_str = " (GLOBAL)" if s['is_global'] else ""
    print(f"  Layer {s['layer']}: {s['img_attention']*100:.1f}%{global_str}")

plt.tight_layout()
plt.show()

## 13. Summary <a name="13-summary"></a>

In this tutorial, we:

1. **Loaded MedGemma** with eager attention for interpretability
2. **Understood the architecture** (34 layers, 256 image tokens, local/global attention)
3. **Implemented Chefer et al. method** with **correct backprop targets**
4. **Applied to medical imaging** (chest X-ray pneumonia detection)
5. **Compared multiple methods** (Chefer, global-only, rollout, raw attention)
6. **Analyzed anatomical regions** with correct orientation
7. **Examined layer-wise attention** patterns

### Critical Insight: Correct Backprop Target

The most important fix for accurate explanations:

```python
# For causal LM: logit at position i predicts token at i+1
# To explain token at position p:
logit_position = p - 1
target_token_id = input_ids[0, p]  # Actual token, NOT argmax
target_logit = logits[0, logit_position, target_token_id]
# Extract relevancy from R[logit_position, :]
```

### New API Methods

The updated `MedGemmaExplainer` provides:
- `explain(image, prompt, target_token_position)` - Explain a specific token
- `explain_keyword(image, prompt, keyword)` - Explain a keyword in the response  
- `explain_answer_span(image, prompt)` - Explain the entire answer

### Key Findings
- The Chefer method correctly highlights the patient's right lung field for pneumonia
- Global attention layers (5, 11, 17, 23, 29) have strongest image attention
- Multiple methods should be compared for robust interpretability

### Files Included
- `medgemma_explainability/` - Core library (v0.3.0)
- `scripts/` - Standalone analysis scripts
- `tests/` - Unit tests
- `outputs/` - Generated visualizations

## 13. Summary <a name="13-summary"></a>

In this tutorial, we:

1. **Loaded MedGemma** with eager attention for interpretability
2. **Understood the architecture** (34 layers, 256 image tokens, local/global attention)
3. **Implemented Chefer et al. method** for gradient-weighted attention
4. **Applied to medical imaging** (chest X-ray pneumonia detection)
5. **Compared multiple methods** (Chefer, global-only, rollout, raw attention)
6. **Analyzed anatomical regions** with correct orientation
7. **Examined layer-wise attention** patterns

### Key Findings
- The Chefer method correctly highlights the patient's right lung field for pneumonia
- Global attention layers (5, 11, 17, 23, 29) have strongest image attention
- Multiple methods should be compared for robust interpretability

### Files Included
- `medgemma_explainability/` - Core library
- `scripts/` - Standalone analysis scripts
- `tests/` - Unit tests
- `outputs/` - Generated visualizations

In [None]:
# Cleanup
torch.cuda.empty_cache()
print("Tutorial complete!")
print("\nGenerated outputs are saved in the /content/outputs/ directory.")