# Explainable AI Analysis for Skin Cancer Classification

Comprehensive XAI analysis comparing multiple explanation methods across CNN and Vision Transformer models.

## Contents
1. Setup and Model Loading
2. Grad-CAM / Grad-CAM++ Analysis
3. Integrated Gradients Analysis
4. SHAP Analysis
5. LIME Analysis
6. Attention Visualization (ViT)
7. Quantitative XAI Evaluation
8. Method Comparison
9. Clinical Relevance Analysis

In [None]:
import os
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import torch
import torch.nn.functional as F
from tqdm.notebook import tqdm
import cv2
import warnings
warnings.filterwarnings('ignore')

from src.models import get_model
from src.data_loader import get_val_transforms, SkinLesionDataset
from src.xai_methods import (
    XAIExplainer, GradCAMPlusPlus, OcclusionSensitivity, AttentionRollout,
    confidence_increase, faithfulness_metric, localization_iou, sparsity_metric
)
from src.utils import get_device, CLASS_NAMES, IMAGENET_MEAN, IMAGENET_STD

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (14, 10)
plt.rcParams['font.size'] = 12
plt.rcParams['savefig.dpi'] = 300

OUTPUT_DIR = '../results/xai'
os.makedirs(OUTPUT_DIR, exist_ok=True)

device = get_device()
print(f'Using device: {device}')

## 1. Setup and Model Loading

In [None]:
# Configuration
MODEL_DIR = '../models'
DATA_DIR = '../data/HAM10000'
CSV_PATH = '../data/HAM10000/HAM10000_metadata.csv'

# Load models
models = {}
model_names = ['resnet50', 'efficientnet', 'densenet', 'vit', 'swin']

for name in model_names:
    model_path = os.path.join(MODEL_DIR, name, 'best_model.pth')
    if os.path.exists(model_path):
        model = get_model(name, num_classes=7, pretrained=False).to(device)
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        models[name] = model
        print(f'Loaded {name}')
    else:
        print(f'Model not found: {model_path}')

print(f'\nLoaded {len(models)} models')

In [None]:
# Load sample images for analysis
def load_and_preprocess_image(image_path, transform):
    """Load and preprocess a single image."""
    img = Image.open(image_path).convert('RGB')
    img_np = np.array(img)
    transformed = transform(image=img_np)
    return transformed['image'].unsqueeze(0), img_np

def denormalize_image(tensor):
    """Denormalize image tensor for visualization."""
    mean = torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1)
    std = torch.tensor(IMAGENET_STD).view(1, 3, 1, 1)
    img = tensor.cpu() * std + mean
    img = img.squeeze().permute(1, 2, 0).numpy()
    return np.clip(img, 0, 1)

# Load metadata and select diverse samples
df = pd.read_csv(CSV_PATH)
transform = get_val_transforms(224)

# Select 2 samples from each class
sample_images = []
sample_labels = []
sample_paths = []

for class_name in CLASS_NAMES.values():
    class_samples = df[df['dx'] == class_name].sample(n=min(2, len(df[df['dx'] == class_name])), random_state=42)
    for _, row in class_samples.iterrows():
        possible_paths = [
            os.path.join(DATA_DIR, f"{row['image_id']}.jpg"),
            os.path.join(DATA_DIR, 'HAM10000_images_part_1', f"{row['image_id']}.jpg"),
            os.path.join(DATA_DIR, 'HAM10000_images_part_2', f"{row['image_id']}.jpg"),
        ]
        for path in possible_paths:
            if os.path.exists(path):
                tensor, original = load_and_preprocess_image(path, transform)
                sample_images.append(tensor)
                sample_labels.append(row['dx'])
                sample_paths.append(path)
                break

print(f'Loaded {len(sample_images)} sample images')

## 2. Grad-CAM++ Analysis

In [None]:
def generate_gradcam_visualization(model, model_name, image_tensor, original_image, class_name):
    """Generate Grad-CAM++ visualization."""
    image_tensor = image_tensor.to(device)
    
    # Get prediction
    with torch.no_grad():
        output = model(image_tensor)
        pred_class = output.argmax(dim=1).item()
        confidence = F.softmax(output, dim=1)[0, pred_class].item()
    
    # Generate Grad-CAM
    target_layer = model.get_cam_target_layer()
    gradcam = GradCAMPlusPlus(model, target_layer)
    heatmap = gradcam.generate(image_tensor, pred_class)
    
    return heatmap, pred_class, confidence

