In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

In [2]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

In [3]:
def run_gradcam_on_image(image_path, model, target_layer):
    img = Image.open(image_path).convert('RGB')
    input_tensor = transform(img).unsqueeze(0).to(device)

    # Forward pass
    model.eval()
    output = model(input_tensor)
    pred = torch.argmax(output, 1).item()

    # Grad-CAM
    cam = GradCAM(model=model, target_layers=[target_layer], use_cuda=(device.type=='cuda'))
    grayscale_cam = cam(input_tensor=input_tensor, targets=[ClassifierOutputTarget(pred)])[0, :]

    # Prepare image for overlay
    rgb_img = np.array(img.resize((224, 224))) / 255.0
    cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

    # Plot results
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15,5))
    ax1.imshow(img, cmap='gray')
    ax1.set_title(f"Original (Pred: {pred})")
    ax1.axis('off')

    ax2.imshow(grayscale_cam, cmap='jet')
    ax2.set_title("Grad-CAM Heatmap")
    ax2.axis('off')

    ax3.imshow(cam_image)
    ax3.set_title("Overlay")
    ax3.axis('off')

    plt.show()


In [4]:
def run_on_multiple_images(image_paths, model, target_layer):
    for img_path in image_paths:
        print(f"\nProcessing: {img_path}")
        run_gradcam_on_image(img_path, model, target_layer)

In [5]:
image_files = [
    "D:\datasets\chest_xray\val\PNEUMONIA\person1946_bacteria_4875.jpeg",
    "D:\datasets\chest_xray\val\PNEUMONIA\person1954_bacteria_4886.jpeg",
    "D:\datasets\chest_xray\val\PNEUMONIA\person1951_bacteria_4882.jpeg"
]

run_on_multiple_images(image_files, model, model.features[-1])

NameError: name 'model' is not defined