<a href="https://colab.research.google.com/github/tharushaliyanagama/OralCancerEarlyDetection-DSGP/blob/Image-Prediction-and-XAI/Grad_CAM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import numpy as np
import shap
import matplotlib.pyplot as plt
import cv2
import os
from typing import Tuple, List


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class_names = ['cancer', 'non-cancer', 'leukoplakia']

In [3]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [4]:
# Load the trained model
def load_model(model_path: str, num_classes: int = 3) -> nn.Module:
    """Load the trained ResNet50 model with correct number of classes"""
    model = models.resnet50(pretrained=False)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)

    # Load state dict with strict=False to handle size mismatches
    state_dict = torch.load(model_path, map_location=device)

    # Handle size mismatches for fc layer
    if 'fc.weight' in state_dict and state_dict['fc.weight'].shape[0] != num_classes:
        print(f"Warning: Number of classes in model ({state_dict['fc.weight'].shape[0]}) doesn't match expected ({num_classes})")
        del state_dict['fc.weight']
        del state_dict['fc.bias']

    model.load_state_dict(state_dict, strict=False)
    model = model.to(device)
    model.eval()
    return model


In [5]:
# Grad-CAM implementation
class GradCAM:
    def __init__(self, model: nn.Module, target_layer: nn.Module):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        # Register hooks
        self.target_layer.register_forward_hook(self.save_activation)
        self.target_layer.register_backward_hook(self.save_gradient)

    def save_activation(self, module, input, output):
        self.activations = output.detach()

    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def backward(self, outputs: torch.Tensor, class_idx: int):
        outputs[:, class_idx].sum().backward(retain_graph=True)

    def generate(self, x: torch.Tensor, class_idx: int) -> np.ndarray:
        # Forward pass
        self.model.zero_grad()
        output = self.forward(x)

        # Backward pass for specific class
        self.backward(output, class_idx)

        # Pool the gradients and calculate weights
        pooled_gradients = torch.mean(self.gradients, dim=[2, 3], keepdim=True)

        # Weight the activations
        weighted_activations = pooled_gradients * self.activations
        heatmap = torch.mean(weighted_activations, dim=1).squeeze()
        heatmap = torch.relu(heatmap)  # Apply ReLU

        # Normalize heatmap
        heatmap /= torch.max(heatmap)
        heatmap = heatmap.cpu().numpy()

        return heatmap