# üî¨ DiET vs Basic XAI Methods - Comprehensive Comparison Framework

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/xMOROx/Machine-Learning-Project-2025-2026/blob/main/notebooks/diet_comparison_colab.ipynb)

**A Hands-on Tutorial for Reproducing XAI Comparison Results**

---

## üìã What You'll Learn

This notebook provides a comprehensive, reproducible comparison of:

| Modality | Methods Compared | Datasets |
|----------|------------------|----------|
| **Images** | DiET vs GradCAM | CIFAR-10, CIFAR-100, SVHN, Fashion-MNIST |
| **Text** | DiET vs Integrated Gradients | SST-2, IMDB, AG News |

### üìñ Topics Covered

1. **Environment Setup** - Install dependencies and clone repository
2. **Quick Start** - Run comparison with minimal code
3. **Dataset Exploration** - Explore all supported datasets
4. **Metrics Deep Dive** - Understand evaluation metrics
5. **Full Comparison** - Run multi-dataset experiments
6. **Visualization** - Create publication-ready plots
7. **Custom Experiments** - Advanced configuration
8. **Results Analysis** - Interpret and export results

### üìö References

- **DiET Paper**: Bhalla et al., "Discriminative Feature Attributions", NeurIPS 2023
- **GradCAM**: Selvaraju et al., "Grad-CAM: Visual Explanations from Deep Networks", ICCV 2017
- **Integrated Gradients**: Sundararajan et al., "Axiomatic Attribution for Deep Networks", ICML 2017

---

# üöÄ Part 1: Environment Setup

First, let's set up the environment for Google Colab.

In [None]:
# Check GPU availability
import torch

if torch.cuda.is_available():
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    DEVICE = "cuda"
else:
    print("‚ö†Ô∏è No GPU detected. Training will be slower.")
    print("   Go to Runtime > Change runtime type > Hardware accelerator > GPU")
    DEVICE = "cpu"

In [None]:
# Clone the repository
import os

REPO_URL = "https://github.com/xMOROx/Machine-Learning-Project-2025-2026.git"
REPO_DIR = "Machine-Learning-Project-2025-2026"

if not os.path.exists(REPO_DIR):
    !git clone {REPO_URL}
    print(f"‚úÖ Repository cloned to {REPO_DIR}")
else:
    print(f"‚úÖ Repository already exists at {REPO_DIR}")
    # Pull latest changes
    !cd {REPO_DIR} && git pull

# Change to repository directory
os.chdir(REPO_DIR)
print(f"üìÅ Working directory: {os.getcwd()}")

In [None]:
# Install dependencies
print("üì¶ Installing dependencies...")

# Core dependencies
!pip install -q torch torchvision torchaudio
!pip install -q transformers datasets
!pip install -q captum  # For Integrated Gradients
!pip install -q matplotlib seaborn pandas numpy
!pip install -q tqdm scikit-learn

print("‚úÖ Dependencies installed!")

In [None]:
# Add project to Python path
import sys
sys.path.insert(0, '.')

# Verify imports work
try:
    from scripts.xai_experiments import XAIMethodsComparison, ComparisonConfig
    from scripts.xai_experiments.datasets import SUPPORTED_IMAGE_DATASETS, SUPPORTED_TEXT_DATASETS
    from scripts.xai_experiments.metrics import AttributionMetrics, PixelPerturbation, AOPC
    from scripts.xai_experiments.visualization import ComparisonVisualizer
    print("‚úÖ All imports successful!")
except ImportError as e:
    print(f"‚ùå Import error: {e}")
    print("   Make sure you're in the repository directory")

In [None]:
# Create output directories
import os

OUTPUT_DIR = "./outputs/colab_experiments"
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/checkpoints", exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/visualizations", exist_ok=True)

print(f"üìÅ Output directory: {OUTPUT_DIR}")

---

# ‚ö° Part 2: Quick Start

Run a complete comparison with just a few lines of code.

In [None]:
# Quick start - minimal configuration for a fast demo
from scripts.xai_experiments import XAIMethodsComparison, ComparisonConfig

