XAI Comparison: Grad-CAM vs LIME

Setup & Imports

In [None]:
# System
import os
import numpy as np
import matplotlib.pyplot as plt

# PyTorch & Vision
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

# Pretrained model
import timm

# Grad-CAM
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# LIME
from lime import lime_image
from skimage.segmentation import mark_boundaries
import torchvision.transforms.functional as TF
import torch.nn.functional as F

# Dataset / Metrics
import pandas as pd

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Load Model & Test Dataset

In [None]:
# 1) Reconstruct model architecture, load weights
model = timm.create_model('vit_base_patch16_224', pretrained=False)
model.head = nn.Linear(model.head.in_features, 2)
# load saved state dict (remove 'module.' if necessary)
state = torch.load('models/vit_ai_detection_10.pth', map_location=device)
new_state = {}
for k,v in state.items():
    name = k.replace('module.','') 
    new_state[name] = v
model.load_state_dict(new_state)
model = model.to(device)
model.eval()

# 2) Prepare test dataset (same transform as training)
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

class ImageDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.df = pd.read_csv(csv_file)
        self.df['file_name'] = self.df['file_name'].apply(lambda x: os.path.join(root_dir, x))
        self.transform = transform
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        img = Image.open(self.df.iloc[idx]['file_name']).convert('RGB')
        label = int(self.df.iloc[idx]['label'])
        if self.transform:
            img = self.transform(img)
        return img, label

test_dataset = ImageDataset('datasets/test.csv', 'datasets', transform=transform)
# pick one sample
image_tensor, true_label = test_dataset[0]
# for visualization later:
rgb = image_tensor.permute(1,2,0).cpu().numpy()
rgb_vis = (rgb*0.5 + 0.5).clip(0,1)

Grad-CAM Explanation

In [None]:
# reshape_transform for ViT (remove class token, reshape to H×W)
def reshape_transform(tensor, height=14, width=14):
    result = tensor[:,1:,:].reshape(tensor.size(0), height, width, tensor.size(2))
    return result.permute(0,3,1,2)

# unwrap DataParallel if used
base_model = model.module if isinstance(model, nn.DataParallel) else model

# choose the last transformer block’s norm1
target_layers = [base_model.blocks[-1].norm1]

cam = GradCAM(
    model=base_model,
    target_layers=target_layers,
    reshape_transform=reshape_transform,
    use_cuda=torch.cuda.is_available()
)

# prepare input batch
input_tensor = image_tensor.unsqueeze(0).to(device)
targets = [ClassifierOutputTarget(true_label)]

# compute CAM
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
# overlay
cam_vis = show_cam_on_image(rgb_vis, grayscale_cam, use_rgb=True)

# plot
plt.figure(figsize=(6,6))
plt.imshow(cam_vis)
plt.title(f"Grad-CAM (True={true_label})")
plt.axis('off')
plt.show()

LIME Explanation

In [None]:
# convert image to uint8 HWC for LIME
rgb_uint8 = (rgb_vis * 255).astype(np.uint8)

# define the wrapper for LIME
def predict_fn(images):
    model.eval()
    batch = []
    for img in images:
        # HWC uint8 -> CHW float tensor
        t = TF.to_tensor(Image.fromarray(img))
        t = TF.resize(t, (224,224))
        t = TF.normalize(t, [0.5]*3, [0.5]*3)
        batch.append(t)
    batch = torch.stack(batch).to(device)
    with torch.no_grad():
        logits = model(batch)
        probs  = F.softmax(logits, dim=1).cpu().numpy()
    return probs

explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(
    rgb_uint8,
    classifier_fn=predict_fn,
    top_labels=2,
    hide_color=0,
    num_samples=1000
)

# get mask for the true class
temp, mask = explanation.get_image_and_mask(
    true_label, positive_only=True, num_features=10, hide_rest=False
)

# visualize
plt.figure(figsize=(6,6))
plt.imshow(mark_boundaries(temp, mask))
plt.title(f"LIME (True={true_label})")
plt.axis('off')
plt.show()