# Comprehensive Model Interpretation Tutorial

This notebook provides a complete guide to using AutoTimm's interpretation capabilities.

## What You'll Learn

1. **Interpretation Methods** - 6 different explanation techniques
2. **Quality Metrics** - Quantitative evaluation of explanations
3. **Interactive Visualizations** - Plotly-based exploration tools
4. **Performance Optimization** - Caching, batching, and profiling
5. **Production Best Practices** - Real-world deployment tips

## Prerequisites

```bash
pip install autotimm[all]  # Includes plotly for interactive visualizations
```

In [None]:
# Imports
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from pathlib import Path
import time

# AutoTimm imports
from autotimm import ImageClassifier
from autotimm.interpretation import (
    GradCAM,
    GradCAMPlusPlus,
    IntegratedGradients,
    SmoothGrad,
    AttentionRollout,
    AttentionFlow,
    quick_explain,
    compare_methods,
    ExplanationMetrics,
    InteractiveVisualizer,
    ExplanationCache,
    BatchProcessor,
    PerformanceProfiler,
    optimize_for_inference,
)

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("‚úì All imports successful!")

## 1. Setup: Load Model and Data

We'll use a pre-trained ResNet model for this tutorial.

In [None]:
# Load a pre-trained model
model = ImageClassifier(
    backbone="resnet50",
    num_classes=10,
    pretrained=False  # Set True if you have pre-trained weights
)
model.eval()

print(f"Model: {model.__class__.__name__}")
print(f"Backbone: resnet50")
print(f"Number of classes: 10")

In [None]:
# Create sample images for demonstration
def create_sample_image(seed=42, size=(224, 224)):
    """Create a random sample image."""
    np.random.seed(seed)
    img_array = np.random.rand(size[0], size[1], 3)
    img_array = (img_array * 255).astype(np.uint8)
    return Image.fromarray(img_array)

# Create test image
test_image = create_sample_image(seed=42)

# Display
plt.figure(figsize=(4, 4))
plt.imshow(test_image)
plt.title("Test Image")
plt.axis('off')
plt.show()

print(f"Image shape: {test_image.size}")

## 2. Quick Start: Simplest Way to Explain

The fastest way to get an explanation is using `quick_explain()`.

In [None]:
# Quick explanation with default settings
fig = quick_explain(
    model,
    test_image,
    method="gradcam",
    show_plot=True
)

print("\n‚úì Quick explanation generated!")
print("  This shows which parts of the image the model focuses on.")

## 3. Interpretation Methods: Deep Dive

AutoTimm provides 6 different interpretation methods. Let's explore each one.

### 3.1 GradCAM (Gradient-weighted Class Activation Mapping)

**How it works:** Uses gradients flowing into the final convolutional layer to highlight important regions.

**Best for:** General-purpose visualization, CNNs

**Pros:** Fast, interpretable, works well for most CNNs

**Cons:** Limited to convolutional layers

In [None]:
# Initialize GradCAM
gradcam = GradCAM(model)

# Generate explanation
heatmap_gradcam = gradcam.explain(
    test_image,
    target_class=5  # Explain prediction for class 5
)

# Visualize
gradcam.visualize(
    test_image,
    heatmap_gradcam,
    alpha=0.6,
    title="GradCAM Explanation"
)
plt.show()

print(f"Heatmap shape: {heatmap_gradcam.shape}")
print(f"Value range: [{heatmap_gradcam.min():.3f}, {heatmap_gradcam.max():.3f}]")

### 3.2 GradCAM++ (Improved GradCAM)

**How it works:** Enhanced version of GradCAM with better localization for multiple objects.

**Best for:** Images with multiple objects, better localization

**Pros:** Better than GradCAM for multiple objects

**Cons:** Slightly slower than GradCAM

In [None]:
# Initialize GradCAM++
gradcam_pp = GradCAMPlusPlus(model)

# Generate explanation
heatmap_gradcam_pp = gradcam_pp.explain(test_image, target_class=5)

# Visualize
gradcam_pp.visualize(
    test_image,
    heatmap_gradcam_pp,
    alpha=0.6,
    title="GradCAM++ Explanation"
)
plt.show()

### 3.3 Integrated Gradients

**How it works:** Integrates gradients along a path from a baseline to the input.

