# SignXAI2 PyTorch Tutorial - Image Classification

This tutorial demonstrates how to use SignXAI2 for explaining image classification models with PyTorch.

## Setup

⚠️ **Data Requirements**: This tutorial requires example data from the GitHub repository. Please ensure you have downloaded the necessary data files or cloned the repository.

First, let's download the signxai2 package and a sample image to work with:

In [None]:
# Download the signxai2 package if not already installed
 !pip install signxai2[pytorch]

# Download an example image
import urllib.request

# Download an image of a dog
url = "http://vision.stanford.edu/aditya86/ImageNetDogs/images/n02106030-collie/n02106030_16370.jpg"
urllib.request.urlretrieve(url, "dog.jpg")

## PyTorch Implementation

Now let's do the same with PyTorch:

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.models as models
import torchvision.transforms as transforms
from signxai import explain, list_methods
from signxai.utils.utils import normalize_heatmap

# Load the pre-trained model
model = models.vgg16(pretrained=True)
model.eval()

# Remove softmax layer (critical for explanations)
model.classifier[-1] = torch.nn.Identity()

# Load and preprocess the image
img_path = "dog.jpg"
img = Image.open(img_path).convert('RGB')

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_tensor = preprocess(img).unsqueeze(0)  # Add batch dimension
img_np = np.array(img.resize((224, 224))) / 255.0  # For visualization

# Make prediction
with torch.no_grad():
    output = model(input_tensor)

# Get the predicted class
_, predicted_idx = torch.max(output, 1)

In [None]:
# Calculate explanations with different methods
methods = [
    "gradient",
    "gradient_x_input",
    "integrated_gradients",
    "smoothgrad",
    "grad_cam",
    "lrp_z",
    "lrp_epsilon_0_1",
    "lrpsign_z"  # The SIGN method
]

explanations = {}
for method in methods:
    explanations[method] = explain(
        model=model,
        x=input_tensor,
        method_name=method,
        target_class=predicted_idx.item()
    )

In [None]:
# Visualize explanations
fig, axs = plt.subplots(2, 4, figsize=(20, 10))
axs = axs.flatten()

# Original image
axs[0].imshow(img_np)
axs[0].set_title('Original Image', fontsize=14)
axs[0].axis('off')

# Explanations
for i, method in enumerate(methods[:7]):
    explanation = explanations[method][0].sum(axis=0)
    axs[i+1].imshow(normalize_heatmap(explanation), cmap='seismic', clim=(-1, 1))
    axs[i+1].set_title(method, fontsize=14)
    axs[i+1].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Highlight the difference between standard LRP and SIGN
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.imshow(img_np)
plt.title('Original Image', fontsize=14)
plt.axis('off')

plt.subplot(1, 3, 2)
# For PyTorch, we need to handle the tensor format
lrp_z_expl = explanations['lrp_z']
if hasattr(lrp_z_expl, 'detach'):
    lrp_z_expl = lrp_z_expl.detach().cpu().numpy()
if lrp_z_expl.ndim == 4:
    lrp_z_expl = lrp_z_expl[0]
if lrp_z_expl.shape[0] == 3:  # CHW format
    lrp_z_expl = lrp_z_expl.transpose(1, 2, 0)
lrp_z_heatmap = lrp_z_expl.sum(axis=-1) if lrp_z_expl.ndim == 3 else lrp_z_expl
plt.imshow(normalize_heatmap(lrp_z_heatmap), cmap='seismic', clim=(-1, 1))
plt.title('LRP-Z', fontsize=14)
plt.axis('off')

plt.subplot(1, 3, 3)
# For PyTorch, we need to handle the tensor format
lrpsign_z_expl = explanations['lrpsign_z']
if hasattr(lrpsign_z_expl, 'detach'):
    lrpsign_z_expl = lrpsign_z_expl.detach().cpu().numpy()
if lrpsign_z_expl.ndim == 4:
    lrpsign_z_expl = lrpsign_z_expl[0]
if lrpsign_z_expl.shape[0] == 3:  # CHW format
    lrpsign_z_expl = lrpsign_z_expl.transpose(1, 2, 0)
lrpsign_z_heatmap = lrpsign_z_expl.sum(axis=-1) if lrpsign_z_expl.ndim == 3 else lrpsign_z_expl
plt.imshow(normalize_heatmap(lrpsign_z_heatmap), cmap='seismic', clim=(-1, 1))
plt.title('LRP-SIGN', fontsize=14)
plt.axis('off')

plt.tight_layout()
plt.show()