In [None]:
import torch

import cv2

import torch.nn.functional  as F

import numpy as np

import matplotlib.pyplot as plt

In [None]:
class GradCam:

    def __init__(self, model, target_layer):

        self.model = model.eval()

        self.target_layer = self._get_layer(target_layer)


        self.gradients = None

        self.activations = None


        # Hook for gradients (backward)

        self.target_layer.register_backward_hook(self.save_gradient)

        # Hook for activations (forward)

        self.target_layer.register_forward_hook(self.save_activation)

    
    def _get_layer(self, target_layer):

        # Traverse model to fetch target layer (supports strings like 'layer4')

        layer = self.model

        for name in target_layer.split('.'):

            layer = getattr(layer, name)

        return layer
    

    def _save_gradient(self, module, grad_input, grad_output):

        self.gradients = grad_output[0]  # Save gradient of output wrt layer

    def _save_activation(self, module, input, output):
         

        self.activations = output # Save forward activation maps
    
    def generate(self, input_tensor, target_class = None):

        self.model.zero_grad()
         
        output = self.model(input_tensor)


        if target_class is None:
            
            target_class = output.argmax(dim=1).item()

        target = output[0, target_class]

        target.backward()


        gradients = self.gradients[0]  # [C, H, W]

        activations = self.activations[0] # [C, H, W]


        weights = gradients.mean(dim=(1, 2)) # Global average pooling: [C]


        # Weighted combination of activations

        cam = torch.zeros(activations.shape[1:], dtype=torch.float32)


        for i, w in enumerate(weights):

            cam += w * activations[i]

        cam = F.relu(cam) # Apply ReLU

        cam  = cam - cam.min()

        cam = cam / cam.max()

        cam  = cam.cpu().detach().numpy()


        return cam


    def overlay(self, img_pil, cam, alpha=0.5):

        img_np = np.array(img_pil.resize((224, 224)))/ 255.0

        cam_resized  = cv2.resize(cam, (224, 224))

        heatmap = plt.cm.jet(cam_resized)[...,:3]

        overlay = alpha * img_np + (1 - alpha) * heatmap

        overlay = np.clip(overlay, 0, 1)

        return overlay
    
    
    
    
gradcam = GradCam(model, target_layer='layer4')

cam = gradcam.generate(input_tensor, target_class=243)  # Optional: use argmax if None

overlay = gradcam.overlay(original_image, cam)


plt.imshow('overlay')

plt.axis('off')

plt.title("Grad-CAM Overlay")

plt.show()