# OCT Image Classification - Explainable AI (XAI)

## Overview
This notebook provides explainability for the OCT classification model using multiple XAI techniques:
- **Grad-CAM**: Gradient-weighted Class Activation Mapping
- **LIME**: Local Interpretable Model-agnostic Explanations
- **Integrated Gradients**: Attribution scores showing pixel importance

## Purpose
Make the deep learning model's decisions interpretable by:
1. Highlighting which regions of the retinal image influenced the prediction
2. Providing quantitative metrics about model focus
3. Comparing different explanation methods
4. Building trust in model predictions for clinical use


## 1. Setup and Installation


In [None]:
# Install XAI requirements (uncomment if needed)
# !pip install -r requirements_xai.txt


## 2. Import Libraries


In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import json
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Import XAI utilities
from xai_utils import (
    preprocess_image_for_xai,
    generate_gradcam,
    generate_integrated_gradients,
    generate_lime_explanation,
    apply_colormap_on_image,
    create_comparison_plot,
    calculate_explanation_metrics,
    visualize_with_bounding_boxes,
    CLASS_NAMES
)

print("✓ Libraries imported successfully")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")


## 3. Configuration


In [None]:
# Configuration
CONFIG = {
    'model_path': 'classification_models/best_oct_classifier.pth',
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'num_classes': 4,
    'img_size': 224,
    'output_dir': 'xai_explanations'
}

# Create output directory
os.makedirs(CONFIG['output_dir'], exist_ok=True)

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

print(f"\nClass names: {CLASS_NAMES}")


## 4. Load Pre-trained Classification Model


In [None]:
def create_classification_model(num_classes=4, pretrained=False):
    """Create ResNet50 classifier (same architecture as training)"""
    model = models.resnet50(pretrained=pretrained)
    num_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(num_features, 512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, num_classes)
    )
    return model

# Load model
print(f"Loading model from: {CONFIG['model_path']}")
model = create_classification_model(num_classes=CONFIG['num_classes'])

if os.path.exists(CONFIG['model_path']):
    checkpoint = torch.load(CONFIG['model_path'], map_location=CONFIG['device'])
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(CONFIG['device'])
    model.eval()
    print("✓ Model loaded successfully!")
    print(f"  Validation accuracy: {checkpoint['val_acc']:.2f}%")
else:
    print(f"❌ Model not found at {CONFIG['model_path']}")
    print("Please train the classification model first.")


## 5. Single Image Explanation Function

The main function that generates all XAI explanations for a single image.