# Use reduced settings for quick demo (full comparison later)
quick_config = ComparisonConfig(
    device=DEVICE,
    
    # Use only one dataset each for quick demo
    image_datasets=["cifar10"],
    text_datasets=["sst2"],
    
    # Reduced training for speed
    image_epochs=2,
    image_max_samples=1000,
    image_comparison_samples=20,
    
    text_epochs=1,
    text_max_samples=500,
    text_comparison_samples=10,
    text_top_k=5,  # Show top 5 important tokens
    
    output_dir=f"{OUTPUT_DIR}/quick_demo"
)

print("üìã Quick Demo Configuration:")
print(f"   Device: {quick_config.device}")
print(f"   Image datasets: {quick_config.image_datasets}")
print(f"   Text datasets: {quick_config.text_datasets}")

In [None]:
# Run quick demo (images only for speed)
quick_comparison = XAIMethodsComparison(quick_config)

print("üöÄ Running quick image comparison (DiET vs GradCAM)...")
quick_results = quick_comparison.run_full_comparison(
    run_images=True,
    run_text=False,  # Skip text for quick demo
    skip_training=False
)

print("\n‚úÖ Quick demo complete!")

In [None]:
# View results as DataFrame
df = quick_comparison.get_results_dataframe()
display(df)

---

# üìä Part 3: Dataset Exploration

Let's explore all the datasets supported by the framework.

In [None]:
# List all supported datasets
from scripts.xai_experiments.datasets import SUPPORTED_IMAGE_DATASETS, SUPPORTED_TEXT_DATASETS

print("üñºÔ∏è Supported Image Datasets:")
print("-" * 50)
for name, info in SUPPORTED_IMAGE_DATASETS.items():
    print(f"  ‚Ä¢ {name}: {info['description']} ({info['num_classes']} classes)")

print("\nüìù Supported Text Datasets:")
print("-" * 50)
for name, info in SUPPORTED_TEXT_DATASETS.items():
    print(f"  ‚Ä¢ {name}: {info['description']} ({info['num_classes']} classes)")

In [None]:
# Load and visualize image datasets
from scripts.xai_experiments.datasets import get_image_dataset
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 4, figsize=(16, 8))
image_datasets = ["cifar10", "cifar100", "svhn", "fashion_mnist"]

for idx, dataset_name in enumerate(image_datasets):
    try:
        # Load dataset
        train_loader, test_loader, num_classes = get_image_dataset(
            dataset_name,
            batch_size=16,
            max_samples=100
        )
        
        # Get sample batch
        images, labels = next(iter(train_loader))
        
        # Plot first image
        ax = axes[0, idx]
        img = images[0].permute(1, 2, 0).numpy()
        img = (img - img.min()) / (img.max() - img.min())  # Normalize for display
        ax.imshow(img)
        ax.set_title(f"{dataset_name.upper()}\n(Class: {labels[0].item()})")
        ax.axis('off')
        
        # Plot grid of images
        ax = axes[1, idx]
        grid_size = 4
        grid_img = images[:grid_size*grid_size].reshape(grid_size, grid_size, *images.shape[1:])
        ax.text(0.5, 0.5, f"{num_classes} classes", ha='center', va='center', fontsize=14)
        ax.axis('off')
        
        print(f"‚úÖ {dataset_name}: {len(train_loader.dataset)} train samples, {num_classes} classes")
        
    except Exception as e:
        print(f"‚ùå Error loading {dataset_name}: {e}")
        axes[0, idx].text(0.5, 0.5, f"Error loading\n{dataset_name}", ha='center', va='center')
        axes[0, idx].axis('off')

plt.suptitle("Supported Image Datasets", fontsize=16, fontweight='bold')
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/visualizations/dataset_samples.png", dpi=150)
plt.show()

In [None]:
# Load and explore text datasets
from scripts.xai_experiments.datasets import get_text_dataset

text_datasets = ["sst2", "imdb", "ag_news"]

print("üìù Text Dataset Samples:")
print("=" * 80)

for dataset_name in text_datasets:
    try:
        train_data, test_data, num_classes = get_text_dataset(
            dataset_name,
            max_samples=10
        )
        
        print(f"\nüî§ {dataset_name.upper()} ({num_classes} classes)")
        print("-" * 60)
        
        # Show sample texts
        for i, sample in enumerate(train_data[:3]):
            text = sample['text'][:100] + "..." if len(sample['text']) > 100 else sample['text']
            label = sample['label']
            print(f"  [{label}] {text}")
        
    except Exception as e:
        print(f"‚ùå Error loading {dataset_name}: {e}")