**Best for:** Pixel-level attributions, research

**Pros:** Theoretically sound, satisfies axioms

**Cons:** Slower, requires baseline selection

In [None]:
# Initialize Integrated Gradients
ig = IntegratedGradients(model)

# Generate explanation (this may take longer)
print("Computing Integrated Gradients (may take 10-20 seconds)...")
heatmap_ig = ig.explain(
    test_image,
    target_class=5,
    n_steps=50  # Number of integration steps (more = better but slower)
)

# Visualize
ig.visualize(
    test_image,
    heatmap_ig,
    alpha=0.6,
    title="Integrated Gradients Explanation"
)
plt.show()

print("‚úì Integrated Gradients complete!")

### 3.4 SmoothGrad

**How it works:** Averages gradients over multiple noisy versions of the input.

**Best for:** Reducing noise in gradient-based explanations

**Pros:** Smoother, more stable explanations

**Cons:** Slower (requires multiple forward passes)

In [None]:
# Initialize SmoothGrad
smoothgrad = SmoothGrad(model)

# Generate explanation
print("Computing SmoothGrad (may take 10-20 seconds)...")
heatmap_smoothgrad = smoothgrad.explain(
    test_image,
    target_class=5,
    n_samples=50,  # Number of noisy samples
    noise_level=0.15  # Standard deviation of noise
)

# Visualize
smoothgrad.visualize(
    test_image,
    heatmap_smoothgrad,
    alpha=0.6,
    title="SmoothGrad Explanation"
)
plt.show()

print("‚úì SmoothGrad complete!")

### 3.5 & 3.6 Attention Methods (Vision Transformers)

**AttentionRollout** and **AttentionFlow** are designed for Vision Transformers (ViT).

Since our example uses ResNet (CNN), we'll skip these for now. They work similarly but require a Transformer-based model.

## 4. Comparing Methods Side-by-Side

Let's compare multiple methods visually using `compare_methods()`.

In [None]:
# Compare multiple methods
fig = compare_methods(
    model,
    test_image,
    methods=["gradcam", "gradcam++"],  # Add more: "integrated_gradients", "smoothgrad"
    target_class=5,
    figsize=(12, 4)
)
plt.tight_layout()
plt.show()

print("\n‚úì Method comparison complete!")
print("  Notice the differences in highlighted regions between methods.")

## 5. Explanation Quality Metrics

How do we know if an explanation is good? Use quantitative metrics!

In [None]:
# Initialize metrics
metrics = ExplanationMetrics(model, gradcam)

print("Computing explanation quality metrics...")
print("This may take 1-2 minutes...\n")

### 5.1 Deletion Metric (Faithfulness)

**What it measures:** How much does the prediction drop when we remove important pixels?

**Interpretation:** Higher drop = more faithful explanation

**Expected:** Good explanations cause large prediction drops when important regions are deleted.

In [None]:
# Deletion metric
deletion_result = metrics.deletion(
    test_image,
    target_class=5,
    steps=20  # Number of deletion steps
)

print("Deletion Metric Results:")
print(f"  AUC: {deletion_result['auc']:.4f}")
print(f"  Final drop: {deletion_result['final_drop']:.4f}")
print(f"\n  Interpretation: Lower AUC = better (faster drop in prediction)")

# Plot deletion curve
plt.figure(figsize=(8, 5))
plt.plot(deletion_result['scores'], marker='o')
plt.xlabel('Deletion Step')
plt.ylabel('Prediction Score')
plt.title('Deletion Curve (Should Decrease)')
plt.grid(True, alpha=0.3)
plt.show()

### 5.2 Insertion Metric (Faithfulness)

**What it measures:** How much does the prediction rise when we progressively add important pixels?

**Interpretation:** Faster rise = more faithful explanation

**Expected:** Good explanations cause rapid prediction increases when important regions are added.

In [None]:
# Insertion metric
insertion_result = metrics.insertion(
    test_image,
    target_class=5,
    steps=20
)

print("Insertion Metric Results:")
print(f"  AUC: {insertion_result['auc']:.4f}")
print(f"  Final rise: {insertion_result['final_rise']:.4f}")
print(f"\n  Interpretation: Higher AUC = better (faster rise in prediction)")

