# TorchXRayVision Grad-CAM quickstart

This notebook demonstrates how to use the pretrained `torchxrayvision` DenseNet121 (CheX weights) that powers the web demo. Update the image path in the cell below with a chest X-ray to reproduce the backend predictions and Grad-CAM overlays locally.

In [None]:
import base64
import io
from pathlib import Path

import torch
import torchxrayvision as xrv
from PIL import Image

from medmnist_web.utils import (
    GradCAM,
    get_device,
    get_display_names,
    logits_to_output,
    pil_to_tensor,
    render_gradcam_overlay,
    resolve_densenet_target_layer,
    set_class_names,
)


In [None]:
device = get_device()
model = xrv.models.DenseNet(weights="densenet121-res224-chex").to(device).eval()
set_class_names(getattr(model, "pathologies", getattr(model, "classes", [])))
print(f"Loaded DenseNet on {device}. Classes: {', '.join(get_display_names())}")
gradcam = GradCAM(model, resolve_densenet_target_layer(model))


In [None]:
# Update this path with a chest X-ray image before running the next cell.
IMAGE_PATH = Path("path/to/your/xray.png")
if not IMAGE_PATH.exists():
    raise FileNotFoundError(f"Image not found: {IMAGE_PATH}. Update IMAGE_PATH above with your image path.")

pil_image = Image.open(IMAGE_PATH)
print(f"Using image: {IMAGE_PATH}")
pil_image


In [None]:
# Run the model and generate Grad-CAM
x, resized = pil_to_tensor(pil_image, device)
logits, cam, target_idx = gradcam(x)
summary = logits_to_output(logits)
summary.pop("top_indices", None)
print("Predicted findings:", summary["pred_class"])
print("Grad-CAM target:", get_display_names()[target_idx])

overlay_b64, heatmap_b64 = render_gradcam_overlay(resized, cam)

def b64_to_image(blob: str) -> Image.Image:
    return Image.open(io.BytesIO(base64.b64decode(blob)))

overlay_img = b64_to_image(overlay_b64)
heatmap_img = b64_to_image(heatmap_b64)
overlay_img


In [None]:
# Heatmap preview
heatmap_img