---

# üìè Part 4: Metrics Deep Dive

Understand the evaluation metrics used to compare XAI methods.

In [None]:
# Display metrics explanation
from IPython.display import Markdown, display

metrics_explanation = """
## üìä Evaluation Metrics

### Image Attribution Metrics

| Metric | Description | Higher = Better? |
|--------|-------------|------------------|
| **Pixel Perturbation (Keep)** | Keep top-k% most important pixels, measure accuracy | ‚úÖ Yes |
| **Pixel Perturbation (Remove)** | Remove top-k% most important pixels, measure accuracy drop | ‚ùå No (lower = better) |
| **AOPC** | Area Over Perturbation Curve - aggregate measure | ‚úÖ Yes |
| **Insertion** | Progressively add pixels in importance order | ‚úÖ Yes |
| **Deletion** | Progressively remove pixels in importance order | ‚ùå No |
| **Faithfulness** | Correlation between attribution and model sensitivity | ‚úÖ Yes |

### Text Attribution Metrics

| Metric | Description | Range |
|--------|-------------|-------|
| **Top-k Token Overlap** | Agreement between IG and DiET on most important tokens | 0-1 |
| **Token Attribution Score** | Importance score for each token | Any |

### Interpretation

- **DiET Better**: When DiET achieves higher scores on perturbation-based metrics
- **High Overlap**: IG and DiET agree on important features (good for validation)
- **Low Overlap**: DiET identifies different discriminative features than IG
"""

display(Markdown(metrics_explanation))

In [None]:
# Demonstrate metrics computation
from scripts.xai_experiments.metrics import PixelPerturbation, AOPC, InsertionDeletion
import torch
import numpy as np

print("üî¨ Metrics Demonstration")
print("=" * 50)

# Create dummy data for demonstration
batch_size = 4
dummy_images = torch.randn(batch_size, 3, 32, 32)
dummy_attributions = torch.rand(batch_size, 32, 32)  # Random attributions

print(f"\nüìä Sample shapes:")
print(f"   Images: {dummy_images.shape}")
print(f"   Attributions: {dummy_attributions.shape}")

# Show how metrics work conceptually
print("\nüìê Metric Computation:")
print("\n   1. Pixel Perturbation:")
print("      - Rank pixels by attribution importance")
print("      - Keep/remove top k% of pixels")
print("      - Measure model accuracy change")

print("\n   2. AOPC (Area Over Perturbation Curve):")
print("      - Compute perturbation at multiple thresholds")
print("      - Calculate area under the curve")
print("      - Higher AOPC = more faithful attributions")

print("\n   3. Insertion/Deletion:")
print("      - Start with blank/full image")
print("      - Progressively add/remove pixels")
print("      - Track prediction confidence")

---

# üî¨ Part 5: Full Multi-Dataset Comparison

Run the complete comparison across all datasets for robust results.

In [None]:
# Full comparison configuration
from scripts.xai_experiments import XAIMethodsComparison, ComparisonConfig

# Configure for comprehensive comparison
# Adjust these based on your GPU memory and time constraints

full_config = ComparisonConfig(
    device=DEVICE,
    
    # === Image Datasets ===
    # Use all 4 datasets for robust comparison
    image_datasets=["cifar10", "cifar100", "svhn", "fashion_mnist"],
    image_model_type="resnet",
    image_batch_size=64 if DEVICE == "cuda" else 16,
    image_epochs=5,
    image_max_samples=5000,  # Training samples per dataset
    image_comparison_samples=100,  # Samples for XAI comparison
    
    # === Text Datasets ===
    # Use all 3 datasets
    text_datasets=["sst2", "imdb", "ag_news"],
    text_model_name="bert-base-uncased",
    text_max_length=128,
    text_epochs=3,
    text_max_samples=2000,
    text_comparison_samples=50,
    text_top_k=10,  # Show top 10 tokens
    
    # === DiET Settings ===
    diet_upsample_factor=4,
    diet_rounding_steps=2,
    
    # === Metric Settings ===
    perturbation_percentages=[5, 10, 20, 30, 50, 70, 90],
    insertion_deletion_steps=50,
    aopc_steps=10,
    faithfulness_samples=30,
    
    # === Output ===
    output_dir=f"{OUTPUT_DIR}/full_comparison",
    save_visualizations=True
)

