# DiET vs Basic XAI Methods: Comprehensive Comparison Framework

## A Complete Experimental Pipeline for Image and Text Classification

---

**Author:** Machine Learning Research Team  
**Date:** 2025-2026 Academic Year  
**Course:** Advanced Machine Learning

---

### Abstract

This notebook presents a comprehensive experimental comparison between **DiET (Distractor Erasure Tuning)** and standard XAI methods:

- **Image Classification:** DiET vs GradCAM on CIFAR-10, CIFAR-100, SVHN, Fashion-MNIST
- **Text Classification:** DiET vs Integrated Gradients on SST-2, IMDB, AG News

The notebook provides robust evaluation metrics, visual summaries, and statistical analysis suitable for academic presentations.

### Metrics Used

**Image Attribution Metrics:**
- Pixel Perturbation (keep/remove important regions)
- AOPC (Area Over Perturbation Curve)
- Insertion/Deletion Curves
- Faithfulness Correlation

**Text Attribution Metrics:**
- Top-K Token Overlap (K = 3, 5, 10, 15, 20)
- Attribution Correlation
- Token-level Agreement Analysis

### Reference

Bhalla, U., et al. (2023). **"Discriminative Feature Attributions: Bridging Post Hoc Explainability and Inherent Interpretability."** *NeurIPS 2023.*

---

## Table of Contents