# Plot insertion curve
plt.figure(figsize=(8, 5))
plt.plot(insertion_result['scores'], marker='o', color='green')
plt.xlabel('Insertion Step')
plt.ylabel('Prediction Score')
plt.title('Insertion Curve (Should Increase)')
plt.grid(True, alpha=0.3)
plt.show()

### 5.3 Sensitivity-N (Stability)

**What it measures:** How stable is the explanation under small input perturbations?

**Interpretation:** Lower sensitivity = more stable explanation

**Expected:** Good explanations shouldn't change drastically with tiny noise.

In [None]:
# Sensitivity metric
print("Computing sensitivity (may take 20-30 seconds)...")
sensitivity_result = metrics.sensitivity_n(
    test_image,
    target_class=5,
    n_samples=20,  # Number of noisy samples
    noise_level=0.15
)

print("\nSensitivity-N Results:")
print(f"  Sensitivity: {sensitivity_result['sensitivity']:.4f}")
print(f"  Std deviation: {sensitivity_result['std']:.4f}")
print(f"  Max change: {sensitivity_result['max_change']:.4f}")
print(f"\n  Interpretation: Lower values = more stable explanation")

# Plot sensitivity distribution
plt.figure(figsize=(8, 5))
plt.hist(sensitivity_result['changes'], bins=15, edgecolor='black', alpha=0.7)
plt.xlabel('Explanation Change')
plt.ylabel('Frequency')
plt.title('Distribution of Explanation Changes Under Noise')
plt.axvline(sensitivity_result['sensitivity'], color='red', 
           linestyle='--', label=f"Mean: {sensitivity_result['sensitivity']:.3f}")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

### 5.4 Sanity Checks

Two sanity checks ensure explanations are meaningful:

1. **Model Parameter Randomization**: Explanation should change with randomized model
2. **Data Randomization**: Explanation should differ for different target classes

In [None]:
# Model parameter randomization test
param_test = metrics.model_parameter_randomization_test(
    test_image,
    target_class=5
)

print("Model Parameter Randomization Test:")
print(f"  Correlation with randomized model: {param_test['correlation']:.4f}")
print(f"  Change: {param_test['change']:.4f}")
print(f"  Passes: {param_test['passes']}")
print(f"\n  ‚úì Should PASS: Explanation changes with random model")

# Data randomization test
data_test = metrics.data_randomization_test(
    test_image,
    target_class=5
)

print("\nData Randomization Test:")
print(f"  Correlation with different class: {data_test['correlation']:.4f}")
print(f"  Change: {data_test['change']:.4f}")
print(f"  Passes: {data_test['passes']}")
print(f"\n  ‚úì Should PASS: Explanation differs for different classes")

### 5.5 Pointing Game (Localization)

**What it measures:** Does the maximum attention fall within the ground-truth bounding box?

**Interpretation:** Hit = good localization