print("üìã Full Comparison Configuration")
print("=" * 50)
print(f"Device: {full_config.device}")
print(f"\nüñºÔ∏è Image Datasets: {', '.join(full_config.image_datasets)}")
print(f"   Epochs: {full_config.image_epochs}")
print(f"   Samples/dataset: {full_config.image_max_samples}")
print(f"   Comparison samples: {full_config.image_comparison_samples}")
print(f"\nüìù Text Datasets: {', '.join(full_config.text_datasets)}")
print(f"   Epochs: {full_config.text_epochs}")
print(f"   Top-k tokens: {full_config.text_top_k}")

In [None]:
# Run the full comparison
# ‚ö†Ô∏è This may take 30-60 minutes depending on GPU

full_comparison = XAIMethodsComparison(full_config)

print("üöÄ Starting full comparison...")
print("   This may take 30-60 minutes.")
print("   Training checkpoints are saved automatically.")
print("   If interrupted, re-run this cell to resume.\n")

# Run both image and text experiments
full_results = full_comparison.run_full_comparison(
    run_images=True,
    run_text=True,
    skip_training=False  # Set to True to use cached models
)

print("\n" + "=" * 50)
print("‚úÖ FULL COMPARISON COMPLETE!")
print("=" * 50)

In [None]:
# View comprehensive results as DataFrame
import pandas as pd

df = full_comparison.get_results_dataframe()

print("üìä Results Summary")
print("=" * 50)
display(df)

# Save to CSV
csv_path = f"{OUTPUT_DIR}/full_comparison/results.csv"
df.to_csv(csv_path, index=False)
print(f"\nüìÅ Results saved to: {csv_path}")

---

# üìà Part 6: Visualization

Create publication-ready visualizations of the comparison results.

In [None]:
# Generate all visualizations
viz_files = full_comparison.visualize_results(save_plots=True, show=True)

print("\nüìä Generated Visualization Files:")
for name, path in viz_files.items():
    print(f"   ‚Ä¢ {name}: {path}")

In [None]:
# Create custom comparison bar chart
import matplotlib.pyplot as plt
import numpy as np

# Extract results for plotting
image_results = full_results.get("image_experiments", {})

if image_results:
    datasets = []
    gradcam_scores = []
    diet_scores = []
    
    for ds_name, ds_data in image_results.items():
        if "error" not in ds_data:
            datasets.append(ds_name.upper())
            gradcam_scores.append(ds_data.get("gradcam_mean_score", 0))
            diet_scores.append(ds_data.get("diet_mean_score", 0))
    
    # Create grouped bar chart
    x = np.arange(len(datasets))
    width = 0.35
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    bars1 = ax.bar(x - width/2, gradcam_scores, width, label='GradCAM', color='#2196F3', edgecolor='black')
    bars2 = ax.bar(x + width/2, diet_scores, width, label='DiET', color='#4CAF50', edgecolor='black')
    
    ax.set_xlabel('Dataset', fontsize=12)
    ax.set_ylabel('Pixel Perturbation Score', fontsize=12)
    ax.set_title('DiET vs GradCAM: Image Attribution Quality\n(Higher = Better)', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(datasets)
    ax.legend(loc='upper right')
    ax.set_ylim(0, 1)
    ax.grid(axis='y', alpha=0.3)
    
    # Add value labels
    for bar in bars1:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
               f'{height:.3f}', ha='center', va='bottom', fontsize=9)
    for bar in bars2:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
               f'{height:.3f}', ha='center', va='bottom', fontsize=9)
    
    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/visualizations/image_comparison_bar.png", dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("No image results available for visualization")

In [None]:
# Create text comparison visualization
text_results = full_results.get("text_experiments", {})