1. [Environment Setup](#1-environment-setup)
2. [Experimental Configuration](#2-experimental-configuration)
3. [Image Experiments: DiET vs GradCAM](#3-image-experiments-diet-vs-gradcam)
4. [Text Experiments: DiET vs Integrated Gradients](#4-text-experiments-diet-vs-integrated-gradients)
5. [Combined Results Summary](#5-combined-results-summary)
6. [Statistical Analysis](#6-statistical-analysis)
7. [Export Results](#7-export-results)

---

## 1. Environment Setup

### 1.1 Hardware Configuration

This notebook is optimized for Google Colab Pro with GPU acceleration.

In [None]:
import torch
import gc

def cleanup_memory():
    """Clean up GPU memory."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

print("=" * 70)
print("HARDWARE CONFIGURATION")
print("=" * 70)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f"GPU Available: {gpu_name}")
    print(f"Memory: {gpu_memory:.1f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"PyTorch Version: {torch.__version__}")

    if gpu_memory >= 15:
        print("\nConfiguration: HIGH-MEMORY GPU")
        GPU_CONFIG = "high"
    elif gpu_memory >= 8:
        print("\nConfiguration: STANDARD GPU")
        GPU_CONFIG = "standard"
    else:
        print("\nConfiguration: LOW-MEMORY GPU")
        GPU_CONFIG = "low"
else:
    print("WARNING: No GPU available. Training will be slow.")
    print("Enable GPU: Runtime -> Change runtime type -> GPU")
    GPU_CONFIG = "cpu"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\nUsing device: {DEVICE.upper()}")
print("=" * 70)

### 1.2 Clone Repository and Install Dependencies

In [None]:
import os

REPO_URL = "https://github.com/xMOROx/Machine-Learning-Project-2025-2026.git"
REPO_DIR = "Machine-Learning-Project-2025-2026"

if not os.path.exists(REPO_DIR):
    print("Cloning repository...")
    !git clone --recursive {REPO_URL}
    print("Repository cloned successfully.")
else:
    print("Repository already exists. Pulling latest changes...")
    %cd {REPO_DIR}
    !git pull
    !git submodule update --init --recursive
    %cd ..

%cd {REPO_DIR}

In [None]:
print("Installing dependencies...")

!pip install -q transformers datasets tqdm matplotlib seaborn pandas pillow scikit-learn captum scipy

print("All dependencies installed.")

In [None]:
import sys
import os

repo_base = "/content/Machine-Learning-Project-2025-2026/scripts"
if repo_base not in sys.path:
    sys.path.append(repo_base)

from xai_experiments.experiments.xai_comparison import XAIMethodsComparison, ComparisonConfig
from xai_experiments.visualization.comparison_plots import ComparisonVisualizer, VisualizationConfig

print("Import successful!")

# Standard imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm.auto import tqdm
from scipy import stats
import json
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False

print("All modules imported successfully.")
print(f"Experiment started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

---

## 2. Experimental Configuration

### 2.1 Configuration Parameters

Parameters are automatically adjusted based on available GPU memory.
**Note:** Text experiments use reduced batch sizes and sequence lengths to prevent OOM errors.

In [None]:
# Configuration based on GPU capabilities
if GPU_CONFIG == "high":  # A100, V100
    CONFIG = {
        "image_batch_size": 512,
        "image_epochs": 25,
        "image_max_samples": 10000,
        "image_comparison_samples": 2000,
        "image_datasets": ["cifar10", "cifar100", "svhn", "fashion_mnist"],
        "text_batch_size": 128,  
        "text_epochs": 25,  
        "text_max_length": 256,  
        "text_max_samples": 3000,  
        "text_comparison_samples": 1000,  
        "text_datasets": ["sst2", "imdb", "ag_news"],
        "text_top_k_values": [3, 5, 10, 15, 20],
    }
elif GPU_CONFIG == "standard":  # T4, P100
    CONFIG = {
        "image_batch_size": 256,
        "image_epochs": 15,
        "image_max_samples": 5000,
        "image_comparison_samples": 500,
        "image_datasets": ["cifar10", "cifar100", "svhn", "fashion_mnist"],
        "text_batch_size": 16,
        "text_epochs": 10,
        "text_max_length": 128,
        "text_max_samples": 1500,
        "text_comparison_samples": 300,
        "text_datasets": ["sst2", "imdb", "ag_news"],
        "text_top_k_values": [3, 5, 10, 15, 20],
    }
elif GPU_CONFIG == "low":  # K80, older GPUs
    CONFIG = {
        "image_batch_size": 128,
        "image_epochs": 10,
        "image_max_samples": 2000,
        "image_comparison_samples": 200,
        "image_datasets": ["cifar10", "svhn"],
        "text_batch_size": 8,
        "text_epochs": 5,
        "text_max_length": 64,
        "text_max_samples": 1000,
        "text_comparison_samples": 100,
        "text_datasets": ["sst2", "ag_news"],
        "text_top_k_values": [3, 5, 10],
    }
else:  # CPU
    CONFIG = {
        "image_batch_size": 16,
        "image_epochs": 2,
        "image_max_samples": 500,
        "image_comparison_samples": 20,
        "image_datasets": ["cifar10"],
        "text_batch_size": 4,
        "text_epochs": 1,
        "text_max_length": 64,
        "text_max_samples": 200,
        "text_comparison_samples": 20,
        "text_datasets": ["sst2"],
        "text_top_k_values": [3, 5],
    }

# Display configuration
print("=" * 70)
print("EXPERIMENT CONFIGURATION")
print("=" * 70)
print("\nImage Experiments:")
print(f"  Datasets: {CONFIG['image_datasets']}")
print(f"  Batch size: {CONFIG['image_batch_size']}")
print(f"  Epochs: {CONFIG['image_epochs']}")
print(f"  Max samples: {CONFIG['image_max_samples']}")
print(f"  Comparison samples: {CONFIG['image_comparison_samples']}")
print("\nText Experiments:")
print(f"  Datasets: {CONFIG['text_datasets']}")
print(f"  Batch size: {CONFIG['text_batch_size']}")
print(f"  Epochs: {CONFIG['text_epochs']}")
print(f"  Max length: {CONFIG['text_max_length']}")
print(f"  Max samples: {CONFIG['text_max_samples']}")
print(f"  Comparison samples: {CONFIG['text_comparison_samples']}")
print(f"  Top-k values: {CONFIG['text_top_k_values']}")
print("=" * 70)

### 2.2 Initialize Comparison Framework

In [None]:
# Create output directory
OUTPUT_DIR = "./outputs/colab_experiments/comprehensive_comparison"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Initialize configuration
comparison_config = ComparisonConfig(
    device=DEVICE,
    # Image settings
    image_datasets=CONFIG["image_datasets"],
    image_batch_size=CONFIG["image_batch_size"],
    image_epochs=CONFIG["image_epochs"],
    image_max_samples=CONFIG["image_max_samples"],
    image_comparison_samples=CONFIG["image_comparison_samples"],
    # Text settings
    text_datasets=CONFIG["text_datasets"],
    text_batch_size=CONFIG["text_batch_size"],
    text_epochs=CONFIG["text_epochs"],
    text_max_length=CONFIG["text_max_length"],
    text_max_samples=CONFIG["text_max_samples"],
    text_comparison_samples=CONFIG["text_comparison_samples"],
    text_top_k_values=CONFIG["text_top_k_values"],
    low_vram=(GPU_CONFIG in ["low", "standard"]),
    output_dir=OUTPUT_DIR,
    compute_all_metrics=True,
)

# Initialize comparison object
comparison = XAIMethodsComparison(comparison_config)

# Initialize visualizer with enhanced config
viz_config = VisualizationConfig(
    figsize=(14, 8),
    dpi=150,
    style="whitegrid",
    use_gradients=True,
)
visualizer = ComparisonVisualizer(output_dir=OUTPUT_DIR, config=viz_config)

print("Comparison framework initialized.")
print(f"Output directory: {OUTPUT_DIR}")

---

## 3. Image Experiments: DiET vs GradCAM

### 3.1 Run Image Comparison Experiments

This section compares DiET and GradCAM on image classification datasets.

In [None]:
print("=" * 70)
print("IMAGE EXPERIMENTS: DiET vs GradCAM")
print("=" * 70)
print(f"\nDatasets: {CONFIG['image_datasets']}")
print(f"Samples per dataset: {CONFIG['image_max_samples']}")
print(f"Training epochs: {CONFIG['image_epochs']}")
print("\nStarting experiments...\n")

# Clean up memory before starting
cleanup_memory()

image_start_time = datetime.now()

image_results = comparison.run_all_image_comparisons(skip_training=False)

image_end_time = datetime.now()
image_duration = (image_end_time - image_start_time).seconds

cleanup_memory()

print(f"\nImage experiments completed in {image_duration // 60} minutes {image_duration % 60} seconds.")

### 3.2 Image Results: Visual Summary

In [None]:
# Extract image results with all metrics
image_data = []
for dataset_name, result in image_results.items():
    if "error" not in result:
        row = {
            "Dataset": dataset_name.upper(),
            "Baseline Accuracy": result.get("baseline_accuracy", 0),
            "DiET Accuracy": result.get("diet_accuracy", 0),
            "GradCAM Score": result.get("gradcam_mean_score", 0),
            "DiET Score": result.get("diet_mean_score", 0),
            "Improvement": result.get("improvement", 0),
            "DiET Better": result.get("diet_better", False),
        }
        # Add additional metrics if available
        if result.get("gradcam_aopc") is not None:
            row["GradCAM AOPC"] = result.get("gradcam_aopc", 0)
            row["DiET AOPC"] = result.get("diet_aopc", 0)
        if result.get("gradcam_faithfulness") is not None:
            row["GradCAM Faithfulness"] = result.get("gradcam_faithfulness", 0)
            row["DiET Faithfulness"] = result.get("diet_faithfulness", 0)
        if result.get("gradcam_insertion_auc") is not None:
            row["GradCAM Insertion AUC"] = result.get("gradcam_insertion_auc", 0)
            row["DiET Insertion AUC"] = result.get("diet_insertion_auc", 0)
            row["GradCAM Deletion AUC"] = result.get("gradcam_deletion_auc", 0)
            row["DiET Deletion AUC"] = result.get("diet_deletion_auc", 0)
        image_data.append(row)

image_df = pd.DataFrame(image_data)

# Display table
print("\n" + "=" * 70)
print("IMAGE EXPERIMENTS: QUANTITATIVE RESULTS")
print("=" * 70)
print("\nAll Metrics (higher = better attribution quality for most metrics):\n")
print(image_df.to_string(index=False))

# Summary statistics
if len(image_df) > 0:
    diet_wins = image_df["DiET Better"].sum()
    total = len(image_df)
    avg_improvement = image_df["Improvement"].mean()
    print(f"\n{'='*50}")
    print(f"Summary: DiET outperforms GradCAM on {diet_wins}/{total} datasets")
    print(f"Average improvement: {avg_improvement:+.4f}")
    print(f"{'='*50}")

In [None]:
if len(image_df) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle("Image Experiments: DiET vs GradCAM - Comprehensive Visual Summary", 
                 fontsize=18, fontweight='bold', y=1.02)
    
    for ax in axes.flat:
        ax.set_facecolor('#fafafa')

    # Plot 1: Attribution Quality Comparison
    datasets = image_df["Dataset"].tolist()
    x = np.arange(len(datasets))
    width = 0.35

    bars1 = axes[0, 0].bar(x - width/2, image_df["GradCAM Score"], width, 
                           label='GradCAM', color='#3498db', alpha=0.85, edgecolor='#333', linewidth=1.5)
    bars2 = axes[0, 0].bar(x + width/2, image_df["DiET Score"], width, 
                           label='DiET', color='#2ecc71', alpha=0.85, edgecolor='#333', linewidth=1.5)
    axes[0, 0].set_ylabel('Pixel Perturbation Score', fontweight='bold')
    axes[0, 0].set_title('Attribution Quality Comparison (Higher = Better)', fontweight='bold', fontsize=12)
    axes[0, 0].set_xticks(x)
    axes[0, 0].set_xticklabels(datasets, fontweight='bold')
    axes[0, 0].legend(framealpha=0.95)
    axes[0, 0].set_ylim(0, max(image_df["GradCAM Score"].max(), image_df["DiET Score"].max()) * 1.2)
    axes[0, 0].grid(True, alpha=0.3, linestyle='--')
    
    for bar in bars1:
        axes[0, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                       f'{bar.get_height():.3f}', ha='center', fontsize=9, fontweight='bold')
    for bar in bars2:
        axes[0, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                       f'{bar.get_height():.3f}', ha='center', fontsize=9, fontweight='bold')

    # Plot 2: Improvement Over GradCAM
    improvements = image_df["Improvement"].tolist()
    colors = ['#2ecc71' if imp > 0 else '#e74c3c' for imp in improvements]
    bars = axes[0, 1].barh(datasets, improvements, color=colors, alpha=0.85, edgecolor='#333', linewidth=1.5)
    axes[0, 1].axvline(x=0, color='black', linestyle='-', linewidth=1)
    axes[0, 1].set_xlabel('Improvement (DiET - GradCAM)', fontweight='bold')
    axes[0, 1].set_title('DiET Improvement Over GradCAM', fontweight='bold', fontsize=12)
    axes[0, 1].grid(True, alpha=0.3, linestyle='--')
    for bar, v in zip(bars, improvements):
        axes[0, 1].text(v + 0.005 if v >= 0 else v - 0.005, bar.get_y() + bar.get_height()/2,
                       f'{v:+.4f}', va='center', ha='left' if v >= 0 else 'right', 
                       fontsize=10, fontweight='bold')

    # Plot 3: Model Accuracy Comparison
    bars3 = axes[1, 0].bar(x - width/2, image_df["Baseline Accuracy"], width, 
                           label='Baseline', color='#f39c12', alpha=0.85, edgecolor='#333', linewidth=1.5)
    bars4 = axes[1, 0].bar(x + width/2, image_df["DiET Accuracy"], width, 
                           label='After DiET', color='#9b59b6', alpha=0.85, edgecolor='#333', linewidth=1.5)
    axes[1, 0].set_ylabel('Accuracy (%)', fontweight='bold')
    axes[1, 0].set_title('Model Accuracy Before and After DiET', fontweight='bold', fontsize=12)
    axes[1, 0].set_xticks(x)
    axes[1, 0].set_xticklabels(datasets, fontweight='bold')
    axes[1, 0].legend(framealpha=0.95)
    axes[1, 0].set_ylim(0, 100)
    axes[1, 0].grid(True, alpha=0.3, linestyle='--')

    # Plot 4: Summary Pie Chart with better styling
    diet_wins = image_df["DiET Better"].sum()
    gradcam_wins = len(image_df) - diet_wins
    if diet_wins > 0 or gradcam_wins > 0:
        wedges, texts, autotexts = axes[1, 1].pie(
            [diet_wins, gradcam_wins], 
            labels=['DiET Better', 'GradCAM Better'],
            autopct='%1.0f%%', 
            colors=['#2ecc71', '#3498db'], 
            startangle=90,
            explode=(0.05, 0),
            shadow=True,
            textprops={'fontweight': 'bold'}
        )
        for autotext in autotexts:
            autotext.set_color('white')
            autotext.set_fontsize(14)
    axes[1, 1].set_title('Method Performance Summary', fontweight='bold', fontsize=12)

    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/image_visual_summary.png', dpi=150, bbox_inches='tight', facecolor='white')
    plt.show()

    print(f"\nFigure saved: {OUTPUT_DIR}/image_visual_summary.png")

---

## 4. Text Experiments: DiET vs Integrated Gradients

### 4.1 Run Text Comparison Experiments

This section compares DiET and Integrated Gradients on text classification datasets using BERT.

**Note:** Memory is cleaned between each dataset to prevent OOM errors.

In [None]:
print("=" * 70)
print("TEXT EXPERIMENTS: DiET vs Integrated Gradients")
print("=" * 70)
print(f"\nDatasets: {CONFIG['text_datasets']}")
print(f"Model: BERT-base-uncased")
print(f"Max sequence length: {CONFIG['text_max_length']}")
print(f"Training epochs: {CONFIG['text_epochs']}")
print(f"Top-k values: {CONFIG['text_top_k_values']}")
print("\nStarting experiments...\n")

# Clean up memory before starting text experiments
cleanup_memory()

text_start_time = datetime.now()

# Run text experiments
text_results = comparison.run_all_text_comparisons(skip_training=False)

text_end_time = datetime.now()
text_duration = (text_end_time - text_start_time).seconds

# Clean up after text experiments
cleanup_memory()

print(f"\nText experiments completed in {text_duration // 60} minutes {text_duration % 60} seconds.")

### 4.2 Text Results: Comprehensive Visual Summary

In [None]:
# Extract text results with all metrics
text_data = []
for dataset_name, result in text_results.items():
    if "error" not in result:
        row = {
            "Dataset": dataset_name.upper(),
            "Baseline Accuracy": result.get("baseline_accuracy", 0),
            "Samples Compared": result.get("samples_compared", 0),
            "Mean Correlation": result.get("mean_correlation", 0),
        }
        # Add all top-k overlap metrics
        for k in CONFIG["text_top_k_values"]:
            key = f"top_{k}_overlap"
            if key in result:
                row[f"Top-{k} Overlap"] = result[key]
                row[f"Top-{k} Std"] = result.get(f"{key}_std", 0)
        text_data.append(row)

text_df = pd.DataFrame(text_data)

print("\n" + "=" * 70)
print("TEXT EXPERIMENTS: QUANTITATIVE RESULTS")
print("=" * 70)
print(f"\nToken Overlap between IG and DiET for various K values:\n")
if len(text_df) > 0:
    print(text_df.to_string(index=False))

    print(f"\n{'='*50}")
    print("Summary Statistics:")
    for k in CONFIG["text_top_k_values"]:
        col = f"Top-{k} Overlap"
        if col in text_df.columns:
            avg = text_df[col].mean()
            std = text_df[col].std() if len(text_df) > 1 else 0
            print(f"  Average Top-{k} Overlap: {avg:.4f} (±{std:.4f})")
    if "Mean Correlation" in text_df.columns:
        avg_corr = text_df["Mean Correlation"].mean()
        print(f"  Average Correlation: {avg_corr:.4f}")
    print(f"{'='*50}")
else:
    print("No successful text experiments to display.")

In [None]:
if len(text_df) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle("Text Experiments: DiET vs Integrated Gradients - Comprehensive Visual Summary", 
                 fontsize=18, fontweight='bold', y=1.02)
    
    for ax in axes.flat:
        ax.set_facecolor('#fafafa')

    datasets = text_df["Dataset"].tolist()

    # Plot 1: Top-K Overlap Comparison across K values
    k_values = [k for k in CONFIG["text_top_k_values"] if f"Top-{k} Overlap" in text_df.columns]
    if k_values:
        x = np.arange(len(datasets))
        width = 0.8 / len(k_values)
        colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(k_values)))
        
        for i, k in enumerate(k_values):
            col = f"Top-{k} Overlap"
            offset = (i - len(k_values)/2 + 0.5) * width
            bars = axes[0, 0].bar(x + offset, text_df[col], width, label=f'Top-{k}', 
                                 color=colors[i], alpha=0.85, edgecolor='#333', linewidth=0.5)
        
        axes[0, 0].axhline(y=0.5, color='#e74c3c', linestyle='--', linewidth=2, alpha=0.7, label='50% Threshold')
        axes[0, 0].set_ylabel('Token Overlap Score', fontweight='bold')
        axes[0, 0].set_title('IG-DiET Token Overlap Across K Values', fontweight='bold', fontsize=12)
        axes[0, 0].set_xticks(x)
        axes[0, 0].set_xticklabels(datasets, fontweight='bold')
        axes[0, 0].legend(loc='upper right', ncol=3, framealpha=0.95)
        axes[0, 0].set_ylim(0, 1.1)
        axes[0, 0].grid(True, alpha=0.3, linestyle='--')

    # Plot 2: Top-K Overlap Line Chart (trend across K)
    if k_values:
        for idx, dataset in enumerate(datasets):
            overlaps = [text_df[text_df["Dataset"] == dataset][f"Top-{k} Overlap"].values[0] 
                       for k in k_values if f"Top-{k} Overlap" in text_df.columns]
            axes[0, 1].plot(k_values, overlaps, 'o-', linewidth=2.5, markersize=10, 
                           label=dataset, alpha=0.8)
        
        axes[0, 1].axhline(y=0.5, color='#e74c3c', linestyle='--', linewidth=2, alpha=0.7)
        axes[0, 1].set_xlabel('K (Top-K Tokens)', fontweight='bold')
        axes[0, 1].set_ylabel('Overlap Score', fontweight='bold')
        axes[0, 1].set_title('Overlap Trend Across K Values', fontweight='bold', fontsize=12)
        axes[0, 1].legend(framealpha=0.95)
        axes[0, 1].set_ylim(0, 1.05)
        axes[0, 1].grid(True, alpha=0.3, linestyle='--')

    # Plot 3: BERT Accuracy
    bars2 = axes[1, 0].bar(datasets, text_df["Baseline Accuracy"], 
                           color='#3498db', alpha=0.85, edgecolor='#333', linewidth=1.5)
    axes[1, 0].set_ylabel('Accuracy (%)', fontweight='bold')
    axes[1, 0].set_title('BERT Classification Accuracy', fontweight='bold', fontsize=12)
    axes[1, 0].set_ylim(0, 105)
    axes[1, 0].grid(True, alpha=0.3, linestyle='--')
    for bar, acc in zip(bars2, text_df["Baseline Accuracy"]):
        axes[1, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
                       f'{acc:.1f}%', ha='center', fontsize=11, fontweight='bold')

    # Plot 4: Correlation scores
    if "Mean Correlation" in text_df.columns:
        correlations = text_df["Mean Correlation"].tolist()
        colors_corr = ['#2ecc71' if c > 0 else '#e74c3c' for c in correlations]
        bars3 = axes[1, 1].bar(datasets, correlations, color=colors_corr, 
                              alpha=0.85, edgecolor='#333', linewidth=1.5)
        axes[1, 1].axhline(y=0, color='black', linestyle='-', linewidth=1)
        axes[1, 1].set_ylabel('Correlation', fontweight='bold')
        axes[1, 1].set_title('IG-DiET Attribution Correlation', fontweight='bold', fontsize=12)
        axes[1, 1].set_ylim(-1, 1)
        axes[1, 1].grid(True, alpha=0.3, linestyle='--')
        for bar, val in zip(bars3, correlations):
            axes[1, 1].text(bar.get_x() + bar.get_width()/2, val + 0.05 if val >= 0 else val - 0.1, 
                           f'{val:.3f}', ha='center', fontsize=10, fontweight='bold')
    else:
        axes[1, 1].text(0.5, 0.5, "Correlation data not available", ha='center', va='center', fontsize=12)
        axes[1, 1].axis('off')

    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/text_visual_summary.png', dpi=150, bbox_inches='tight', facecolor='white')
    plt.show()

    print(f"\nFigure saved: {OUTPUT_DIR}/text_visual_summary.png")
else:
    print("No text experiment data to visualize.")

---

## 5. Combined Results Summary

### 5.1 Complete Results Table

In [None]:
print("=" * 70)
print("COMBINED EXPERIMENT SUMMARY")
print("=" * 70)

full_df = comparison.get_results_dataframe()
print("\nComplete Results Table:\n")
if len(full_df) > 0:
    print(full_df.to_string(index=False))
else:
    print("No results to display.")

In [None]:
# Generate summary report text
report = comparison.generate_summary_report()
print(report)

### 5.2 Combined Visual Summary with All Metrics

In [None]:
if len(image_df) > 0 or len(text_df) > 0:
    dashboard_fig = visualizer.create_summary_dashboard(
        image_results=image_results if len(image_df) > 0 else None,
        text_results=text_results if len(text_df) > 0 else None,
        save_name="comprehensive_dashboard",
        show=True
    )

    # Create top-k overlap comparison for text if we have data
    if len(text_df) > 0:
        text_overlap_data = {}
        for _, row in text_df.iterrows():
            dataset = row["Dataset"]
            text_overlap_data[dataset] = {}
            for k in CONFIG["text_top_k_values"]:
                col = f"Top-{k} Overlap"
                if col in row:
                    text_overlap_data[dataset][f"top_{k}_overlap"] = row[col]
                    std_col = f"Top-{k} Std"
                    if std_col in row:
                        text_overlap_data[dataset][f"top_{k}_overlap_std"] = row[std_col]
        
        if text_overlap_data:
            topk_fig = visualizer.plot_top_k_overlap_comparison(
                text_overlap_data,
                title="Token Overlap Across K Values by Dataset",
                save_name="topk_overlap_comparison",
                show=True
            )

    print(f"\nAll dashboard figures saved to: {OUTPUT_DIR}")
else:
    print("No data available for dashboard visualization.")

---

## 6. Statistical Analysis

### 6.1 Statistical Tests

In [None]:
print("=" * 70)
print("STATISTICAL ANALYSIS")
print("=" * 70)

# Image experiments statistical analysis
if len(image_df) >= 3:
    gradcam_scores = image_df["GradCAM Score"].values
    diet_scores = image_df["DiET Score"].values

    # Paired t-test
    t_stat, p_value = stats.ttest_rel(diet_scores, gradcam_scores)

    # Wilcoxon signed-rank test (non-parametric alternative)
    try:
        w_stat, w_pvalue = stats.wilcoxon(diet_scores, gradcam_scores)
    except:
        w_stat, w_pvalue = None, None

    # Effect size (Cohen's d)
    pooled_std = np.sqrt((np.var(gradcam_scores) + np.var(diet_scores)) / 2)
    cohens_d = (np.mean(diet_scores) - np.mean(gradcam_scores)) / pooled_std if pooled_std > 0 else 0

    print("\nImage Experiments (DiET vs GradCAM):")
    print(f"  Paired t-test:")
    print(f"    t-statistic: {t_stat:.4f}")
    print(f"    p-value: {p_value:.4f}")

    if p_value < 0.05:
        print(f"    Result: Statistically significant (p < 0.05)")
    else:
        print(f"    Result: Not statistically significant")

    if w_pvalue is not None:
        print(f"\n  Wilcoxon signed-rank test:")
        print(f"    W-statistic: {w_stat:.4f}")
        print(f"    p-value: {w_pvalue:.4f}")

    print(f"\n  Effect Size (Cohen's d): {cohens_d:.4f}")
    if abs(cohens_d) < 0.2:
        print("    Interpretation: Small effect")
    elif abs(cohens_d) < 0.8:
        print("    Interpretation: Medium effect")
    else:
        print("    Interpretation: Large effect")
else:
    print("\nImage Experiments: Not enough data points for statistical testing (need >= 3)")

# Text experiments statistical analysis
if len(text_df) >= 2:
    # Get overlap values for the default k (5)
    overlap_col = "Top-5 Overlap" if "Top-5 Overlap" in text_df.columns else None
    
    if overlap_col:
        overlaps = text_df[overlap_col].values

        # One-sample t-test against 0.5 threshold
        t_stat_text, p_value_text = stats.ttest_1samp(overlaps, 0.5)

        print("\nText Experiments (IG-DiET Overlap):")
        print(f"  One-sample t-test (vs 0.5 threshold):")
        print(f"    Mean overlap: {np.mean(overlaps):.4f}")
        print(f"    t-statistic: {t_stat_text:.4f}")
        print(f"    p-value: {p_value_text:.4f}")

        if np.mean(overlaps) >= 0.5:
            print("    Interpretation: Methods show agreement above chance level")
        else:
            print("    Interpretation: Methods identify different features")
    
    if "Mean Correlation" in text_df.columns:
        correlations = text_df["Mean Correlation"].values
        print(f"\n  Attribution Correlation:")
        print(f"    Mean: {np.mean(correlations):.4f}")
        print(f"    Std: {np.std(correlations):.4f}")
else:
    print("\nText Experiments: Not enough data points for statistical testing")

---

## 7. Export Results

### 7.1 Save All Results

In [None]:
comparison.save_results()

if len(image_df) > 0:
    image_df.to_csv(f'{OUTPUT_DIR}/image_results.csv', index=False)
if len(text_df) > 0:
    text_df.to_csv(f'{OUTPUT_DIR}/text_results.csv', index=False)
if len(full_df) > 0:
    full_df.to_csv(f'{OUTPUT_DIR}/all_results.csv', index=False)

with open(f'{OUTPUT_DIR}/experiment_config.json', 'w') as f:
    json.dump(CONFIG, f, indent=2)

print("\nResults saved:")
print(f"  - {OUTPUT_DIR}/comparison_results.json")
print(f"  - {OUTPUT_DIR}/image_results.csv")
print(f"  - {OUTPUT_DIR}/text_results.csv")
print(f"  - {OUTPUT_DIR}/all_results.csv")
print(f"  - {OUTPUT_DIR}/experiment_config.json")
print(f"  - {OUTPUT_DIR}/image_visual_summary.png")
print(f"  - {OUTPUT_DIR}/text_visual_summary.png")
print(f"  - {OUTPUT_DIR}/combined_visual_summary.png")

In [None]:
try:
    viz_files = comparison.visualize_results(save_plots=True, show=False)
    print("\nAdditional visualizations generated:")
    for name, path in viz_files.items():
        print(f"  - {name}: {path}")
except Exception as e:
    print(f"Note: Some visualizations could not be generated: {e}")

In [None]:
try:
    from google.colab import files

    !zip -r comprehensive_results.zip {OUTPUT_DIR}/

    print("\nDownload your results:")
    files.download('comprehensive_results.zip')
except:
    print(f"\nResults are saved locally in: {OUTPUT_DIR}/")
    print("(Download option only available in Google Colab)")

### 7.2 Final Report

In [None]:
total_duration = (image_duration if 'image_duration' in dir() else 0) + (text_duration if 'text_duration' in dir() else 0)

final_report = f"""
================================================================================
                    DiET vs BASIC XAI METHODS
                    COMPREHENSIVE COMPARISON REPORT
================================================================================

Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
Device: {DEVICE} ({GPU_CONFIG})
Total Duration: {total_duration // 60} minutes {total_duration % 60} seconds

--------------------------------------------------------------------------------
IMAGE EXPERIMENTS: DiET vs GradCAM
--------------------------------------------------------------------------------
Datasets: {CONFIG['image_datasets']}
Model: ResNet
Training epochs: {CONFIG['image_epochs']}
Comparison samples: {CONFIG['image_comparison_samples']}

"""

if len(image_df) > 0:
    for _, row in image_df.iterrows():
        status = "[+]" if row["DiET Better"] else "[-]"
        final_report += f"{status} {row['Dataset']}: GradCAM={row['GradCAM Score']:.4f}, DiET={row['DiET Score']:.4f}, Improvement={row['Improvement']:+.4f}\n"

    diet_wins = image_df["DiET Better"].sum()
    final_report += f"\nSummary: DiET outperforms GradCAM on {diet_wins}/{len(image_df)} datasets\n"

final_report += f"""
--------------------------------------------------------------------------------
TEXT EXPERIMENTS: DiET vs Integrated Gradients
--------------------------------------------------------------------------------
Datasets: {CONFIG['text_datasets']}
Model: BERT-base-uncased
Max length: {CONFIG['text_max_length']}
Training epochs: {CONFIG['text_epochs']}
Top-K values: {CONFIG['text_top_k_values']}

"""

if len(text_df) > 0:
    for _, row in text_df.iterrows():
        default_overlap = row.get("Top-5 Overlap", row.get("Top-3 Overlap", 0))
        level = "[HIGH]" if default_overlap >= 0.5 else "[MED]" if default_overlap >= 0.3 else "[LOW]"
        final_report += f"{level} {row['Dataset']}: Accuracy={row['Baseline Accuracy']:.1f}%"
        for k in CONFIG['text_top_k_values']:
            col = f"Top-{k} Overlap"
            if col in row:
                final_report += f", Top-{k}={row[col]:.3f}"
        final_report += "\n"

    if "Top-5 Overlap" in text_df.columns:
        avg_overlap = text_df["Top-5 Overlap"].mean()
        final_report += f"\nSummary: Average Top-5 IG-DiET overlap = {avg_overlap:.4f}\n"

final_report += f"""
================================================================================
Output Files:
  - {OUTPUT_DIR}/comparison_results.json
  - {OUTPUT_DIR}/image_results.csv
  - {OUTPUT_DIR}/text_results.csv
  - {OUTPUT_DIR}/all_results.csv
  - {OUTPUT_DIR}/image_visual_summary.png
  - {OUTPUT_DIR}/text_visual_summary.png
  - {OUTPUT_DIR}/combined_visual_summary.png
  - {OUTPUT_DIR}/comprehensive_dashboard.png
================================================================================
"""

# Save final report
with open(f'{OUTPUT_DIR}/final_report.txt', 'w') as f:
    f.write(final_report)

print(final_report)

---

## References

1. Bhalla, U., et al. (2023). "Discriminative Feature Attributions: Bridging Post Hoc Explainability and Inherent Interpretability." *NeurIPS 2023.*

2. Selvaraju, R. R., et al. (2017). "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization." *ICCV 2017.*

3. Sundararajan, M., Taly, A., & Yan, Q. (2017). "Axiomatic Attribution for Deep Networks." *ICML 2017.*

4. Devlin, J., et al. (2019). "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding." *NAACL 2019.*

---

**Notebook Version:** 2.0  
**Last Updated:** 2025-2026 Academic Year  
**Repository:** https://github.com/xMOROx/Machine-Learning-Project-2025-2026