# DiET vs Basic XAI Methods: Comprehensive Comparison Framework

## A Complete Experimental Pipeline for Image and Text Classification

---

**Author:** Machine Learning Research Team  
**Date:** 2025-2026 Academic Year  
**Course:** Advanced Machine Learning

---

### Abstract

This notebook presents a comprehensive experimental comparison between **DiET (Distractor Erasure Tuning)** and standard XAI methods:

- **Image Classification:** DiET vs GradCAM on CIFAR-10, CIFAR-100, SVHN, Fashion-MNIST
- **Text Classification:** DiET vs Integrated Gradients on SST-2, IMDB, AG News

The notebook provides robust evaluation metrics, visual summaries, and statistical analysis suitable for academic presentations.

### Reference

Bhalla, U., et al. (2023). **"Discriminative Feature Attributions: Bridging Post Hoc Explainability and Inherent Interpretability."** *NeurIPS 2023.*

---

## Table of Contents

1. [Environment Setup](#1-environment-setup)
2. [Experimental Configuration](#2-experimental-configuration)
3. [Image Experiments: DiET vs GradCAM](#3-image-experiments-diet-vs-gradcam)
4. [Text Experiments: DiET vs Integrated Gradients](#4-text-experiments-diet-vs-integrated-gradients)
5. [Combined Results Summary](#5-combined-results-summary)
6. [Statistical Analysis](#6-statistical-analysis)
7. [Export Results](#7-export-results)

---

## 1. Environment Setup

### 1.1 Hardware Configuration

This notebook is optimized for Google Colab Pro with GPU acceleration.

In [None]:
# Check GPU availability and type
import torch

print("=" * 70)
print("HARDWARE CONFIGURATION")
print("=" * 70)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f"GPU Available: {gpu_name}")
    print(f"Memory: {gpu_memory:.1f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"PyTorch Version: {torch.__version__}")
    
    if gpu_memory >= 15:
        print("\nConfiguration: HIGH-MEMORY GPU")
        GPU_CONFIG = "high"
    elif gpu_memory >= 8:
        print("\nConfiguration: STANDARD GPU")
        GPU_CONFIG = "standard"
    else:
        print("\nConfiguration: LOW-MEMORY GPU")
        GPU_CONFIG = "low"
else:
    print("WARNING: No GPU available. Training will be slow.")
    print("Enable GPU: Runtime -> Change runtime type -> GPU")
    GPU_CONFIG = "cpu"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\nUsing device: {DEVICE.upper()}")
print("=" * 70)

### 1.2 Clone Repository and Install Dependencies

In [None]:
# Clone the repository
import os

REPO_URL = "https://github.com/xMOROx/Machine-Learning-Project-2025-2026.git"
REPO_DIR = "Machine-Learning-Project-2025-2026"

if not os.path.exists(REPO_DIR):
    print("Cloning repository...")
    !git clone --recursive {REPO_URL}
    print("Repository cloned successfully.")
else:
    print("Repository already exists. Pulling latest changes...")
    %cd {REPO_DIR}
    !git pull
    !git submodule update --init --recursive
    %cd ..

%cd {REPO_DIR}

In [None]:
# Install required packages
print("Installing dependencies...")

!pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu118
!pip install -q transformers datasets tqdm matplotlib seaborn pandas numpy pillow scikit-learn captum scipy

print("All dependencies installed.")

In [None]:
# Add project to path and import modules
import sys
sys.path.insert(0, './scripts/xai_experiments')

# Standard imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm.auto import tqdm
from scipy import stats
import json
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11

print("All modules imported successfully.")
print(f"Experiment started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

---

## 2. Experimental Configuration

### 2.1 Configuration Parameters

Parameters are automatically adjusted based on available GPU memory.

In [None]:
# Configuration based on GPU capabilities
if GPU_CONFIG == "high":  # A100, V100
    CONFIG = {
        # Image settings
        "image_batch_size": 128,
        "image_epochs": 10,
        "image_max_samples": 10000,
        "image_comparison_samples": 200,
        "image_datasets": ["cifar10", "cifar100", "svhn", "fashion_mnist"],
        # Text settings
        "text_batch_size": 32,
        "text_epochs": 3,
        "text_max_length": 256,
        "text_max_samples": 3000,
        "text_comparison_samples": 100,
        "text_datasets": ["sst2", "imdb", "ag_news"],
        "text_top_k": 10,
    }
elif GPU_CONFIG == "standard":  # T4, P100
    CONFIG = {
        "image_batch_size": 64,
        "image_epochs": 5,
        "image_max_samples": 5000,
        "image_comparison_samples": 100,
        "image_datasets": ["cifar10", "cifar100", "svhn", "fashion_mnist"],
        "text_batch_size": 16,
        "text_epochs": 2,
        "text_max_length": 128,
        "text_max_samples": 2000,
        "text_comparison_samples": 50,
        "text_datasets": ["sst2", "imdb", "ag_news"],
        "text_top_k": 5,
    }
elif GPU_CONFIG == "low":  # K80, older GPUs
    CONFIG = {
        "image_batch_size": 32,
        "image_epochs": 3,
        "image_max_samples": 2000,
        "image_comparison_samples": 50,
        "image_datasets": ["cifar10", "svhn"],
        "text_batch_size": 8,
        "text_epochs": 2,
        "text_max_length": 64,
        "text_max_samples": 1000,
        "text_comparison_samples": 30,
        "text_datasets": ["sst2", "ag_news"],
        "text_top_k": 5,
    }
else:  # CPU
    CONFIG = {
        "image_batch_size": 16,
        "image_epochs": 2,
        "image_max_samples": 1000,
        "image_comparison_samples": 20,
        "image_datasets": ["cifar10"],
        "text_batch_size": 4,
        "text_epochs": 1,
        "text_max_length": 64,
        "text_max_samples": 500,
        "text_comparison_samples": 20,
        "text_datasets": ["sst2"],
        "text_top_k": 5,
    }

# Display configuration
print("=" * 70)
print("EXPERIMENT CONFIGURATION")
print("=" * 70)
print("\nImage Experiments:")
print(f"  Datasets: {CONFIG['image_datasets']}")
print(f"  Batch size: {CONFIG['image_batch_size']}")
print(f"  Epochs: {CONFIG['image_epochs']}")
print(f"  Max samples: {CONFIG['image_max_samples']}")
print(f"  Comparison samples: {CONFIG['image_comparison_samples']}")
print("\nText Experiments:")
print(f"  Datasets: {CONFIG['text_datasets']}")
print(f"  Batch size: {CONFIG['text_batch_size']}")
print(f"  Epochs: {CONFIG['text_epochs']}")
print(f"  Max length: {CONFIG['text_max_length']}")
print(f"  Max samples: {CONFIG['text_max_samples']}")
print(f"  Top-k tokens: {CONFIG['text_top_k']}")
print("=" * 70)

### 2.2 Initialize Comparison Framework

In [None]:
# Import and configure the comparison framework
from experiments.xai_comparison import XAIMethodsComparison, ComparisonConfig

# Create output directory
OUTPUT_DIR = "./outputs/colab_experiments/comprehensive_comparison"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Initialize configuration
comparison_config = ComparisonConfig(
    device=DEVICE,
    # Image settings
    image_datasets=CONFIG["image_datasets"],
    image_batch_size=CONFIG["image_batch_size"],
    image_epochs=CONFIG["image_epochs"],
    image_max_samples=CONFIG["image_max_samples"],
    image_comparison_samples=CONFIG["image_comparison_samples"],
    # Text settings
    text_datasets=CONFIG["text_datasets"],
    text_batch_size=CONFIG["text_batch_size"],
    text_epochs=CONFIG["text_epochs"],
    text_max_length=CONFIG["text_max_length"],
    text_max_samples=CONFIG["text_max_samples"],
    text_comparison_samples=CONFIG["text_comparison_samples"],
    text_top_k=CONFIG["text_top_k"],
    low_vram=(GPU_CONFIG == "low"),
    output_dir=OUTPUT_DIR,
)

# Initialize comparison object
comparison = XAIMethodsComparison(comparison_config)

print("Comparison framework initialized.")
print(f"Output directory: {OUTPUT_DIR}")

---

## 3. Image Experiments: DiET vs GradCAM

### 3.1 Run Image Comparison Experiments

This section compares DiET and GradCAM on image classification datasets.

In [None]:
print("=" * 70)
print("IMAGE EXPERIMENTS: DiET vs GradCAM")
print("=" * 70)
print(f"\nDatasets: {CONFIG['image_datasets']}")
print(f"Samples per dataset: {CONFIG['image_max_samples']}")
print(f"Training epochs: {CONFIG['image_epochs']}")
print("\nStarting experiments...\n")

image_start_time = datetime.now()

# Run image experiments
image_results = comparison.run_all_image_comparisons(skip_training=False)

image_end_time = datetime.now()
image_duration = (image_end_time - image_start_time).seconds

print(f"\nImage experiments completed in {image_duration // 60} minutes {image_duration % 60} seconds.")

### 3.2 Image Results: Visual Summary

In [None]:
# Extract image results
image_data = []
for dataset_name, result in image_results.items():
    if "error" not in result:
        image_data.append({
            "Dataset": dataset_name.upper(),
            "Baseline Accuracy": result.get("baseline_accuracy", 0),
            "DiET Accuracy": result.get("diet_accuracy", 0),
            "GradCAM Score": result.get("gradcam_mean_score", 0),
            "DiET Score": result.get("diet_mean_score", 0),
            "Improvement": result.get("improvement", 0),
            "DiET Better": result.get("diet_better", False),
        })

image_df = pd.DataFrame(image_data)

# Display table
print("\n" + "=" * 70)
print("IMAGE EXPERIMENTS: QUANTITATIVE RESULTS")
print("=" * 70)
print("\nPixel Perturbation Scores (higher = better attribution quality):\n")
print(image_df.to_string(index=False))

# Summary statistics
if len(image_df) > 0:
    diet_wins = image_df["DiET Better"].sum()
    total = len(image_df)
    avg_improvement = image_df["Improvement"].mean()
    print(f"\nSummary: DiET outperforms GradCAM on {diet_wins}/{total} datasets")
    print(f"Average improvement: {avg_improvement:+.4f}")

In [None]:
# Visual Summary: Image Experiments
if len(image_df) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle("Image Experiments: DiET vs GradCAM - Visual Summary", fontsize=16, fontweight='bold')
    
    # Plot 1: Attribution Quality Comparison
    datasets = image_df["Dataset"].tolist()
    x = np.arange(len(datasets))
    width = 0.35
    
    bars1 = axes[0, 0].bar(x - width/2, image_df["GradCAM Score"], width, label='GradCAM', color='#2196F3')
    bars2 = axes[0, 0].bar(x + width/2, image_df["DiET Score"], width, label='DiET', color='#4CAF50')
    axes[0, 0].set_ylabel('Pixel Perturbation Score')
    axes[0, 0].set_title('Attribution Quality Comparison (Higher = Better)')
    axes[0, 0].set_xticks(x)
    axes[0, 0].set_xticklabels(datasets)
    axes[0, 0].legend()
    axes[0, 0].set_ylim(0, 1)
    for bar in bars1:
        axes[0, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, f'{bar.get_height():.2f}', ha='center', fontsize=9)
    for bar in bars2:
        axes[0, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, f'{bar.get_height():.2f}', ha='center', fontsize=9)
    
    # Plot 2: Improvement Over GradCAM
    improvements = image_df["Improvement"].tolist()
    colors = ['#4CAF50' if imp > 0 else '#F44336' for imp in improvements]
    axes[0, 1].barh(datasets, improvements, color=colors)
    axes[0, 1].axvline(x=0, color='black', linestyle='-', linewidth=0.5)
    axes[0, 1].set_xlabel('Improvement (DiET - GradCAM)')
    axes[0, 1].set_title('DiET Improvement Over GradCAM')
    for i, v in enumerate(improvements):
        axes[0, 1].text(v + 0.005 if v >= 0 else v - 0.005, i, f'{v:+.3f}', va='center', fontsize=10)
    
    # Plot 3: Model Accuracy Comparison
    x = np.arange(len(datasets))
    bars3 = axes[1, 0].bar(x - width/2, image_df["Baseline Accuracy"], width, label='Baseline', color='#FF9800')
    bars4 = axes[1, 0].bar(x + width/2, image_df["DiET Accuracy"], width, label='After DiET', color='#9C27B0')
    axes[1, 0].set_ylabel('Accuracy (%)')
    axes[1, 0].set_title('Model Accuracy Before and After DiET')
    axes[1, 0].set_xticks(x)
    axes[1, 0].set_xticklabels(datasets)
    axes[1, 0].legend()
    axes[1, 0].set_ylim(0, 100)
    
    # Plot 4: Summary Pie Chart
    diet_wins = image_df["DiET Better"].sum()
    gradcam_wins = len(image_df) - diet_wins
    axes[1, 1].pie([diet_wins, gradcam_wins], labels=['DiET Better', 'GradCAM Better'], 
                   autopct='%1.0f%%', colors=['#4CAF50', '#2196F3'], startangle=90)
    axes[1, 1].set_title('Method Performance Summary')
    
    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/image_visual_summary.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\nFigure saved: {OUTPUT_DIR}/image_visual_summary.png")

---

## 4. Text Experiments: DiET vs Integrated Gradients

### 4.1 Run Text Comparison Experiments

This section compares DiET and Integrated Gradients on text classification datasets using BERT.

In [None]:
print("=" * 70)
print("TEXT EXPERIMENTS: DiET vs Integrated Gradients")
print("=" * 70)
print(f"\nDatasets: {CONFIG['text_datasets']}")
print(f"Model: BERT-base-uncased")
print(f"Max sequence length: {CONFIG['text_max_length']}")
print(f"Training epochs: {CONFIG['text_epochs']}")
print("\nStarting experiments...\n")

text_start_time = datetime.now()

# Run text experiments
text_results = comparison.run_all_text_comparisons(skip_training=False)

text_end_time = datetime.now()
text_duration = (text_end_time - text_start_time).seconds

print(f"\nText experiments completed in {text_duration // 60} minutes {text_duration % 60} seconds.")

### 4.2 Text Results: Visual Summary

In [None]:
# Extract text results
text_data = []
for dataset_name, result in text_results.items():
    if "error" not in result:
        text_data.append({
            "Dataset": dataset_name.upper(),
            "Baseline Accuracy": result.get("baseline_accuracy", 0),
            "IG-DiET Overlap": result.get("ig_diet_overlap", 0),
            "Samples Compared": result.get("samples_compared", 0),
        })

text_df = pd.DataFrame(text_data)

# Display table
print("\n" + "=" * 70)
print("TEXT EXPERIMENTS: QUANTITATIVE RESULTS")
print("=" * 70)
print(f"\nTop-{CONFIG['text_top_k']} Token Overlap between IG and DiET:\n")
print(text_df.to_string(index=False))

# Summary statistics
if len(text_df) > 0:
    avg_overlap = text_df["IG-DiET Overlap"].mean()
    avg_accuracy = text_df["Baseline Accuracy"].mean()
    print(f"\nSummary:")
    print(f"  Average BERT accuracy: {avg_accuracy:.1f}%")
    print(f"  Average IG-DiET overlap: {avg_overlap:.4f}")
    
    if avg_overlap >= 0.5:
        print("  Interpretation: Methods show good agreement on important tokens")
    else:
        print("  Interpretation: Methods identify different features")

In [None]:
# Visual Summary: Text Experiments
if len(text_df) > 0:
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    fig.suptitle("Text Experiments: DiET vs Integrated Gradients - Visual Summary", fontsize=16, fontweight='bold')
    
    datasets = text_df["Dataset"].tolist()
    
    # Plot 1: IG-DiET Token Overlap
    overlaps = text_df["IG-DiET Overlap"].tolist()
    colors = ['#4CAF50' if o >= 0.5 else '#FF9800' if o >= 0.3 else '#F44336' for o in overlaps]
    bars = axes[0].bar(datasets, overlaps, color=colors, edgecolor='black')
    axes[0].axhline(y=0.5, color='gray', linestyle='--', linewidth=1, label='50% threshold')
    axes[0].set_ylabel(f'Top-{CONFIG["text_top_k"]} Token Overlap')
    axes[0].set_title('IG-DiET Attribution Agreement')
    axes[0].set_ylim(0, 1)
    axes[0].legend()
    for bar, val in zip(bars, overlaps):
        axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, f'{val:.3f}', ha='center', fontsize=11)
    
    # Plot 2: BERT Accuracy
    accuracies = text_df["Baseline Accuracy"].tolist()
    bars2 = axes[1].bar(datasets, accuracies, color='#2196F3', edgecolor='black')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('BERT Classification Accuracy')
    axes[1].set_ylim(0, 100)
    for bar, acc in zip(bars2, accuracies):
        axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, f'{acc:.1f}%', ha='center', fontsize=11)
    
    # Plot 3: Agreement Level Distribution
    high_agreement = sum(1 for o in overlaps if o >= 0.5)
    medium_agreement = sum(1 for o in overlaps if 0.3 <= o < 0.5)
    low_agreement = sum(1 for o in overlaps if o < 0.3)
    
    labels = ['High (>=0.5)', 'Medium (0.3-0.5)', 'Low (<0.3)']
    sizes = [high_agreement, medium_agreement, low_agreement]
    colors_pie = ['#4CAF50', '#FF9800', '#F44336']
    
    # Only show non-zero segments
    non_zero = [(l, s, c) for l, s, c in zip(labels, sizes, colors_pie) if s > 0]
    if non_zero:
        labels, sizes, colors_pie = zip(*non_zero)
        axes[2].pie(sizes, labels=labels, autopct='%1.0f%%', colors=colors_pie, startangle=90)
    axes[2].set_title('Agreement Level Distribution')
    
    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/text_visual_summary.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\nFigure saved: {OUTPUT_DIR}/text_visual_summary.png")

---

## 5. Combined Results Summary

### 5.1 Complete Results Table

In [None]:
# Generate combined summary report
print("=" * 70)
print("COMBINED EXPERIMENT SUMMARY")
print("=" * 70)

# Get full results dataframe
full_df = comparison.get_results_dataframe()
print("\nComplete Results Table:\n")
print(full_df.to_string(index=False))

In [None]:
# Generate summary report text
report = comparison.generate_summary_report()
print(report)

### 5.2 Combined Visual Summary

In [None]:
# Combined Visual Summary
fig = plt.figure(figsize=(16, 12))
fig.suptitle("DiET vs Basic XAI Methods: Complete Comparison Summary", fontsize=18, fontweight='bold', y=0.98)

# Create grid
gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.25)

# Image comparison subplot
if len(image_df) > 0:
    ax1 = fig.add_subplot(gs[0, 0])
    datasets = image_df["Dataset"].tolist()
    x = np.arange(len(datasets))
    width = 0.35
    ax1.bar(x - width/2, image_df["GradCAM Score"], width, label='GradCAM', color='#2196F3')
    ax1.bar(x + width/2, image_df["DiET Score"], width, label='DiET', color='#4CAF50')
    ax1.set_ylabel('Score')
    ax1.set_title('Image: DiET vs GradCAM')
    ax1.set_xticks(x)
    ax1.set_xticklabels(datasets, rotation=45, ha='right')
    ax1.legend()
    ax1.set_ylim(0, 1)

# Text comparison subplot
if len(text_df) > 0:
    ax2 = fig.add_subplot(gs[0, 1])
    datasets = text_df["Dataset"].tolist()
    overlaps = text_df["IG-DiET Overlap"].tolist()
    colors = ['#4CAF50' if o >= 0.5 else '#FF9800' for o in overlaps]
    ax2.bar(datasets, overlaps, color=colors, edgecolor='black')
    ax2.axhline(y=0.5, color='gray', linestyle='--', linewidth=1)
    ax2.set_ylabel('Token Overlap')
    ax2.set_title('Text: IG-DiET Agreement')
    ax2.set_ylim(0, 1)

# Image accuracy subplot
if len(image_df) > 0:
    ax3 = fig.add_subplot(gs[1, 0])
    datasets = image_df["Dataset"].tolist()
    x = np.arange(len(datasets))
    ax3.bar(x - width/2, image_df["Baseline Accuracy"], width, label='Baseline', color='#FF9800')
    ax3.bar(x + width/2, image_df["DiET Accuracy"], width, label='DiET', color='#9C27B0')
    ax3.set_ylabel('Accuracy (%)')
    ax3.set_title('Image Model Accuracy')
    ax3.set_xticks(x)
    ax3.set_xticklabels(datasets, rotation=45, ha='right')
    ax3.legend()

# Text accuracy subplot
if len(text_df) > 0:
    ax4 = fig.add_subplot(gs[1, 1])
    datasets = text_df["Dataset"].tolist()
    accuracies = text_df["Baseline Accuracy"].tolist()
    ax4.bar(datasets, accuracies, color='#2196F3', edgecolor='black')
    ax4.set_ylabel('Accuracy (%)')
    ax4.set_title('Text Model Accuracy (BERT)')
    ax4.set_ylim(0, 100)

# Summary statistics subplot
ax5 = fig.add_subplot(gs[2, :])
ax5.axis('off')

# Create summary text
summary_text = "EXPERIMENT SUMMARY\n" + "="*50 + "\n\n"

if len(image_df) > 0:
    diet_wins = image_df["DiET Better"].sum()
    avg_imp = image_df["Improvement"].mean()
    summary_text += f"IMAGE EXPERIMENTS (DiET vs GradCAM):\n"
    summary_text += f"  - DiET outperforms GradCAM: {diet_wins}/{len(image_df)} datasets\n"
    summary_text += f"  - Average improvement: {avg_imp:+.4f}\n\n"

if len(text_df) > 0:
    avg_overlap = text_df["IG-DiET Overlap"].mean()
    avg_acc = text_df["Baseline Accuracy"].mean()
    summary_text += f"TEXT EXPERIMENTS (DiET vs Integrated Gradients):\n"
    summary_text += f"  - Average IG-DiET overlap: {avg_overlap:.4f}\n"
    summary_text += f"  - Average BERT accuracy: {avg_acc:.1f}%\n"

ax5.text(0.5, 0.5, summary_text, transform=ax5.transAxes, fontsize=12,
         verticalalignment='center', horizontalalignment='center',
         fontfamily='monospace', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.savefig(f'{OUTPUT_DIR}/combined_visual_summary.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nFigure saved: {OUTPUT_DIR}/combined_visual_summary.png")

---

## 6. Statistical Analysis

### 6.1 Statistical Tests

In [None]:
print("=" * 70)
print("STATISTICAL ANALYSIS")
print("=" * 70)

# Image experiments statistical analysis
if len(image_df) >= 3:
    gradcam_scores = image_df["GradCAM Score"].values
    diet_scores = image_df["DiET Score"].values
    
    # Paired t-test
    t_stat, p_value = stats.ttest_rel(diet_scores, gradcam_scores)
    
    # Effect size (Cohen's d)
    pooled_std = np.sqrt((np.var(gradcam_scores) + np.var(diet_scores)) / 2)
    cohens_d = (np.mean(diet_scores) - np.mean(gradcam_scores)) / pooled_std if pooled_std > 0 else 0
    
    print("\nImage Experiments (DiET vs GradCAM):")
    print(f"  Paired t-test:")
    print(f"    t-statistic: {t_stat:.4f}")
    print(f"    p-value: {p_value:.4f}")
    
    if p_value < 0.05:
        print(f"    Result: Statistically significant (p < 0.05)")
    else:
        print(f"    Result: Not statistically significant")
    
    print(f"\n  Effect Size (Cohen's d): {cohens_d:.4f}")
    if abs(cohens_d) < 0.2:
        print("    Interpretation: Small effect")
    elif abs(cohens_d) < 0.8:
        print("    Interpretation: Medium effect")
    else:
        print("    Interpretation: Large effect")
else:
    print("\nImage Experiments: Not enough data points for statistical testing (need >= 3)")

# Text experiments statistical analysis
if len(text_df) >= 2:
    overlaps = text_df["IG-DiET Overlap"].values
    
    # One-sample t-test against 0.5 threshold
    t_stat_text, p_value_text = stats.ttest_1samp(overlaps, 0.5)
    
    print("\nText Experiments (IG-DiET Overlap):")
    print(f"  One-sample t-test (vs 0.5 threshold):")
    print(f"    Mean overlap: {np.mean(overlaps):.4f}")
    print(f"    t-statistic: {t_stat_text:.4f}")
    print(f"    p-value: {p_value_text:.4f}")
    
    if np.mean(overlaps) >= 0.5:
        print("    Interpretation: Methods show agreement above chance level")
    else:
        print("    Interpretation: Methods identify different features")
else:
    print("\nText Experiments: Not enough data points for statistical testing")

---

## 7. Export Results

### 7.1 Save All Results

In [None]:
# Save results
comparison.save_results()

# Save additional CSV files
if len(image_df) > 0:
    image_df.to_csv(f'{OUTPUT_DIR}/image_results.csv', index=False)
if len(text_df) > 0:
    text_df.to_csv(f'{OUTPUT_DIR}/text_results.csv', index=False)
if len(full_df) > 0:
    full_df.to_csv(f'{OUTPUT_DIR}/all_results.csv', index=False)

# Save configuration
with open(f'{OUTPUT_DIR}/experiment_config.json', 'w') as f:
    json.dump(CONFIG, f, indent=2)

print("\nResults saved:")
print(f"  - {OUTPUT_DIR}/comparison_results.json")
print(f"  - {OUTPUT_DIR}/image_results.csv")
print(f"  - {OUTPUT_DIR}/text_results.csv")
print(f"  - {OUTPUT_DIR}/all_results.csv")
print(f"  - {OUTPUT_DIR}/experiment_config.json")
print(f"  - {OUTPUT_DIR}/image_visual_summary.png")
print(f"  - {OUTPUT_DIR}/text_visual_summary.png")
print(f"  - {OUTPUT_DIR}/combined_visual_summary.png")

In [None]:
# Generate visualizations from framework
try:
    viz_files = comparison.visualize_results(save_plots=True, show=False)
    print("\nAdditional visualizations generated:")
    for name, path in viz_files.items():
        print(f"  - {name}: {path}")
except Exception as e:
    print(f"Note: Some visualizations could not be generated: {e}")

In [None]:
# Download results (for Colab)
try:
    from google.colab import files
    
    # Create zip of all results
    !zip -r comprehensive_results.zip {OUTPUT_DIR}/
    
    print("\nDownload your results:")
    files.download('comprehensive_results.zip')
except:
    print(f"\nResults are saved locally in: {OUTPUT_DIR}/")
    print("(Download option only available in Google Colab)")

### 7.2 Final Report

In [None]:
# Generate final report
total_duration = (image_duration if 'image_duration' in dir() else 0) + (text_duration if 'text_duration' in dir() else 0)

final_report = f"""
================================================================================
                    DiET vs BASIC XAI METHODS
                    COMPREHENSIVE COMPARISON REPORT
================================================================================

Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
Device: {DEVICE} ({GPU_CONFIG})
Total Duration: {total_duration // 60} minutes {total_duration % 60} seconds

--------------------------------------------------------------------------------
IMAGE EXPERIMENTS: DiET vs GradCAM
--------------------------------------------------------------------------------
Datasets: {CONFIG['image_datasets']}
Model: ResNet
Training epochs: {CONFIG['image_epochs']}
Comparison samples: {CONFIG['image_comparison_samples']}

"""

if len(image_df) > 0:
    for _, row in image_df.iterrows():
        status = "[+]" if row["DiET Better"] else "[-]"
        final_report += f"{status} {row['Dataset']}: GradCAM={row['GradCAM Score']:.4f}, DiET={row['DiET Score']:.4f}, Improvement={row['Improvement']:+.4f}\n"
    
    diet_wins = image_df["DiET Better"].sum()
    final_report += f"\nSummary: DiET outperforms GradCAM on {diet_wins}/{len(image_df)} datasets\n"

final_report += f"""
--------------------------------------------------------------------------------
TEXT EXPERIMENTS: DiET vs Integrated Gradients
--------------------------------------------------------------------------------
Datasets: {CONFIG['text_datasets']}
Model: BERT-base-uncased
Max length: {CONFIG['text_max_length']}
Training epochs: {CONFIG['text_epochs']}

"""

if len(text_df) > 0:
    for _, row in text_df.iterrows():
        level = "[HIGH]" if row["IG-DiET Overlap"] >= 0.5 else "[MED]" if row["IG-DiET Overlap"] >= 0.3 else "[LOW]"
        final_report += f"{level} {row['Dataset']}: Accuracy={row['Baseline Accuracy']:.1f}%, Overlap={row['IG-DiET Overlap']:.4f}\n"
    
    avg_overlap = text_df["IG-DiET Overlap"].mean()
    final_report += f"\nSummary: Average IG-DiET overlap = {avg_overlap:.4f}\n"

final_report += f"""
================================================================================
Output Files:
  - {OUTPUT_DIR}/comparison_results.json
  - {OUTPUT_DIR}/image_results.csv
  - {OUTPUT_DIR}/text_results.csv
  - {OUTPUT_DIR}/all_results.csv
  - {OUTPUT_DIR}/image_visual_summary.png
  - {OUTPUT_DIR}/text_visual_summary.png
  - {OUTPUT_DIR}/combined_visual_summary.png
================================================================================
"""

# Save final report
with open(f'{OUTPUT_DIR}/final_report.txt', 'w') as f:
    f.write(final_report)

print(final_report)

---

## References

1. Bhalla, U., et al. (2023). "Discriminative Feature Attributions: Bridging Post Hoc Explainability and Inherent Interpretability." *NeurIPS 2023.*

2. Selvaraju, R. R., et al. (2017). "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization." *ICCV 2017.*

3. Sundararajan, M., Taly, A., & Yan, Q. (2017). "Axiomatic Attribution for Deep Networks." *ICML 2017.*

4. Devlin, J., et al. (2019). "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding." *NAACL 2019.*

---

**Notebook Version:** 1.0  
**Last Updated:** 2025-2026 Academic Year  
**Repository:** https://github.com/xMOROx/Machine-Learning-Project-2025-2026