**Note:** Requires ground-truth bounding boxes (we'll simulate one here)

In [None]:
# Simulate a bounding box (x1, y1, x2, y2)
bbox = (50, 50, 150, 150)  # Example bbox

# Pointing game
pointing_result = metrics.pointing_game(
    test_image,
    bbox,
    target_class=5
)

print("Pointing Game Results:")
print(f"  Hit: {pointing_result['hit']}")
print(f"  Max location: {pointing_result['max_location']}")
print(f"  Bounding box: {pointing_result['bbox']}")
print(f"\n  Interpretation: Hit = max attention inside bbox")

# Visualize
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.imshow(heatmap_gradcam, cmap='jet', alpha=0.6)
ax.imshow(test_image, alpha=0.4)

# Draw bbox
from matplotlib.patches import Rectangle
rect = Rectangle(
    (bbox[0], bbox[1]), 
    bbox[2] - bbox[0], 
    bbox[3] - bbox[1],
    linewidth=2, 
    edgecolor='green', 
    facecolor='none'
)
ax.add_patch(rect)

# Mark max location
ax.plot(
    pointing_result['max_location'][1], 
    pointing_result['max_location'][0], 
    'r*', 
    markersize=15, 
    label='Max Attention'
)

ax.set_title('Pointing Game: Max Attention vs Ground Truth')
ax.legend()
ax.axis('off')
plt.show()

### 5.6 Metrics Summary

Let's compile all metrics into a summary.

In [None]:
# Create metrics summary
import pandas as pd

metrics_summary = pd.DataFrame([
    {
        'Metric': 'Deletion AUC',
        'Value': f"{deletion_result['auc']:.4f}",
        'Interpretation': 'Lower = better',
        'Category': 'Faithfulness'
    },
    {
        'Metric': 'Insertion AUC',
        'Value': f"{insertion_result['auc']:.4f}",
        'Interpretation': 'Higher = better',
        'Category': 'Faithfulness'
    },
    {
        'Metric': 'Sensitivity-N',
        'Value': f"{sensitivity_result['sensitivity']:.4f}",
        'Interpretation': 'Lower = better',
        'Category': 'Stability'
    },
    {
        'Metric': 'Model Param Test',
        'Value': 'PASS' if param_test['passes'] else 'FAIL',
        'Interpretation': 'Should pass',
        'Category': 'Sanity Check'
    },
    {
        'Metric': 'Data Randomization',
        'Value': 'PASS' if data_test['passes'] else 'FAIL',
        'Interpretation': 'Should pass',
        'Category': 'Sanity Check'
    },
    {
        'Metric': 'Pointing Game',
        'Value': 'HIT' if pointing_result['hit'] else 'MISS',
        'Interpretation': 'Hit = good',
        'Category': 'Localization'
    },
])

print("\n" + "="*70)
print("EXPLANATION QUALITY METRICS SUMMARY")
print("="*70)
print(metrics_summary.to_string(index=False))
print("="*70)

## 6. Interactive Visualizations

Static images are great, but interactive visualizations let you explore in detail!

In [None]:
# Check if plotly is available
try:
    viz = InteractiveVisualizer(model)
    print("‚úì Interactive visualizations available!")
    INTERACTIVE_AVAILABLE = True
except Exception as e:
    print(f"‚úó Interactive visualizations not available: {e}")
    print("  Install with: pip install plotly")
    INTERACTIVE_AVAILABLE = False

### 6.1 Basic Interactive Visualization

In [None]:
if INTERACTIVE_AVAILABLE:
    # Create interactive visualization
    fig = viz.visualize_explanation(
        test_image,
        gradcam,
        target_class=5,
        title="Interactive GradCAM",
        colorscale="Viridis",
        opacity=0.6,
        save_path="tutorial_interactive.html"
    )
    
    # Display in notebook
    fig.show()
    
    print("\n‚úì Interactive visualization created!")
    print("  Try: Zoom (scroll), Pan (drag), Hover (see values)")
    print("  Saved to: tutorial_interactive.html")
else:
    print("Skipping interactive visualizations (plotly not installed)")

### 6.2 Method Comparison (Interactive)

In [None]:
if INTERACTIVE_AVAILABLE:
    # Compare methods interactively
    explainers = {
        'GradCAM': gradcam,
        'GradCAM++': gradcam_pp,
    }
    
    fig = viz.compare_methods(
        test_image,
        explainers,
        target_class=5,
        title="Interactive Method Comparison",
        save_path="tutorial_comparison.html",
        width=1400,
        height=500
    )
    
    fig.show()
    
    print("\n‚úì Interactive comparison created!")
    print("  Saved to: tutorial_comparison.html")

### 6.3 Comprehensive HTML Report

In [None]:
if INTERACTIVE_AVAILABLE:
    # Generate comprehensive report
    report_path = viz.create_report(
        test_image,
        gradcam,
        target_class=5,
        include_statistics=True,
        save_path="tutorial_report.html",
        title="Model Interpretation Report"
    )
    
    print(f"\n‚úì Comprehensive report created!")
    print(f"  Saved to: {report_path}")
    print(f"  Open in browser to view full report with:")
    print(f"    - Prediction information")
    print(f"    - Top-5 classes")
    print(f"    - Heatmap statistics")
    print(f"    - Interactive visualization")
    print(f"    - Distribution plots")

## 7. Performance Optimization

For production systems, speed matters! Let's optimize.

### 7.1 Caching for Repeated Explanations

In [None]:
# Create cache
cache = ExplanationCache(
    cache_dir="./tutorial_cache",
    max_size_mb=100,  # 100 MB cache
    enabled=True
)

# Without caching
print("Without caching:")
start = time.time()
for i in range(5):
    heatmap = gradcam.explain(test_image, target_class=5)
time_no_cache = time.time() - start
print(f"  5 explanations: {time_no_cache:.3f}s ({time_no_cache/5:.3f}s each)")

# With caching
print("\nWith caching:")
start = time.time()
for i in range(5):
    # Check cache
    heatmap = cache.get(test_image, method="gradcam", target_class=5)
    if heatmap is None:
        # Cache miss - compute and store
        heatmap = gradcam.explain(test_image, target_class=5)
        cache.put(test_image, method="gradcam", explanation=heatmap, target_class=5)
time_with_cache = time.time() - start
print(f"  5 explanations: {time_with_cache:.3f}s ({time_with_cache/5:.3f}s each)")

# Speedup
speedup = time_no_cache / time_with_cache
print(f"\n‚úì Speedup: {speedup:.1f}x faster with caching!")

# Cache statistics
stats = cache.stats()
print(f"\nCache stats:")
print(f"  Entries: {stats['num_entries']}")
print(f"  Size: {stats['total_size_mb']:.2f} MB")
print(f"  Utilization: {stats['utilization']:.1%}")

# Cleanup
cache.clear()

### 7.2 Batch Processing

In [None]:
# Create multiple test images
test_images = [create_sample_image(seed=i) for i in range(20)]

# Sequential processing
print("Sequential processing (one-by-one):")
start = time.time()
heatmaps_seq = [gradcam.explain(img) for img in test_images]
time_seq = time.time() - start
print(f"  20 images: {time_seq:.3f}s ({time_seq/20:.3f}s per image)")

# Batch processing
print("\nBatch processing:")
processor = BatchProcessor(
    model,
    gradcam,
    batch_size=8,
    show_progress=False,
    use_cuda=False
)

start = time.time()
heatmaps_batch = processor.process_batch(test_images)
time_batch = time.time() - start
print(f"  20 images: {time_batch:.3f}s ({time_batch/20:.3f}s per image)")

# Speedup
speedup = time_seq / time_batch
print(f"\n‚úì Speedup: {speedup:.1f}x faster with batching!")

### 7.3 Performance Profiling

In [None]:
# Create profiler
profiler = PerformanceProfiler(enabled=True)

# Profile interpretation pipeline
with profiler.profile("total"):
    with profiler.profile("preprocessing"):
        # Simulate preprocessing
        tensor = torch.from_numpy(np.array(test_image)).permute(2, 0, 1).float() / 255.0
        tensor = tensor.unsqueeze(0)
    
    with profiler.profile("forward_pass"):
        with torch.no_grad():
            output = model(tensor)
    
    with profiler.profile("explanation"):
        heatmap = gradcam.explain(test_image, target_class=5)
    
    with profiler.profile("postprocessing"):
        normalized = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)

