# MedGemma Explainability - Setup and Verification

This notebook verifies that MedGemma can be loaded and attention weights can be extracted.

## Step 1: Setup HuggingFace Authentication

Make sure you have:
1. Created a HuggingFace account
2. Accepted the MedGemma license at https://huggingface.co/google/medgemma-1.5-4b-it
3. Created an access token at https://huggingface.co/settings/tokens
4. Added the token to Colab secrets with name `HF_TOKEN`

In [None]:
# Setup HuggingFace authentication
from google.colab import userdata
import os

# Get token from Colab secrets
hf_token = userdata.get('HF_TOKEN')
os.environ['HF_TOKEN'] = hf_token
os.environ['HUGGING_FACE_HUB_TOKEN'] = hf_token

# Login to HuggingFace
from huggingface_hub import login
login(token=hf_token, add_to_git_credential=False)
print("HuggingFace authentication successful!")

## Step 2: Check GPU

In [None]:
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## Step 3: Load MedGemma Model

In [None]:
from transformers import AutoProcessor, AutoModelForImageTextToText

model_name = "google/medgemma-1.5-4b-it"
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading {model_name}...")

# Load processor
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
print("Processor loaded")

# Load model in bfloat16
model = AutoModelForImageTextToText.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=device,
    trust_remote_code=True,
)
print("Model loaded successfully!")

## Step 4: Explore Model Architecture

In [None]:
def print_model_structure(model, max_depth=3):
    """Print model structure."""
    def _print(module, prefix="", depth=0):
        if depth >= max_depth:
            return
        for i, (name, child) in enumerate(module.named_children()):
            is_last = i == len(list(module.named_children())) - 1
            print(f"{prefix}{'└── ' if is_last else '├── '}{name}: {child.__class__.__name__}")
            _print(child, prefix + ('    ' if is_last else '│   '), depth + 1)
    
    print(f"Model: {model.__class__.__name__}")
    _print(model)

print_model_structure(model)

In [None]:
# Analyze model config
config = model.config
print("Model Configuration")
print("=" * 50)

if hasattr(config, 'text_config'):
    tc = config.text_config
    print("\nLanguage Model:")
    print(f"  num_hidden_layers: {getattr(tc, 'num_hidden_layers', 'N/A')}")
    print(f"  num_attention_heads: {getattr(tc, 'num_attention_heads', 'N/A')}")
    print(f"  num_key_value_heads: {getattr(tc, 'num_key_value_heads', 'N/A')}")
    print(f"  hidden_size: {getattr(tc, 'hidden_size', 'N/A')}")
    print(f"  head_dim: {getattr(tc, 'head_dim', 'N/A')}")

if hasattr(config, 'vision_config'):
    vc = config.vision_config
    print("\nVision Model:")
    print(f"  num_hidden_layers: {getattr(vc, 'num_hidden_layers', 'N/A')}")
    print(f"  num_attention_heads: {getattr(vc, 'num_attention_heads', 'N/A')}")
    print(f"  hidden_size: {getattr(vc, 'hidden_size', 'N/A')}")
    print(f"  image_size: {getattr(vc, 'image_size', 'N/A')}")
    print(f"  patch_size: {getattr(vc, 'patch_size', 'N/A')}")

## Step 5: Find Attention Modules

In [None]:
# Find all attention modules
language_attn = []
vision_attn = []

for name, module in model.named_modules():
    if 'self_attn' in name.lower() or 'attention' in module.__class__.__name__.lower():
        if 'language' in name.lower():
            language_attn.append((name, module))
        elif 'vision' in name.lower():
            vision_attn.append((name, module))

print(f"Language model attention layers: {len(language_attn)}")
print(f"Vision model attention layers: {len(vision_attn)}")

if language_attn:
    print(f"\nFirst language attention: {language_attn[0][0]}")
    module = language_attn[0][1]
    print(f"  Class: {module.__class__.__name__}")
    for attr in ['num_heads', 'num_key_value_heads', 'head_dim']:
        if hasattr(module, attr):
            print(f"  {attr}: {getattr(module, attr)}")

## Step 6: Test Inference

In [None]:
from PIL import Image
import numpy as np

# Create test image (left=red, right=blue)
img_array = np.zeros((224, 224, 3), dtype=np.uint8)
img_array[:, :112, 0] = 255  # Red on left
img_array[:, 112:, 2] = 255  # Blue on right
test_image = Image.fromarray(img_array)