if text_results:
    datasets = []
    overlaps = []
    accuracies = []
    
    for ds_name, ds_data in text_results.items():
        if "error" not in ds_data:
            datasets.append(ds_name.upper())
            overlaps.append(ds_data.get("ig_diet_overlap", 0))
            accuracies.append(ds_data.get("baseline_accuracy", 0))
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Token overlap
    colors = ['#FF9800' if o > 0.5 else '#F44336' for o in overlaps]
    bars = axes[0].barh(datasets, overlaps, color=colors, edgecolor='black')
    axes[0].set_xlim(0, 1)
    axes[0].set_xlabel('Top-k Token Overlap', fontsize=12)
    axes[0].set_title('IG vs DiET Token Agreement\n(Higher = More Agreement)', fontsize=12, fontweight='bold')
    axes[0].axvline(x=0.5, color='gray', linestyle='--', alpha=0.5, label='50% threshold')
    
    for bar, val in zip(bars, overlaps):
        axes[0].text(val + 0.02, bar.get_y() + bar.get_height()/2.,
                    f'{val:.3f}', va='center', fontsize=10)
    
    # Baseline accuracy
    bars = axes[1].barh(datasets, accuracies, color='#9C27B0', edgecolor='black')
    axes[1].set_xlim(0, 100)
    axes[1].set_xlabel('Accuracy (%)', fontsize=12)
    axes[1].set_title('Model Baseline Accuracy', fontsize=12, fontweight='bold')
    
    for bar, val in zip(bars, accuracies):
        axes[1].text(val + 1, bar.get_y() + bar.get_height()/2.,
                    f'{val:.1f}%', va='center', fontsize=10)
    
    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/visualizations/text_comparison.png", dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("No text results available for visualization")

In [None]:
# Create summary dashboard
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

fig = plt.figure(figsize=(16, 10))
gs = GridSpec(2, 3, figure=fig, hspace=0.3, wspace=0.3)

# Title
fig.suptitle('DiET vs Basic XAI Methods - Comparison Dashboard', fontsize=16, fontweight='bold', y=1.02)

# 1. Image comparison (top left)
ax1 = fig.add_subplot(gs[0, 0:2])
if image_results:
    x = np.arange(len(datasets))
    ax1.bar(x - 0.2, gradcam_scores, 0.4, label='GradCAM', color='#2196F3')
    ax1.bar(x + 0.2, diet_scores, 0.4, label='DiET', color='#4CAF50')
    ax1.set_xticks(x)
    ax1.set_xticklabels([ds for ds in image_results.keys()])
    ax1.set_ylabel('Score')
    ax1.set_title('Image Attribution Quality')
    ax1.legend()

# 2. Win/Loss summary (top right)
ax2 = fig.add_subplot(gs[0, 2])
if image_results:
    diet_wins = sum(1 for ds in image_results.values() if ds.get('diet_better', False))
    total = len([ds for ds in image_results.values() if 'error' not in ds])
    ax2.pie([diet_wins, total - diet_wins], 
            labels=['DiET Better', 'GradCAM Better'],
            colors=['#4CAF50', '#2196F3'],
            autopct='%1.0f%%',
            startangle=90)
    ax2.set_title('Image: Win Rate')

# 3. Text overlap (bottom left)
ax3 = fig.add_subplot(gs[1, 0:2])
if text_results:
    ds_names = list(text_results.keys())
    overlaps = [text_results[ds].get('ig_diet_overlap', 0) for ds in ds_names if 'error' not in text_results[ds]]
    ax3.barh(ds_names, overlaps, color='#FF9800')
    ax3.set_xlim(0, 1)
    ax3.set_xlabel('Token Overlap')
    ax3.set_title('Text: IG-DiET Agreement')

# 4. Summary stats (bottom right)
ax4 = fig.add_subplot(gs[1, 2])
ax4.axis('off')
summary_text = "üìä Summary Statistics\n\n"
if image_results:
    avg_improvement = np.mean([ds.get('improvement', 0) for ds in image_results.values() if 'error' not in ds])
    summary_text += f"Image Avg Improvement: {avg_improvement:.4f}\n"
if text_results:
    avg_overlap = np.mean([ds.get('ig_diet_overlap', 0) for ds in text_results.values() if 'error' not in ds])
    summary_text += f"Text Avg Overlap: {avg_overlap:.4f}\n"