# Print statistics
profiler.print_stats()

# Identify bottleneck
stats = profiler.get_stats()
slowest = max(stats.items(), key=lambda x: x[1]['mean'])
print(f"\n‚ö† Bottleneck: {slowest[0]} ({slowest[1]['mean']:.3f}s)")
print(f"  Focus optimization efforts here!")

### 7.4 Model Optimization

In [None]:
# Benchmark original model
print("Original model:")
tensor = torch.from_numpy(np.array(test_image)).permute(2, 0, 1).float() / 255.0
tensor = tensor.unsqueeze(0)

start = time.time()
for _ in range(10):
    with torch.no_grad():
        output = model(tensor)
time_original = time.time() - start
print(f"  10 forward passes: {time_original:.3f}s ({time_original/10:.3f}s each)")

# Optimize model
model_opt = optimize_for_inference(model, use_fp16=False)

# Benchmark optimized model
print("\nOptimized model:")
start = time.time()
for _ in range(10):
    with torch.no_grad():
        output = model_opt(tensor)
time_optimized = time.time() - start
print(f"  10 forward passes: {time_optimized:.3f}s ({time_optimized/10:.3f}s each)")

# Speedup
speedup = time_original / time_optimized
print(f"\n‚úì Speedup: {speedup:.1f}x faster with optimization!")

print("\nOptimizations applied:")
print("  ‚úì Disabled gradient computation")
print("  ‚úì Enabled cudnn benchmarking")
print("  ‚úì Set to eval mode")