In [None]:
def explain_prediction(image_path, save_results=True):
    """
    Generate all XAI explanations for a single image
    
    Args:
        image_path: Path to input image
        save_results: Whether to save results to disk
    
    Returns:
        Dictionary containing all explanations and metrics
    """
    print(f"\n{'='*70}")
    print(f"Explaining prediction for: {os.path.basename(image_path)}")
    print(f"{'='*70}")
    
    # Preprocess image
    image_tensor, original_image = preprocess_image_for_xai(image_path)
    image_tensor = image_tensor.to(CONFIG['device'])
    
    # 1. Generate Grad-CAM
    print("\n[1/3] Generating Grad-CAM...")
    gradcam_heatmap, pred_class, confidence, class_probs = generate_gradcam(
        model, image_tensor
    )
    
    # Resize original for overlay
    original_resized = cv2.resize(original_image, (224, 224))
    gradcam_overlay = apply_colormap_on_image(original_resized, gradcam_heatmap, alpha=0.4)
    
    print(f"  ✓ Predicted: {CLASS_NAMES[pred_class]} (Confidence: {confidence*100:.1f}%)")
    
    # 2. Generate Integrated Gradients
    print("\n[2/3] Generating Integrated Gradients...")
    ig_map, _, _, _ = generate_integrated_gradients(model, image_tensor, n_steps=50)
    ig_overlay = apply_colormap_on_image(original_resized, ig_map, alpha=0.4)
    print("  ✓ Done")
    
    # 3. Generate LIME Explanation
    print("\n[3/3] Generating LIME explanation (this may take a minute)...")
    lime_image, lime_mask, _, _, _ = generate_lime_explanation(
        model, image_tensor, original_image, num_samples=500, num_features=5
    )
    print("  ✓ Done")
    
    # Calculate metrics
    print("\nCalculating explanation metrics...")
    gradcam_metrics = calculate_explanation_metrics(gradcam_heatmap)
    ig_metrics = calculate_explanation_metrics(ig_map)
    
    # Create bounding box visualization
    gradcam_bbox = visualize_with_bounding_boxes(original_resized, gradcam_heatmap, num_regions=5)
    
    # Create comparison plot
    print("\nCreating comparison visualization...")
    fig = create_comparison_plot(
        original_resized, gradcam_overlay, lime_image, ig_overlay,
        pred_class, confidence, class_probs
    )
    
    # Save results if requested
    if save_results:
        base_name = Path(image_path).stem
        output_prefix = os.path.join(CONFIG['output_dir'], f"{base_name}_{CLASS_NAMES[pred_class]}")
        
        # Save comparison plot
        comparison_path = f"{output_prefix}_comparison.png"
        fig.savefig(comparison_path, dpi=150, bbox_inches='tight')
        
        # Save individual visualizations
        cv2.imwrite(f"{output_prefix}_gradcam.png", cv2.cvtColor(gradcam_overlay, cv2.COLOR_RGB2BGR))
        cv2.imwrite(f"{output_prefix}_lime.png", cv2.cvtColor(lime_image, cv2.COLOR_RGB2BGR))
        cv2.imwrite(f"{output_prefix}_ig.png", cv2.cvtColor(ig_overlay, cv2.COLOR_RGB2BGR))
        cv2.imwrite(f"{output_prefix}_gradcam_bbox.png", cv2.cvtColor(gradcam_bbox, cv2.COLOR_RGB2BGR))
        
        # Save metrics
        metrics_dict = {
            'image_path': image_path,
            'predicted_class': CLASS_NAMES[pred_class],
            'confidence': float(confidence),
            'class_probabilities': {CLASS_NAMES[i]: float(class_probs[i]) for i in range(len(CLASS_NAMES))},
            'gradcam_metrics': gradcam_metrics,
            'integrated_gradients_metrics': ig_metrics
        }
        
        metrics_path = f"{output_prefix}_metrics.json"
        with open(metrics_path, 'w') as f:
            json.dump(metrics_dict, f, indent=2)
        
        print(f"\n✓ Results saved to: {CONFIG['output_dir']}/")
    
    # Display results
    plt.show()
    
    # Print detailed metrics
    print(f"\n{'='*70}")
    print("EXPLANATION METRICS")
    print(f"{'='*70}")
    print(f"\nPrediction: {CLASS_NAMES[pred_class]} (Confidence: {confidence*100:.1f}%)")
    print(f"\nClass Probabilities:")
    for i, (name, prob) in enumerate(zip(CLASS_NAMES, class_probs)):
        bar = '█' * int(prob * 50)
        print(f"  {name:8s}: {prob*100:5.1f}% {bar}")
    
    print(f"\nGrad-CAM Metrics:")
    print(f"  Attribution coverage: {gradcam_metrics['attribution_coverage_percent']:.1f}%")
    print(f"  Peak location: ({gradcam_metrics['peak_activation_location']['x']}, {gradcam_metrics['peak_activation_location']['y']})")
    print(f"  Mean attribution: {gradcam_metrics['mean_attribution']:.3f}")
    
    print(f"\nIntegrated Gradients Metrics:")
    print(f"  Attribution coverage: {ig_metrics['attribution_coverage_percent']:.1f}%")
    print(f"  Peak location: ({ig_metrics['peak_activation_location']['x']}, {ig_metrics['peak_activation_location']['y']})")
    print(f"  Mean attribution: {ig_metrics['mean_attribution']:.3f}")
    
    print(f"{'='*70}\n")
    
    return {
        'predicted_class': CLASS_NAMES[pred_class],
        'confidence': confidence,
        'gradcam_overlay': gradcam_overlay,
        'lime_image': lime_image,
        'ig_overlay': ig_overlay
    }

print("✓ Explanation function defined")


In [None]:
# Example: Explain a single image
# Update this path to your image
image_path = 'uploads/example_image.jpg'  # Change this to your image path

# Check if file exists
if not os.path.exists(image_path):
    print(f"⚠️ Image not found: {image_path}")
    print("\nPlease update the image_path variable to point to a valid OCT image.")
else:
    # Generate explanations
    results = explain_prediction(image_path, save_results=True)


## 7. Interpretation Guide

### Understanding the Explanations

#### Grad-CAM (Gradient-weighted Class Activation Mapping)
- **What it shows**: Regions that most strongly influence the model's prediction
- **Colors**: Red = high importance, Blue = low importance
- **Interpretation**: Areas highlighted in red are where the model "looks" to make its decision
- **Clinical use**: Verify the model focuses on pathological features (fluid, deposits, etc.) not artifacts

#### LIME (Local Interpretable Model-agnostic Explanations)
- **What it shows**: Superpixels (image regions) that contribute to the prediction
- **Colors**: Highlighted boundaries show important regions
- **Interpretation**: The model's decision is based on these specific image patches
- **Clinical use**: Understand which anatomical regions drive the diagnosis

#### Integrated Gradients
- **What it shows**: Pixel-level attribution showing each pixel's contribution
- **Colors**: Red = positive contribution, Blue = negative contribution
- **Interpretation**: Quantifies how much each pixel influences the final prediction
- **Clinical use**: Fine-grained analysis of which features matter most

### Clinical Validation Checklist
✓ Does the model focus on relevant anatomical structures?  
✓ Are pathological features (fluid, deposits) highlighted?  
✓ Is the model ignoring irrelevant artifacts or edges?  
✓ Do all three XAI methods show consistent focus areas?  
✓ Does the explanation align with clinical reasoning?