summary_text += f"\nTotal Datasets: {len(image_results) + len(text_results)}"
ax4.text(0.1, 0.5, summary_text, fontsize=12, verticalalignment='center', 
         bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.savefig(f"{OUTPUT_DIR}/visualizations/comparison_dashboard.png", dpi=300, bbox_inches='tight')
plt.show()

---

# ‚öôÔ∏è Part 7: Custom Experiments

Learn how to customize experiments for specific research needs.

In [None]:
# Run comparison on a single dataset
from scripts.xai_experiments import XAIMethodsComparison, ComparisonConfig

# Focus on CIFAR-100 only
cifar100_config = ComparisonConfig(
    device=DEVICE,
    image_datasets=["cifar100"],
    image_epochs=3,
    image_max_samples=3000,
    image_comparison_samples=50,
    output_dir=f"{OUTPUT_DIR}/cifar100_only"
)

print("Running CIFAR-100 focused experiment...")
cifar100_comparison = XAIMethodsComparison(cifar100_config)
cifar100_results = cifar100_comparison.run_full_comparison(run_images=True, run_text=False)

In [None]:
# Text-only experiment with higher top-k
text_config = ComparisonConfig(
    device=DEVICE,
    text_datasets=["sst2", "imdb"],
    text_epochs=2,
    text_max_samples=1000,
    text_comparison_samples=30,
    text_top_k=15,  # Show top 15 tokens
    output_dir=f"{OUTPUT_DIR}/text_only"
)

print("Running text-only experiment with top-15 tokens...")
text_comparison = XAIMethodsComparison(text_config)
text_results = text_comparison.run_full_comparison(run_images=False, run_text=True)

In [None]:
# Demonstrate resumable training (checkpoint support)
from scripts.xai_experiments import CheckpointManager

# List available checkpoints
ckpt_dir = f"{OUTPUT_DIR}/full_comparison/checkpoints"
ckpt_manager = CheckpointManager(ckpt_dir)

print("üìÅ Available Checkpoints:")
checkpoints = ckpt_manager.list_checkpoints()
if checkpoints:
    for ckpt in checkpoints:
        print(f"   ‚Ä¢ {ckpt}")
else:
    print("   No checkpoints found.")

print("\nüí° Tip: If training is interrupted, it will resume from the last checkpoint.")

---

# üìù Part 8: Results Analysis & Export

Analyze results and export for further use.

In [None]:
# Statistical analysis of results
import pandas as pd
import numpy as np

df = full_comparison.get_results_dataframe()

print("üìä Statistical Summary")
print("=" * 50)

# Image results statistics
image_df = df[df['Modality'] == 'Image']
if not image_df.empty:
    print("\nüñºÔ∏è Image Results:")
    print(f"   Datasets tested: {len(image_df)}")
    print(f"   DiET wins: {image_df['DiET Better'].sum()} / {len(image_df)}")
    print(f"   Mean GradCAM score: {image_df['GradCAM Score'].mean():.4f} ¬± {image_df['GradCAM Score'].std():.4f}")
    print(f"   Mean DiET score: {image_df['DiET Score'].mean():.4f} ¬± {image_df['DiET Score'].std():.4f}")
    print(f"   Mean improvement: {image_df['Improvement'].mean():.4f}")

# Text results statistics
text_df = df[df['Modality'] == 'Text']
if not text_df.empty:
    print("\nüìù Text Results:")
    print(f"   Datasets tested: {len(text_df)}")
    print(f"   Mean IG-DiET overlap: {text_df['IG-DiET Overlap'].mean():.4f} ¬± {text_df['IG-DiET Overlap'].std():.4f}")
    print(f"   Mean accuracy: {text_df['Baseline Accuracy'].mean():.2f}%")

In [None]:
# Export results in multiple formats
import json

export_dir = f"{OUTPUT_DIR}/exports"
os.makedirs(export_dir, exist_ok=True)

# 1. CSV export
csv_path = f"{export_dir}/results.csv"
df.to_csv(csv_path, index=False)
print(f"‚úÖ CSV saved: {csv_path}")

# 2. JSON export (full results)
json_path = f"{export_dir}/results.json"
# Convert numpy types for JSON serialization
def convert_for_json(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, (np.int32, np.int64)):
        return int(obj)
    elif isinstance(obj, dict):
        return {k: convert_for_json(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_for_json(item) for item in obj]
    return obj

with open(json_path, 'w') as f:
    json.dump(convert_for_json(full_results), f, indent=2)
print(f"‚úÖ JSON saved: {json_path}")

# 3. LaTeX table
latex_path = f"{export_dir}/results_table.tex"
latex_table = df.to_latex(index=False, float_format="%.4f")
with open(latex_path, 'w') as f:
    f.write(latex_table)
print(f"‚úÖ LaTeX table saved: {latex_path}")

# 4. Markdown summary
md_path = f"{export_dir}/results_summary.md"
md_content = f"""# DiET vs Basic XAI Methods - Results Summary

## Overview

| Modality | Datasets | DiET Wins | Avg Improvement |
|----------|----------|-----------|------------------|
| Image    | {len(image_df)} | {image_df['DiET Better'].sum() if not image_df.empty else 0} | {image_df['Improvement'].mean():.4f if not image_df.empty else 'N/A'} |
| Text     | {len(text_df)} | N/A | {text_df['IG-DiET Overlap'].mean():.4f if not text_df.empty else 'N/A'} (overlap) |

## Detailed Results

{df.to_markdown(index=False)}

Generated: {full_results.get('timestamp', 'N/A')}
"""
with open(md_path, 'w') as f:
    f.write(md_content)
print(f"‚úÖ Markdown summary saved: {md_path}")

In [None]:
# Download results (for Colab)
try:
    from google.colab import files
    
    # Zip all exports
    !cd {OUTPUT_DIR} && zip -r exports.zip exports visualizations
    
    # Download
    files.download(f"{OUTPUT_DIR}/exports.zip")
    print("üì• Download started!")
except ImportError:
    print("üí° Not running in Colab. Results are saved locally at:")
    print(f"   {OUTPUT_DIR}/exports/")
    print(f"   {OUTPUT_DIR}/visualizations/")

---

# üéØ Part 9: Hands-On Exercises

Try these exercises to deepen your understanding.

In [None]:
# Exercise 1: Vary the number of comparison samples
# Question: How does the number of samples affect result stability?

# TODO: Uncomment and modify
# sample_sizes = [10, 25, 50, 100]
# results_by_samples = {}
# 
# for n_samples in sample_sizes:
#     config = ComparisonConfig(
#         device=DEVICE,
#         image_datasets=["cifar10"],
#         image_comparison_samples=n_samples,
#         image_epochs=2,
#         output_dir=f"{OUTPUT_DIR}/exercise1/samples_{n_samples}"
#     )
#     comparison = XAIMethodsComparison(config)
#     results = comparison.run_full_comparison(run_text=False)
#     results_by_samples[n_samples] = results

print("üí° Exercise 1: Vary sample sizes to study result stability")

In [None]:
# Exercise 2: Compare different model architectures
# Question: Does the model architecture affect XAI method performance?

# TODO: Uncomment and modify
# model_types = ["resnet", "vgg", "densenet"]  # If supported
# results_by_model = {}
# 
# for model_type in model_types:
#     config = ComparisonConfig(
#         device=DEVICE,
#         image_model_type=model_type,
#         ...
#     )
#     ...

print("üí° Exercise 2: Compare XAI methods across different model architectures")

In [None]:
# Exercise 3: Analyze text top-k sensitivity
# Question: How does changing top-k affect IG-DiET overlap?

# TODO: Uncomment and modify
# top_k_values = [3, 5, 10, 15, 20]
# overlaps_by_k = []
# 
# for k in top_k_values:
#     config = ComparisonConfig(
#         device=DEVICE,
#         text_datasets=["sst2"],
#         text_top_k=k,
#         ...
#     )
#     ...
# 
# Plot: top-k vs overlap

print("üí° Exercise 3: Study how top-k affects IG-DiET token agreement")

---

# üéì Conclusion

## What We Learned

1. **DiET** (Discriminative Feature Attribution) provides an alternative approach to traditional XAI methods
2. **GradCAM** remains a strong baseline for image attribution
3. **Integrated Gradients** provides theoretically grounded text attributions
4. **Multi-dataset evaluation** is crucial for robust conclusions

## Key Findings

- DiET often identifies different important features compared to GradCAM/IG
- The agreement between methods varies significantly across datasets
- Perturbation-based metrics provide actionable evaluation

## Next Steps

1. Try the framework on your own datasets
2. Extend with additional XAI methods (SHAP, LIME, etc.)
3. Add new evaluation metrics
4. Contribute improvements to the repository!

## Resources

- üìö [DiET Paper (NeurIPS 2023)](https://arxiv.org/abs/2305.04249)
- üìÅ [Repository](https://github.com/xMOROx/Machine-Learning-Project-2025-2026)
- üìñ [Framework Documentation](../README.md)

---

**Happy Experimenting! üöÄ**