### 7.5 Combined Optimization Strategy

In [None]:
print("\n" + "="*70)
print("PERFORMANCE OPTIMIZATION SUMMARY")
print("="*70)
print("\nOptimization Strategies:")
print("\n1. Caching:")
print("   - Use case: Repeated explanations for same images")
print("   - Speedup: 10-50x")
print("   - Trade-off: Disk space")
print("\n2. Batch Processing:")
print("   - Use case: Multiple images at once")
print("   - Speedup: 2-5x")
print("   - Trade-off: Memory usage")
print("\n3. Model Optimization:")
print("   - Use case: All inference operations")
print("   - Speedup: 1.5-3x")
print("   - Trade-off: None (always recommended)")
print("\n4. Profiling:")
print("   - Use case: Identifying bottlenecks")
print("   - Benefit: Targeted optimization")
print("   - Trade-off: Small overhead when enabled")
print("\n‚úì Combined: Up to 100x speedup possible!")
print("="*70)

## 8. Production Best Practices

Deploying interpretations in production? Follow these guidelines.

### 8.1 Recommended Production Setup

In [None]:
# Example production configuration
print("""
PRODUCTION SETUP CHECKLIST:

‚úì 1. Model Optimization
   model = optimize_for_inference(model, use_fp16=True)  # GPU only

‚úì 2. Enable Caching
   cache = ExplanationCache(
       cache_dir='/var/cache/explanations',
       max_size_mb=5000,  # 5GB
       enabled=True
   )

‚úì 3. Batch Processing
   processor = BatchProcessor(
       model, explainer,
       batch_size=32,  # Tune for your hardware
       use_cuda=True
   )

‚úì 4. Performance Monitoring
   profiler = PerformanceProfiler(enabled=True)
   
   with profiler.profile('request'):
       explanation = generate_explanation(image)
   
   # Alert if slow
   if profiler.get_stats()['request']['mean'] > 1.0:
       log_warning('Slow explanation detected')

‚úì 5. Error Handling
   try:
       heatmap = cache.get(image, method='gradcam')
       if heatmap is None:
           heatmap = explainer.explain(image)
           cache.put(image, method='gradcam', explanation=heatmap)
   except Exception as e:
       log_error(f'Explanation failed: {e}')
       # Fallback: return default or retry

‚úì 6. Resource Limits
   - Set memory limits for cache
   - Monitor disk usage
   - Implement request throttling
   - Use async processing for slow methods

‚úì 7. Logging & Monitoring
   - Log cache hit rates
   - Monitor explanation latency
   - Track error rates
   - Alert on anomalies
""")

### 8.2 Method Selection Guide

In [None]:
# Method selection guide
method_guide = pd.DataFrame([
    {
        'Method': 'GradCAM',
        'Speed': 'Fast',
        'Quality': 'Good',
        'Best For': 'General CNNs, Production',
        'Limitation': 'CNN only'
    },
    {
        'Method': 'GradCAM++',
        'Speed': 'Fast',
        'Quality': 'Better',
        'Best For': 'Multiple objects, CNNs',
        'Limitation': 'CNN only'
    },
    {
        'Method': 'Integrated Gradients',
        'Speed': 'Slow',
        'Quality': 'Excellent',
        'Best For': 'Research, Pixel-level',
        'Limitation': 'Computationally expensive'
    },
    {
        'Method': 'SmoothGrad',
        'Speed': 'Slow',
        'Quality': 'Very Good',
        'Best For': 'Stable explanations',
        'Limitation': 'Multiple forward passes'
    },
    {
        'Method': 'AttentionRollout',
        'Speed': 'Fast',
        'Quality': 'Good',
        'Best For': 'Vision Transformers',
        'Limitation': 'ViT only'
    },
    {
        'Method': 'AttentionFlow',
        'Speed': 'Fast',
        'Quality': 'Good',
        'Best For': 'Vision Transformers',
        'Limitation': 'ViT only'
    },
])

print("\n" + "="*80)
print("METHOD SELECTION GUIDE")
print("="*80)
print(method_guide.to_string(index=False))
print("="*80)