# Display test image
display(test_image)

In [None]:
# Run inference
prompt = "What colors do you see in this image?"

messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "image": test_image},
            {"type": "text", "text": prompt},
        ],
    }
]

text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=text, images=test_image, return_tensors="pt").to(device)

print("Input shapes:")
for k, v in inputs.items():
    if hasattr(v, 'shape'):
        print(f"  {k}: {v.shape}")

# Generate
with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=100, do_sample=False)

generated = processor.decode(outputs[0], skip_special_tokens=True)
print(f"\nGenerated response:\n{generated}")

## Step 7: Test Attention Extraction

In [None]:
# Test with output_attentions=True
print("Testing attention extraction with output_attentions=True...")

with torch.no_grad():
    outputs = model(**inputs, output_attentions=True, return_dict=True)

if hasattr(outputs, 'attentions') and outputs.attentions is not None:
    attentions = outputs.attentions
    print(f"\nSUCCESS! Got {len(attentions)} attention tensors")
    print(f"\nFirst attention shape: {attentions[0].shape}")
    print(f"Last attention shape: {attentions[-1].shape}")
    
    # Analyze first attention tensor
    first_attn = attentions[0]
    print(f"\nFirst attention stats:")
    print(f"  Shape: {first_attn.shape}")
    print(f"  dtype: {first_attn.dtype}")
    print(f"  Min: {first_attn.min().item():.6f}")
    print(f"  Max: {first_attn.max().item():.6f}")
    print(f"  Mean: {first_attn.mean().item():.6f}")
    
    # Check softmax (rows should sum to 1)
    row_sums = first_attn[0, 0].sum(dim=-1)
    print(f"\n  Row sums (should be ~1):")
    print(f"    Min: {row_sums.min().item():.4f}")
    print(f"    Max: {row_sums.max().item():.4f}")
else:
    print("\nAttentions not available in outputs")
    print("Will need to use hooks for attention extraction")

In [None]:
# Check attention shape details
if hasattr(outputs, 'attentions') and outputs.attentions:
    print("Attention tensor shapes across layers:")
    print("(batch, heads, seq_len, seq_len)")
    print("-" * 50)
    
    for i, attn in enumerate(outputs.attentions):
        if i < 5 or i >= len(outputs.attentions) - 2:
            print(f"Layer {i:2d}: {tuple(attn.shape)}")
        elif i == 5:
            print("...")

## Step 8: Test Gradient Capture

In [None]:
# Test gradient capture
print("Testing gradient capture...")

model.train()  # Enable gradients

# Forward pass with gradients
outputs = model(**inputs, output_attentions=True, return_dict=True)

if outputs.attentions:
    # Try to retain gradients on attention tensors
    for attn in outputs.attentions:
        if attn.requires_grad:
            attn.retain_grad()
    
    # Backward from last token logit
    logits = outputs.logits
    target_logit = logits[0, -1, logits[0, -1].argmax()]
    target_logit.backward(retain_graph=True)
    
    # Check gradients
    grad_count = 0
    for i, attn in enumerate(outputs.attentions):
        if attn.grad is not None:
            grad_count += 1
            if i < 3:
                print(f"Layer {i}: grad shape {attn.grad.shape}, "
                      f"grad norm {attn.grad.norm().item():.6f}")
    
    print(f"\nGradients captured for {grad_count}/{len(outputs.attentions)} layers")
else:
    print("No attentions available for gradient capture")

model.eval()

## Summary

In [None]:
print("=" * 60)
print("Verification Summary")
print("=" * 60)
print(f"\nModel: {model_name}")
print(f"Device: {device}")
print(f"\nLanguage model layers: {len(language_attn)}")
print(f"Vision model layers: {len(vision_attn)}")

if hasattr(config, 'text_config'):
    tc = config.text_config
    print(f"\nGQA configuration:")
    print(f"  Query heads: {getattr(tc, 'num_attention_heads', 'N/A')}")
    print(f"  KV heads: {getattr(tc, 'num_key_value_heads', 'N/A')}")

print(f"\nAttention extraction: {'Working' if hasattr(outputs, 'attentions') and outputs.attentions else 'Needs hooks'}")
print(f"\nReady for Chefer method implementation!")

In [None]:
# Save model and processor for later use
# They're already loaded in memory, so we just confirm they're available
print(f"Model loaded: {model is not None}")
print(f"Processor loaded: {processor is not None}")
print(f"\nYou can now proceed to the explainability implementation!")