# Model Explainability (Grad-CAM)

This notebook loads a pre-trained model and uses **Grad-CAM** (Gradient-weighted Class Activation Mapping) to visualize the regions of an image that influenced the model's prediction (AI-generated vs. Human).

**Prerequisites:**
- A trained model file (e.g., `best_model.pth`) must exist.
- The `pytorch-grad-cam` library must be installed.

In [None]:
import os
import torch
import torch.nn as nn
import timm
import numpy as np
import pandas as pd
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

# Explainability libraries
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image


In [None]:
# --- Configuration ---
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMG_SIZE = 256

MODEL_NAME = 'convnext_base.clip_laion2b_augreg_ft_in1k' 
# MODEL_NAME = 'vit_small_patch14_dinov2.lvd142m' # Example if you used this one

MODEL_PATH = 'best_model.pth' 

DATA_DIR = "/kaggle/input/ai-vs-human-generated-dataset/" 
TEST_CSV_PATH = "/kaggle/input/ai-vs-human-generated-dataset/test.csv"

print(f"Device: {DEVICE}")

In [None]:
# --- Transforms ---
val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

In [None]:
# --- Load Model ---
def load_trained_model(model_name, model_path, device):
    print(f"Creating model: {model_name}")
    model = timm.create_model(model_name, pretrained=False, num_classes=2)
    
    print(f"Loading weights from {model_path}...")
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)
    print("Weights loaded successfully.")
    
    model.to(device)
    model.eval()
    return model

model = load_trained_model(MODEL_NAME, MODEL_PATH, DEVICE)

In [None]:
# --- Explainability Helpers ---
def get_target_layer(model, model_name):
    """
    Identifies the target layer for Grad-CAM based on model architecture.
    """
    target_layer = None
    
    # ConvNeXt specific
    if 'convnext' in model_name:
        # usually the last block of the last stage
        if hasattr(model, 'stages'):
            target_layer = model.stages[-1].blocks[-1]
            
    # EfficientNet specific
    elif 'efficientnet' in model_name:
        if hasattr(model, 'conv_head'):
            target_layer = model.conv_head
            
    # ResNet / General fallback
    if target_layer is None:
        # Try to find the last Conv2d layer
        layers = [module for module in model.modules() if isinstance(module, nn.Conv2d)]
        if layers:
            target_layer = layers[-1]
            
    # If still None (e.g. ViT), GradCAM might need specific handling for Transformers (reshape_transform)
    # For this snippet, we assume CNN-like or hybrid architectures supported by default
    
    print(f"Target Layer for Grad-CAM: {target_layer}")
    return target_layer

def run_grad_cam(model, img_tensor, target_layer):
    """
    Generates Grad-CAM heatmap.
    """
    # Initialize GradCAM
    cam = GradCAM(model=model, target_layers=[target_layer])
    
    # Add batch dimension
    input_tensor = img_tensor.unsqueeze(0).to(DEVICE)
    
    # We can target a specific class, or None for the highest predicted class
    targets = None 
    
    # Generate CAM
    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
    
    # Take the first image in the batch
    grayscale_cam = grayscale_cam[0, :]
    
    return grayscale_cam

def visualize_cam(img_tensor, heatmap, title="Grad-CAM"):
    """
    Plots the original image and the heatmap overlay.
    """
    # Denormalize image for visualization
    img_np = img_tensor.permute(1, 2, 0).cpu().numpy()
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_np = std * img_np + mean
    img_np = np.clip(img_np, 0, 1)
    
    # Create overlay
    visualization = show_cam_on_image(img_np, heatmap, use_rgb=True)
    
    plt.figure(figsize=(12, 6))
    
    plt.subplot(1, 2, 1)
    plt.imshow(img_np)
    plt.title("Original Image")
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(visualization)
    plt.title(title)
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
# --- Run Inference & Explanation on a Sample ---

df = pd.read_csv(TEST_CSV_PATH)
sample_row = df.sample(1).iloc[0]
# Adjust column name based on your CSV structure ('id' or 'file_name')
fname = sample_row['id'] if 'id' in sample_row else sample_row['file_name']
img_path = os.path.join(DATA_DIR, fname)

# 2. Process and Explain
if os.path.exists(img_path):
    try:
        # Load Image
        image = Image.open(img_path).convert('RGB')
        input_tensor = val_transform(image)
        
        # Get Prediction
        with torch.no_grad():
            output = model(input_tensor.unsqueeze(0).to(DEVICE))
            probs = torch.nn.functional.softmax(output, dim=1)
            pred_idx = torch.argmax(probs).item()
            conf = probs[0][pred_idx].item()
            
        label_map = {0: "Real/Human", 1: "Fake/AI"}
        pred_label = label_map.get(pred_idx, str(pred_idx))
        
        print(f"Prediction: {pred_label} (Confidence: {conf:.4f})")
        
        # Run Grad-CAM
        target_layer = get_target_layer(model, MODEL_NAME)
        if target_layer:
            heatmap = run_grad_cam(model, input_tensor, target_layer)
            visualize_cam(input_tensor, heatmap, title=f"Grad-CAM: {pred_label}")
        else:
            print("Could not identify target layer for this model architecture.")
            
    except Exception as e:
        print(f"Error during processing: {e}")
else:
    print("Image file not found. Please check the path.")