print("\nRecommendations:")
print("  ‚Ä¢ Production: GradCAM or GradCAM++ (fast, reliable)")
print("  ‚Ä¢ Research: Integrated Gradients or SmoothGrad (thorough)")
print("  ‚Ä¢ CNNs: GradCAM, GradCAM++, Integrated Gradients, SmoothGrad")
print("  ‚Ä¢ ViTs: AttentionRollout, AttentionFlow")
print("  ‚Ä¢ Real-time: GradCAM (fastest)")
print("  ‚Ä¢ Highest quality: Integrated Gradients (slowest)")

### 8.3 Common Pitfalls and Solutions

In [None]:
print("""
COMMON PITFALLS AND SOLUTIONS:

1. PITFALL: Slow explanations in production
   SOLUTION: Enable caching, use GradCAM, optimize model

2. PITFALL: Different explanations for same image
   SOLUTION: Check if model is in eval mode, disable dropout

3. PITFALL: Explanations look noisy
   SOLUTION: Use SmoothGrad or GradCAM++ instead of raw gradients

4. PITFALL: Out of memory errors
   SOLUTION: Reduce batch size, use FP16, clear cache regularly

5. PITFALL: Explanations don't match intuition
   SOLUTION: Validate with metrics (deletion, insertion), try multiple methods

6. PITFALL: Wrong target layer selected
   SOLUTION: Use layer_name parameter or automatic detection

7. PITFALL: Cache grows too large
   SOLUTION: Set max_size_mb, monitor utilization, implement cleanup

8. PITFALL: Metrics show poor explanation quality
   SOLUTION: Try different methods, check if model is properly trained

9. PITFALL: Interactive visualizations too large
   SOLUTION: Reduce image resolution, limit number of methods

10. PITFALL: Inconsistent results across runs
    SOLUTION: Set random seed, check for model randomness (dropout, etc.)
""")

## 9. Complete Workflow Example

Let's put it all together in a realistic workflow.

In [None]:
def complete_interpretation_workflow(
    model,
    image,
    target_class=None,
    enable_cache=True,
    enable_profiling=True,
    save_interactive=True
):
    """
    Complete interpretation workflow demonstrating best practices.
    
    Args:
        model: The model to interpret
        image: Input image
        target_class: Target class (None = predicted class)
        enable_cache: Whether to use caching
        enable_profiling: Whether to profile performance
        save_interactive: Whether to save interactive visualization
    
    Returns:
        dict: Results including explanation, metrics, and timings
    """
    results = {}
    
    # 1. Setup
    print("\n" + "="*70)
    print("COMPLETE INTERPRETATION WORKFLOW")
    print("="*70)
    
    # Initialize components
    explainer = GradCAM(model)
    
    if enable_cache:
        cache = ExplanationCache(cache_dir="./workflow_cache", max_size_mb=100)
        print("‚úì Cache enabled")
    
    if enable_profiling:
        profiler = PerformanceProfiler(enabled=True)
        print("‚úì Profiling enabled")
    
    # 2. Generate Explanation
    print("\nStep 1: Generating explanation...")
    
    if enable_profiling:
        with profiler.profile("explanation"):
            if enable_cache:
                heatmap = cache.get(image, method="gradcam", target_class=target_class)
                if heatmap is None:
                    heatmap = explainer.explain(image, target_class=target_class)
                    cache.put(image, method="gradcam", explanation=heatmap, target_class=target_class)
                    print("  Cache MISS - computed")
                else:
                    print("  Cache HIT - loaded from cache")
            else:
                heatmap = explainer.explain(image, target_class=target_class)
    else:
        heatmap = explainer.explain(image, target_class=target_class)
    
    results['heatmap'] = heatmap
    print("‚úì Explanation generated")
    
    # 3. Evaluate Quality
    print("\nStep 2: Evaluating explanation quality...")
    metrics = ExplanationMetrics(model, explainer)
    
    if enable_profiling:
        with profiler.profile("metrics"):
            deletion_result = metrics.deletion(image, target_class=target_class, steps=10)
            insertion_result = metrics.insertion(image, target_class=target_class, steps=10)
    else:
        deletion_result = metrics.deletion(image, target_class=target_class, steps=10)
        insertion_result = metrics.insertion(image, target_class=target_class, steps=10)
    
    results['metrics'] = {
        'deletion_auc': deletion_result['auc'],
        'insertion_auc': insertion_result['auc']
    }
    
    print(f"  Deletion AUC: {deletion_result['auc']:.4f}")
    print(f"  Insertion AUC: {insertion_result['auc']:.4f}")
    print("‚úì Metrics computed")
    
    # 4. Visualize
    print("\nStep 3: Creating visualizations...")
    
    # Static visualization
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    explainer.visualize(image, heatmap, alpha=0.6, title="Explanation", ax=ax)
    plt.tight_layout()
    plt.savefig("workflow_static.png", dpi=150, bbox_inches='tight')
    plt.close()
    print("  Static visualization saved: workflow_static.png")
    
    # Interactive visualization
    if save_interactive and INTERACTIVE_AVAILABLE:
        viz = InteractiveVisualizer(model)
        fig = viz.visualize_explanation(
            image,
            explainer,
            target_class=target_class,
            save_path="workflow_interactive.html"
        )
        print("  Interactive visualization saved: workflow_interactive.html")
    
    print("‚úì Visualizations created")
    
    # 5. Performance Summary
    if enable_profiling:
        print("\nStep 4: Performance summary...")
        profiler.print_stats()
        results['profiling'] = profiler.get_stats()
    
    # 6. Cache Statistics
    if enable_cache:
        print("\nCache statistics:")
        cache_stats = cache.stats()
        print(f"  Entries: {cache_stats['num_entries']}")
        print(f"  Size: {cache_stats['total_size_mb']:.2f} MB")
        print(f"  Utilization: {cache_stats['utilization']:.1%}")
        results['cache'] = cache_stats
        
        # Cleanup
        cache.clear()
    
    print("\n" + "="*70)
    print("‚úì WORKFLOW COMPLETE")
    print("="*70)
    
    return results