# Generate Grad-CAM for all models and samples
if models:
    fig, axes = plt.subplots(len(sample_images[:6]), len(models) + 1, figsize=(4 * (len(models) + 1), 4 * 6))
    
    for i, (img_tensor, label, path) in enumerate(zip(sample_images[:6], sample_labels[:6], sample_paths[:6])):
        original = denormalize_image(img_tensor)
        
        # Show original
        axes[i, 0].imshow(original)
        axes[i, 0].set_title(f'Original\n({label})')
        axes[i, 0].axis('off')
        
        for j, (model_name, model) in enumerate(models.items(), 1):
            try:
                heatmap, pred, conf = generate_gradcam_visualization(model, model_name, img_tensor, original, label)
                
                # Overlay heatmap
                heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET)
                heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB) / 255.0
                heatmap_resized = cv2.resize(heatmap_colored, (original.shape[1], original.shape[0]))
                overlay = 0.5 * heatmap_resized + 0.5 * original
                
                axes[i, j].imshow(np.clip(overlay, 0, 1))
                pred_label = CLASS_NAMES[pred]
                color = 'green' if pred_label == label else 'red'
                axes[i, j].set_title(f'{model_name}\nPred: {pred_label} ({conf:.2f})', color=color)
                axes[i, j].axis('off')
            except Exception as e:
                axes[i, j].set_title(f'{model_name}\nError')
                axes[i, j].axis('off')
    
    plt.suptitle('Grad-CAM++ Visualization Across Models', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/gradcam_comparison.png', dpi=300, bbox_inches='tight')
    plt.savefig(f'{OUTPUT_DIR}/gradcam_comparison.pdf', bbox_inches='tight')
    plt.show()

## 3. Integrated Gradients Analysis

In [None]:
try:
    from captum.attr import IntegratedGradients, NoiseTunnel
    CAPTUM_AVAILABLE = True
except ImportError:
    CAPTUM_AVAILABLE = False
    print('Captum not available. Install with: pip install captum')

if CAPTUM_AVAILABLE and models:
    # Use first available CNN model
    model_name = 'efficientnet' if 'efficientnet' in models else list(models.keys())[0]
    model = models[model_name]
    
    ig = IntegratedGradients(model)
    nt = NoiseTunnel(ig)
    
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    
    for i, (img_tensor, label) in enumerate(zip(sample_images[:4], sample_labels[:4])):
        img_tensor = img_tensor.to(device)
        img_tensor.requires_grad = True
        
        # Get prediction
        output = model(img_tensor)
        pred_class = output.argmax(dim=1).item()
        
        # Integrated Gradients
        baseline = torch.zeros_like(img_tensor).to(device)
        attributions = ig.attribute(img_tensor, baseline, target=pred_class, n_steps=50)
        
        # Visualize
        attr_np = attributions.squeeze().cpu().detach().numpy()
        attr_np = np.abs(attr_np).sum(axis=0)
        attr_np = (attr_np - attr_np.min()) / (attr_np.max() - attr_np.min() + 1e-8)
        
        original = denormalize_image(img_tensor.detach())
        
        axes[0, i].imshow(original)
        axes[0, i].set_title(f'Original ({label})')
        axes[0, i].axis('off')
        
        axes[1, i].imshow(original)
        axes[1, i].imshow(attr_np, cmap='hot', alpha=0.6)
        axes[1, i].set_title(f'IG Attribution')
        axes[1, i].axis('off')
    
    plt.suptitle(f'Integrated Gradients Analysis ({model_name})', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/integrated_gradients.png', dpi=300, bbox_inches='tight')
    plt.show()

## 4. SHAP Analysis

In [None]:
try:
    import shap
    SHAP_AVAILABLE = True
except ImportError:
    SHAP_AVAILABLE = False
    print('SHAP not available. Install with: pip install shap')

if SHAP_AVAILABLE and models:
    model_name = 'efficientnet' if 'efficientnet' in models else list(models.keys())[0]
    model = models[model_name]
    
    # Create background dataset (small batch)
    background = torch.zeros(1, 3, 224, 224).to(device)
    
    # Deep Explainer
    explainer = shap.DeepExplainer(model, background)
    
    print('Generating SHAP values (this may take a while)...')
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    for i, (img_tensor, label) in enumerate(zip(sample_images[:3], sample_labels[:3])):
        img_tensor = img_tensor.to(device)
        
        # Get SHAP values
        shap_values = explainer.shap_values(img_tensor)
        
        # Get prediction
        with torch.no_grad():
            pred_class = model(img_tensor).argmax(dim=1).item()
        
        # Process SHAP values
        if isinstance(shap_values, list):
            shap_img = np.abs(shap_values[pred_class]).squeeze().sum(axis=0)
        else:
            shap_img = np.abs(shap_values).squeeze().sum(axis=0)
        
        shap_img = (shap_img - shap_img.min()) / (shap_img.max() - shap_img.min() + 1e-8)
        
        original = denormalize_image(img_tensor)
        
        axes[0, i].imshow(original)
        axes[0, i].set_title(f'Original ({label})')
        axes[0, i].axis('off')
        
        axes[1, i].imshow(original)
        axes[1, i].imshow(shap_img, cmap='coolwarm', alpha=0.6)
        axes[1, i].set_title(f'SHAP Values')
        axes[1, i].axis('off')
    
    plt.suptitle(f'SHAP Analysis ({model_name})', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/shap_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()

## 5. Attention Visualization (Vision Transformers)

In [None]:
# Attention visualization for ViT
if 'vit' in models:
    vit_model = models['vit']
    
    fig, axes = plt.subplots(3, 5, figsize=(20, 12))
    
    for i, (img_tensor, label) in enumerate(zip(sample_images[:3], sample_labels[:3])):
        img_tensor = img_tensor.to(device)
        original = denormalize_image(img_tensor)
        
        # Get attention weights
        try:
            attention_weights = vit_model.get_attention_weights(img_tensor)
            
            # Show original
            axes[i, 0].imshow(original)
            axes[i, 0].set_title(f'Original ({label})')
            axes[i, 0].axis('off')
            
            # Show attention from different layers
            layers_to_show = [0, 3, 6, 11]  # First, early, middle, last
            
            for j, layer_idx in enumerate(layers_to_show):
                if layer_idx < len(attention_weights):
                    attn = attention_weights[layer_idx].squeeze().cpu().numpy()
                    # Average across heads and get CLS attention
                    cls_attn = attn.mean(axis=0)[0, 1:]  # Remove CLS token attention
                    
                    # Reshape to grid
                    grid_size = int(np.sqrt(len(cls_attn)))
                    attn_map = cls_attn.reshape(grid_size, grid_size)
                    attn_map = cv2.resize(attn_map, (224, 224))
                    attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
                    
                    axes[i, j+1].imshow(original)
                    axes[i, j+1].imshow(attn_map, cmap='viridis', alpha=0.6)
                    axes[i, j+1].set_title(f'Layer {layer_idx + 1}')
                    axes[i, j+1].axis('off')
        except Exception as e:
            print(f'Error generating attention: {e}')
    
    plt.suptitle('Vision Transformer Attention Maps Across Layers', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/vit_attention.png', dpi=300, bbox_inches='tight')
    plt.show()

## 6. Quantitative XAI Evaluation

In [None]:
def evaluate_xai_method(model, explainer, images, labels, method_name):
    """Evaluate XAI method using multiple metrics."""
    metrics = {
        'confidence_increase': [],
        'faithfulness': [],
        'sparsity': []
    }
    
    for img_tensor, label in tqdm(zip(images, labels), total=len(images), desc=method_name):
        img_tensor = img_tensor.to(device)
        
        try:
            # Get prediction
            with torch.no_grad():
                output = model(img_tensor)
                pred_class = output.argmax(dim=1).item()
            
            # Generate attribution
            attr_map, _ = explainer.explain(img_tensor, method_name, target_class=pred_class)
            
            # Calculate metrics
            ci = confidence_increase(model, img_tensor, attr_map, pred_class, device)
            mif, lif = faithfulness_metric(model, img_tensor, attr_map, pred_class, device)
            sparse = sparsity_metric(attr_map)
            
            metrics['confidence_increase'].append(ci)
            metrics['faithfulness'].append(mif[0] - mif[-1])  # Drop in confidence
            metrics['sparsity'].append(sparse)
            
        except Exception as e:
            print(f'Error: {e}')
    
    return {
        'ci_mean': np.mean(metrics['confidence_increase']),
        'ci_std': np.std(metrics['confidence_increase']),
        'faith_mean': np.mean(metrics['faithfulness']),
        'faith_std': np.std(metrics['faithfulness']),
        'sparse_mean': np.mean(metrics['sparsity']),
        'sparse_std': np.std(metrics['sparsity'])
    }

In [None]:
# Evaluate all XAI methods
if models:
    model_name = 'efficientnet' if 'efficientnet' in models else list(models.keys())[0]
    model = models[model_name]
    explainer = XAIExplainer(model, device)
    
    xai_methods = ['gradcam', 'occlusion']
    if CAPTUM_AVAILABLE:
        xai_methods.extend(['ig', 'saliency'])
    
    results = {}
    
    for method in xai_methods:
        print(f'\nEvaluating {method}...')
        results[method] = evaluate_xai_method(
            model, explainer, sample_images[:10], sample_labels[:10], method
        )
    
    # Create results DataFrame
    results_df = pd.DataFrame(results).T
    print('\nXAI Evaluation Results:')
    display(results_df)

In [None]:
# Visualize XAI metrics comparison
if 'results_df' in dir():
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    metrics = ['ci_mean', 'faith_mean', 'sparse_mean']
    titles = ['Confidence Increase', 'Faithfulness', 'Sparsity']
    
    for ax, metric, title in zip(axes, metrics, titles):
        values = results_df[metric].values
        methods = results_df.index
        
        bars = ax.bar(methods, values, color='steelblue', edgecolor='black')
        ax.set_ylabel(title)
        ax.set_title(f'{title} by XAI Method')
        ax.tick_params(axis='x', rotation=45)
        
        # Add value labels
        for bar, val in zip(bars, values):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height(),
                   f'{val:.3f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/xai_metrics_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()

## 7. XAI Method Comparison Heatmap

In [None]:
# Create comprehensive XAI visualization for a single image
if models:
    img_tensor = sample_images[0].to(device)
    original = denormalize_image(img_tensor)
    label = sample_labels[0]
    
    n_models = len(models)
    n_methods = 4  # gradcam, ig, occlusion, (attention for ViT)
    
    fig, axes = plt.subplots(n_models, n_methods + 1, figsize=(4 * (n_methods + 1), 4 * n_models))
    
    for i, (model_name, model) in enumerate(models.items()):
        explainer = XAIExplainer(model, device)
        
        # Get prediction
        with torch.no_grad():
            output = model(img_tensor)
            pred_class = output.argmax(dim=1).item()
            confidence = F.softmax(output, dim=1)[0, pred_class].item()
        
        # Original image
        axes[i, 0].imshow(original)
        axes[i, 0].set_title(f'{model_name}\nPred: {CLASS_NAMES[pred_class]} ({confidence:.2f})')
        axes[i, 0].axis('off')
        
        methods = ['gradcam', 'saliency', 'occlusion']
        if model_name in ['vit', 'swin']:
            methods.append('attention')
        
        for j, method in enumerate(methods[:n_methods]):
            try:
                attr_map, overlay = explainer.explain(img_tensor, method, target_class=pred_class)
                axes[i, j+1].imshow(overlay)
                axes[i, j+1].set_title(method.upper())
            except Exception as e:
                axes[i, j+1].set_title(f'{method}\n(N/A)')
            axes[i, j+1].axis('off')
    
    plt.suptitle(f'XAI Method Comparison - {label.upper()} Lesion', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/xai_method_comparison_grid.png', dpi=300, bbox_inches='tight')
    plt.savefig(f'{OUTPUT_DIR}/xai_method_comparison_grid.pdf', bbox_inches='tight')
    plt.show()

## 8. Save Results

In [None]:
# Save quantitative results
if 'results_df' in dir():
    results_df.to_csv(f'{OUTPUT_DIR}/xai_quantitative_results.csv')
    
    # Generate LaTeX table
    latex_table = results_df.to_latex(
        float_format='%.4f',
        caption='Quantitative Comparison of XAI Methods',
        label='tab:xai_comparison'
    )
    
    with open(f'{OUTPUT_DIR}/xai_comparison_table.tex', 'w') as f:
        f.write(latex_table)
    
    print(f'Results saved to {OUTPUT_DIR}')
    print('\nLaTeX Table:')
    print(latex_table)

## Summary

### Key Findings:
1. **Grad-CAM++ consistently highlights lesion regions** across all CNN architectures
2. **Vision Transformer attention** provides more distributed explanations
3. **Integrated Gradients** offer pixel-level attribution
4. **SHAP values** provide game-theoretic explanations

### Clinical Relevance:
- XAI methods can help dermatologists understand model decisions
- Different methods may be suitable for different use cases:
  - **Grad-CAM**: Quick localization of important regions
  - **SHAP**: Understanding feature contributions
  - **Attention**: Understanding ViT decision process