# DiET vs Basic XAI Methods - Comparison Framework

This notebook demonstrates how to use the XAI comparison framework to compare:
- **Images**: DiET vs GradCAM on CIFAR-10
- **Text**: DiET vs Integrated Gradients on SST-2

## Overview

The framework provides:
1. Easy configuration via `ComparisonConfig`
2. Comprehensive metrics (Pixel Perturbation, AOPC, Faithfulness, etc.)
3. Rich visualizations (bar charts, radar plots, dashboards)
4. DataFrame output for analysis

In [None]:
# Add the project root to path
import sys
sys.path.insert(0, '..')

In [None]:
# Import the comparison framework
from scripts.xai_experiments import XAIMethodsComparison, ComparisonConfig

## 1. Configure the Comparison

Use `ComparisonConfig` to set up your experiment parameters.

In [None]:
# Create configuration
config = ComparisonConfig(
    # Device settings
    device="cuda",  # Use "cpu" if no GPU
    
    # Image experiment settings
    image_model_type="resnet",
    image_batch_size=32,
    image_epochs=3,
    image_max_samples=3000,
    image_comparison_samples=50,  # More samples = more robust metrics
    
    # Text experiment settings
    text_model_name="bert-base-uncased",
    text_epochs=2,
    text_max_samples=1000,
    text_comparison_samples=20,
    
    # DiET settings
    diet_upsample_factor=4,
    diet_rounding_steps=2,
    
    # Output directory
    output_dir="./outputs/notebook_comparison"
)

print("Configuration created!")
print(f"Device: {config.device}")
print(f"Image samples: {config.image_comparison_samples}")
print(f"Text samples: {config.text_comparison_samples}")

## 2. Initialize the Comparison Framework

In [None]:
# Initialize the comparison module
comparison = XAIMethodsComparison(config)
print("Framework initialized!")

## 3. Run Comparisons

You can run image-only, text-only, or both comparisons.

In [None]:
# Run the full comparison (this may take a while)
results = comparison.run_full_comparison(
    run_images=True,   # DiET vs GradCAM on CIFAR-10
    run_text=True,     # DiET vs IG on SST-2
    skip_training=False  # Set to True to use saved models
)

## 4. View Results as DataFrame

In [None]:
# Get results as pandas DataFrame for easy analysis
df = comparison.get_results_dataframe()
df

## 5. Generate Visualizations

In [None]:
# Generate all visualizations
viz_files = comparison.visualize_results(save_plots=True, show=True)

print("\nGenerated visualizations:")
for name, path in viz_files.items():
    print(f"  {name}: {path}")

## 6. Explore Results

In [None]:
# Access raw results
print("=" * 50)
print("IMAGE RESULTS (DiET vs GradCAM)")
print("=" * 50)

if results.get("image_experiments"):
    img = results["image_experiments"]
    print(f"Baseline Accuracy: {img.get('baseline_accuracy', 'N/A'):.2f}%")
    print(f"DiET Accuracy: {img.get('diet_accuracy', 'N/A'):.2f}%")
    print(f"GradCAM Score: {img.get('gradcam_mean_score', 'N/A'):.4f}")
    print(f"DiET Score: {img.get('diet_mean_score', 'N/A'):.4f}")
    
    if img.get('diet_better'):
        print(f"\n✓ DiET improves attribution by {img.get('improvement', 0):.4f}")
    else:
        print("\n→ GradCAM performs adequately")

In [None]:
print("=" * 50)
print("TEXT RESULTS (DiET vs IG)")
print("=" * 50)

if results.get("text_experiments"):
    txt = results["text_experiments"]
    print(f"Baseline Accuracy: {txt.get('baseline_accuracy', 'N/A'):.2f}%")
    print(f"IG-DiET Token Overlap: {txt.get('ig_diet_overlap', 'N/A'):.4f}")
    print(f"Samples Compared: {txt.get('samples_compared', 'N/A')}")
    
    overlap = txt.get('ig_diet_overlap', 0)
    if overlap > 0.5:
        print("\n✓ High agreement between IG and DiET")
    else:
        print("\n→ DiET identifies different discriminative features")

## 7. Quick Start - Alternative API

For even simpler usage, use the `run_comparison` convenience function:

In [None]:
# Alternative: use convenience function
from scripts.xai_experiments.run_xai_experiments import run_comparison

# Run with sensible defaults
quick_results = run_comparison(
    run_images=True,
    run_text=False,  # Only images for quick demo
    low_vram=True,
    output_dir="./outputs/quick_demo"
)

## 8. Using Individual Metrics

You can also use the metrics module directly for custom evaluation:

In [None]:
from scripts.xai_experiments.metrics import (
    AttributionMetrics,
    PixelPerturbation,
    InsertionDeletion,
    AOPC,
    FaithfulnessCorrelation
)

# Example: Create pixel perturbation metric
# pixel_pert = PixelPerturbation(
#     model=your_model,
#     device="cuda",
#     percentages=[5, 10, 20, 30, 50, 70, 90],
#     perturbation_type="keep"  # or "remove"
# )

# result = pixel_pert.compute(images, labels, attribution_maps)
# print(result.to_dict())

print("Metrics available:")
print("  - PixelPerturbation: Keep/remove important pixels")
print("  - InsertionDeletion: Progressive insertion/deletion curves")
print("  - AOPC: Area Over Perturbation Curve")
print("  - FaithfulnessCorrelation: Attribution-sensitivity correlation")

## 9. Using Visualization Module Directly

In [None]:
from scripts.xai_experiments.visualization import (
    ComparisonVisualizer,
    plot_metric_comparison,
    create_comparison_report
)

# Create a custom visualizer
# viz = ComparisonVisualizer(output_dir="./custom_viz")

# Create bar chart
# fig = viz.plot_metric_comparison_bar(
#     results={"GradCAM": {"score": 0.7}, "DiET": {"score": 0.85}},
#     title="Custom Comparison"
# )

print("Visualization functions available:")
print("  - plot_metric_comparison_bar")
print("  - plot_radar_comparison")
print("  - plot_image_attribution_comparison")
print("  - plot_text_attribution_comparison")
print("  - create_summary_dashboard")
print("  - generate_html_report")

## Summary

This framework provides a comprehensive toolkit for comparing DiET with basic XAI methods:

1. **Easy Configuration**: Use `ComparisonConfig` to customize experiments
2. **Unified API**: `XAIMethodsComparison` handles all experiments
3. **Rich Metrics**: Pixel perturbation, AOPC, faithfulness, etc.
4. **Great Visualizations**: Bar charts, radar plots, dashboards, HTML reports
5. **DataFrame Output**: Easy integration with pandas for analysis

For more details, see the [README](../README.md).