# Run complete workflow
workflow_results = complete_interpretation_workflow(
    model,
    test_image,
    target_class=5,
    enable_cache=True,
    enable_profiling=True,
    save_interactive=INTERACTIVE_AVAILABLE
)

## 10. Summary and Next Steps

Congratulations! You've learned the complete AutoTimm interpretation toolkit.

In [None]:
print("""
‚ïî‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïó
‚ïë                    TUTORIAL COMPLETE!                                  ‚ïë
‚ïö‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïù

What You Learned:

‚úì 1. INTERPRETATION METHODS (6 methods)
     ‚Ä¢ GradCAM, GradCAM++
     ‚Ä¢ Integrated Gradients, SmoothGrad
     ‚Ä¢ AttentionRollout, AttentionFlow

‚úì 2. QUALITY METRICS (6 metrics)
     ‚Ä¢ Deletion, Insertion (faithfulness)
     ‚Ä¢ Sensitivity-N (stability)
     ‚Ä¢ Sanity checks (2)
     ‚Ä¢ Pointing game (localization)

‚úì 3. INTERACTIVE VISUALIZATIONS
     ‚Ä¢ Plotly-based exploration
     ‚Ä¢ Method comparisons
     ‚Ä¢ HTML reports

‚úì 4. PERFORMANCE OPTIMIZATION
     ‚Ä¢ Caching (10-50x speedup)
     ‚Ä¢ Batch processing (2-5x speedup)
     ‚Ä¢ Model optimization (1.5-3x speedup)
     ‚Ä¢ Profiling tools

‚úì 5. PRODUCTION BEST PRACTICES
     ‚Ä¢ Method selection
     ‚Ä¢ Error handling
     ‚Ä¢ Monitoring
     ‚Ä¢ Common pitfalls

Next Steps:

1. Try with your own model and data
2. Experiment with different methods
3. Evaluate explanations with metrics
4. Deploy to production with optimizations
5. Read the full documentation at: docs/user-guide/interpretation/

Resources:

‚Ä¢ Documentation: docs/
‚Ä¢ Examples: examples/
‚Ä¢ Tests: tests/test_interpretation*.py
‚Ä¢ API Reference: https://autotimm.readthedocs.io/

Questions? Issues?

‚Ä¢ GitHub Issues: https://github.com/yourusername/autotimm/issues
‚Ä¢ Documentation: https://autotimm.readthedocs.io/

Happy Interpreting! üéâ
""")