In [68]:
import torch
import torch.nn.functional as F
from torchvision import models, transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

In [69]:
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [70]:
def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    input_tensor = preprocess(image).unsqueeze(0)  # Add batch dimension
    return input_tensor, image

In [71]:
# Load the pretrained ResNet50 model
model = models.resnet50(pretrained=True)
model.eval()  # Set model to evaluation mode

# Dictionary to store activations for multiple layers
activations = {}

In [None]:
# Hook function to save activations from any layer
def get_activation(name):
    def hook(model, input, output):
        activations[name] = output
    return hook

# Register hooks on all layers that are convolutional
for name, layer in model.named_modules():
    if isinstance(layer, torch.nn.Conv2d):  # We only care about Conv layers for Grad-CAM
        layer.register_forward_hook(get_activation(name))

# Print the number of layers for reference
print(f"Total convolutional layers with hooks: {len(activations)}")

In [73]:
# Function to compute the Grad-CAM heatmap
def grad_cam(activation, gradients):
    # Global average pooling over the gradients (average the gradients per feature map)
    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

    # Multiply each channel in the activation map by the pooled gradients
    for i in range(activation.size(1)):
        activation[:, i, :, :] *= pooled_gradients[i]

    # Create the heatmap by averaging the weighted activation maps
    heatmap = torch.mean(activation, dim=1).squeeze()

    # Apply ReLU to remove negative values
    heatmap = F.relu(heatmap)

    # Normalize the heatmap between 0 and 1 for visualization
    heatmap -= heatmap.min()
    heatmap /= heatmap.max()

    return heatmap.detach().cpu().numpy()

In [74]:
from PIL import Image

# Function to visualize the heatmap overlayed on the original image with color
def visualize_heatmap(img, heatmap, alpha=0.5, colormap='jet'):
    # Resize heatmap to match the size of the original image
    heatmap = np.uint8(255 * heatmap)
    
    # Apply colormap to the heatmap (using plt.cm.get_cmap)
    colormap = plt.get_cmap(colormap)
    heatmap_colored = colormap(heatmap)

    # Remove the alpha channel from the colormap result (4th channel)
    heatmap_colored = np.delete(heatmap_colored, 3, axis=2)  # Drop alpha channel

    # Resize the colored heatmap to match the size of the original image
    heatmap_colored = Image.fromarray(np.uint8(heatmap_colored * 255))  # Convert to PIL image
    heatmap_colored = heatmap_colored.resize(img.size, Image.LANCZOS)  # Resize to match image size
    heatmap_colored = np.array(heatmap_colored)  # Convert back to numpy array

    # Convert original image to numpy array
    img = np.array(img)

    # Overlay the heatmap on the image with transparency
    overlay = np.uint8(img * (1 - alpha) + heatmap_colored * alpha)

    # Display the overlay
    plt.figure(figsize=(6, 6))
    plt.imshow(overlay)
    plt.axis('off')
    plt.show()


In [75]:
# Load and preprocess your image
input_tensor, image = load_image('/Users/ahmtox/tmp/Homework/CS282r/cs2822r-project/wangzai.jpeg')

# Set the model to train mode temporarily to compute gradients
model.train()

# Forward pass to get model's output
with torch.enable_grad():
    output = model(input_tensor)

# Get the predicted class (class with the highest score)
target_class = output.argmax().item()

# Zero out any previous gradients
model.zero_grad()

# Backward pass to compute gradients for the target class, retain graph for Grad-CAM
output[:, target_class].backward(retain_graph=True)

In [None]:
# Function to visualize Grad-CAM for multiple layers
def generate_gradcam_for_all_layers(activations, target_class, gradients):
    for layer_name, activation in activations.items():
        # Compute the gradients for each layer
        gradients = torch.autograd.grad(output[:, target_class], activation, retain_graph=True)[0]

        # Generate the Grad-CAM heatmap using the activations and gradients
        heatmap = grad_cam(activation, gradients)

        # Visualize the heatmap overlayed on the original image
        print(f"Visualizing Grad-CAM for layer: {layer_name}")
        visualize_heatmap(image, heatmap)

# Generate and visualize Grad-CAM heatmaps for all convolutional layers with a colormap
generate_gradcam_for_all_layers(activations, target_class, gradients)