# DiET vs Basic XAI Methods: Comprehensive Comparison Framework

## A Complete Experimental Pipeline for Image and Text Classification

---

**Author:** Machine Learning Research Team  
**Date:** 2025-2026 Academic Year  
**Course:** Advanced Machine Learning

---

### Abstract

This notebook presents a comprehensive experimental comparison between **DiET (Distractor Erasure Tuning)** and standard XAI methods:

- **Image Classification:** DiET vs GradCAM on CIFAR-10, CIFAR-100, SVHN, Fashion-MNIST
- **Text Classification:** DiET vs Integrated Gradients on SST-2, IMDB, AG News

The notebook provides robust evaluation metrics, visual summaries, and statistical analysis suitable for academic presentations.

### Metrics Used

**Image Attribution Metrics:**
- Pixel Perturbation (keep/remove important regions)
- AOPC (Area Over Perturbation Curve)
- Insertion/Deletion Curves
- Faithfulness Correlation

**Text Attribution Metrics:**
- Top-K Token Overlap (K = 3, 5, 10, 15, 20)
- Attribution Correlation
- Token-level Agreement Analysis

### Reference

Bhalla, U., et al. (2023). **"Discriminative Feature Attributions: Bridging Post Hoc Explainability and Inherent Interpretability."** *NeurIPS 2023.*

---

## Table of Contents

1. [Environment Setup](#1-environment-setup)
2. [Experimental Configuration](#2-experimental-configuration)
3. [Image Experiments: DiET vs GradCAM](#3-image-experiments-diet-vs-gradcam)
4. [Text Experiments: DiET vs Integrated Gradients](#4-text-experiments-diet-vs-integrated-gradients)
5. [Combined Results Summary](#5-combined-results-summary)
6. [Statistical Analysis](#6-statistical-analysis)
7. [Export Results](#7-export-results)

---

## 1. Environment Setup

### 1.1 Hardware Configuration

This notebook is optimized for Google Colab Pro with GPU acceleration.

In [None]:
import torch
import gc

def cleanup_memory():
    """Clean up GPU memory."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

print("=" * 70)
print("HARDWARE CONFIGURATION")
print("=" * 70)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f"GPU Available: {gpu_name}")
    print(f"Memory: {gpu_memory:.1f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"PyTorch Version: {torch.__version__}")

    if gpu_memory >= 15:
        print("\nConfiguration: HIGH-MEMORY GPU")
        GPU_CONFIG = "high"
    elif gpu_memory >= 8:
        print("\nConfiguration: STANDARD GPU")
        GPU_CONFIG = "standard"
    else:
        print("\nConfiguration: LOW-MEMORY GPU")
        GPU_CONFIG = "low"
else:
    print("WARNING: No GPU available. Training will be slow.")
    print("Enable GPU: Runtime -> Change runtime type -> GPU")
    GPU_CONFIG = "cpu"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\nUsing device: {DEVICE.upper()}")
print("=" * 70)

HARDWARE CONFIGURATION
GPU Available: NVIDIA A100-SXM4-80GB
Memory: 79.3 GB
CUDA Version: 12.6
PyTorch Version: 2.9.0+cu126

Configuration: HIGH-MEMORY GPU

Using device: CUDA


### 1.2 Clone Repository and Install Dependencies

In [2]:
import os

REPO_URL = "https://github.com/xMOROx/Machine-Learning-Project-2025-2026.git"
REPO_DIR = "Machine-Learning-Project-2025-2026"

if not os.path.exists(REPO_DIR):
    print("Cloning repository...")
    !git clone --recursive {REPO_URL}
    print("Repository cloned successfully.")
else:
    print("Repository already exists. Pulling latest changes...")
    %cd {REPO_DIR}
    !git pull
    !git submodule update --init --recursive
    %cd ..

%cd {REPO_DIR}

Cloning repository...
Cloning into 'Machine-Learning-Project-2025-2026'...
remote: Enumerating objects: 541, done.[K
remote: Counting objects: 100% (96/96), done.[K
remote: Compressing objects: 100% (63/63), done.[K
remote: Total 541 (delta 53), reused 55 (delta 31), pack-reused 445 (from 1)[K
Receiving objects: 100% (541/541), 1.82 MiB | 40.56 MiB/s, done.
Resolving deltas: 100% (281/281), done.
Submodule 'DiET' (https://github.com/AI4LIFE-GROUP/DiET.git) registered for path 'DiET'
Submodule 'how-to-probe' (https://github.com/sidgairo18/how-to-probe.git) registered for path 'how-to-probe'
Cloning into '/content/Machine-Learning-Project-2025-2026/DiET'...
remote: Enumerating objects: 17, done.        
remote: Counting objects: 100% (17/17), done.        
remote: Compressing objects: 100% (10/10), done.        
remote: Total 17 (delta 9), reused 13 (delta 7), pack-reused 0 (from 0)        
Receiving objects: 100% (17/17), 17.36 KiB | 17.36 MiB/s, done.
Resolving deltas: 100% (9/9), 

In [None]:
print("Installing dependencies...")

!pip install -q transformers datasets tqdm matplotlib seaborn pandas pillow scikit-learn captum scipy

print("All dependencies installed.")

In [None]:
import sys
import os

repo_base = "/content/Machine-Learning-Project-2025-2026/scripts"
if repo_base not in sys.path:
    sys.path.append(repo_base)

from xai_experiments.experiments.xai_comparison import XAIMethodsComparison, ComparisonConfig
from xai_experiments.visualization.comparison_plots import ComparisonVisualizer, VisualizationConfig

print("Import successful!")

# Standard imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm.auto import tqdm
from scipy import stats
import json
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False

print("All modules imported successfully.")
print(f"Experiment started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

Import successful!
All modules imported successfully.
Experiment started: 2026-01-18 08:15:30


---

## 2. Experimental Configuration

### 2.1 Configuration Parameters

Parameters are automatically adjusted based on available GPU memory.
**Note:** Text experiments use reduced batch sizes and sequence lengths to prevent OOM errors.

In [None]:
# Configuration based on GPU capabilities
if GPU_CONFIG == "high":  # A100, V100
    CONFIG = {
        "image_batch_size": 512,
        "image_epochs": 25,
        "image_max_samples": 10000,
        "image_comparison_samples": 2000,
        "image_datasets": ["cifar10", "cifar100", "svhn", "fashion_mnist"],
        "text_batch_size": 128,  
        "text_epochs": 25,  
        "text_max_length": 256,  
        "text_max_samples": 3000,  
        "text_comparison_samples": 1000,  
        "text_datasets": ["sst2", "imdb", "ag_news"],
        "text_top_k_values": [3, 5, 10, 15, 20],
    }
elif GPU_CONFIG == "standard":  # T4, P100
    CONFIG = {
        "image_batch_size": 256,
        "image_epochs": 15,
        "image_max_samples": 5000,
        "image_comparison_samples": 500,
        "image_datasets": ["cifar10", "cifar100", "svhn", "fashion_mnist"],
        "text_batch_size": 16,
        "text_epochs": 10,
        "text_max_length": 128,
        "text_max_samples": 1500,
        "text_comparison_samples": 300,
        "text_datasets": ["sst2", "imdb", "ag_news"],
        "text_top_k_values": [3, 5, 10, 15, 20],
    }
elif GPU_CONFIG == "low":  # K80, older GPUs
    CONFIG = {
        "image_batch_size": 128,
        "image_epochs": 10,
        "image_max_samples": 2000,
        "image_comparison_samples": 200,
        "image_datasets": ["cifar10", "svhn"],
        "text_batch_size": 8,
        "text_epochs": 5,
        "text_max_length": 64,
        "text_max_samples": 1000,
        "text_comparison_samples": 100,
        "text_datasets": ["sst2", "ag_news"],
        "text_top_k_values": [3, 5, 10],
    }
else:  # CPU
    CONFIG = {
        "image_batch_size": 16,
        "image_epochs": 2,
        "image_max_samples": 500,
        "image_comparison_samples": 20,
        "image_datasets": ["cifar10"],
        "text_batch_size": 4,
        "text_epochs": 1,
        "text_max_length": 64,
        "text_max_samples": 200,
        "text_comparison_samples": 20,
        "text_datasets": ["sst2"],
        "text_top_k_values": [3, 5],
    }

# Display configuration
print("=" * 70)
print("EXPERIMENT CONFIGURATION")
print("=" * 70)
print("\nImage Experiments:")
print(f"  Datasets: {CONFIG['image_datasets']}")
print(f"  Batch size: {CONFIG['image_batch_size']}")
print(f"  Epochs: {CONFIG['image_epochs']}")
print(f"  Max samples: {CONFIG['image_max_samples']}")
print(f"  Comparison samples: {CONFIG['image_comparison_samples']}")
print("\nText Experiments:")
print(f"  Datasets: {CONFIG['text_datasets']}")
print(f"  Batch size: {CONFIG['text_batch_size']}")
print(f"  Epochs: {CONFIG['text_epochs']}")
print(f"  Max length: {CONFIG['text_max_length']}")
print(f"  Max samples: {CONFIG['text_max_samples']}")
print(f"  Comparison samples: {CONFIG['text_comparison_samples']}")
print(f"  Top-k values: {CONFIG['text_top_k_values']}")
print("=" * 70)

EXPERIMENT CONFIGURATION

Image Experiments:
  Datasets: ['cifar10', 'cifar100', 'svhn', 'fashion_mnist']
  Batch size: 1024
  Epochs: 25
  Max samples: 30000
  Comparison samples: 5000

Text Experiments:
  Datasets: ['sst2', 'imdb', 'ag_news']
  Batch size: 256
  Epochs: 25
  Max length: 256
  Max samples: 3000
  Top-k tokens: 25


### 2.2 Initialize Comparison Framework

In [None]:
# Create output directory
OUTPUT_DIR = "./outputs/colab_experiments/comprehensive_comparison"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Initialize configuration
comparison_config = ComparisonConfig(
    device=DEVICE,
    # Image settings
    image_datasets=CONFIG["image_datasets"],
    image_batch_size=CONFIG["image_batch_size"],
    image_epochs=CONFIG["image_epochs"],
    image_max_samples=CONFIG["image_max_samples"],
    image_comparison_samples=CONFIG["image_comparison_samples"],
    # Text settings
    text_datasets=CONFIG["text_datasets"],
    text_batch_size=CONFIG["text_batch_size"],
    text_epochs=CONFIG["text_epochs"],
    text_max_length=CONFIG["text_max_length"],
    text_max_samples=CONFIG["text_max_samples"],
    text_comparison_samples=CONFIG["text_comparison_samples"],
    text_top_k_values=CONFIG["text_top_k_values"],
    low_vram=(GPU_CONFIG in ["low", "standard"]),
    output_dir=OUTPUT_DIR,
    compute_all_metrics=True,
)

# Initialize comparison object
comparison = XAIMethodsComparison(comparison_config)

# Initialize visualizer with enhanced config
viz_config = VisualizationConfig(
    figsize=(14, 8),
    dpi=150,
    style="whitegrid",
    use_gradients=True,
)
visualizer = ComparisonVisualizer(output_dir=OUTPUT_DIR, config=viz_config)

print("Comparison framework initialized.")
print(f"Output directory: {OUTPUT_DIR}")

Comparison framework initialized.
Output directory: ./outputs/colab_experiments/comprehensive_comparison


---

## 3. Image Experiments: DiET vs GradCAM

### 3.1 Run Image Comparison Experiments

This section compares DiET and GradCAM on image classification datasets.

In [None]:
print("=" * 70)
print("IMAGE EXPERIMENTS: DiET vs GradCAM")
print("=" * 70)
print(f"\nDatasets: {CONFIG['image_datasets']}")
print(f"Samples per dataset: {CONFIG['image_max_samples']}")
print(f"Training epochs: {CONFIG['image_epochs']}")
print("\nStarting experiments...\n")

# Clean up memory before starting
cleanup_memory()

image_start_time = datetime.now()

image_results = comparison.run_all_image_comparisons(skip_training=False)

image_end_time = datetime.now()
image_duration = (image_end_time - image_start_time).seconds

cleanup_memory()

print(f"\nImage experiments completed in {image_duration // 60} minutes {image_duration % 60} seconds.")

IMAGE EXPERIMENTS: DiET vs GradCAM

Datasets: ['cifar10', 'cifar100', 'svhn', 'fashion_mnist']
Samples per dataset: 30000
Training epochs: 25

Starting experiments...


RUNNING IMAGE COMPARISONS ON 4 DATASETS
Datasets: cifar10, cifar100, svhn, fashion_mnist

IMAGE-BASED XAI COMPARISON (CIFAR-10)
DiET vs GradCAM - Discriminative Feature Attribution

IMAGE-BASED XAI COMPARISON (CIFAR10)
DiET vs GradCAM - Discriminative Feature Attribution
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 226MB/s]


DiET Experiment: Discriminative Feature Attribution

[Step 1/4] Preparing data...


100%|██████████| 170M/170M [00:13<00:00, 12.5MB/s]


Training samples (DiET): 30000
Test samples (DiET): 6000

[Step 2/4] Training baseline model...

Training baseline resnet model...


Epoch 1/25: 100%|██████████| 49/49 [00:07<00:00,  6.75it/s, loss=0.7489, acc=74.21%]


Epoch 1: Train Acc: 74.21%, Test Acc: 76.61%
Checkpoint saved: diet_image_resnet_baseline (epoch 0)


Epoch 2/25: 100%|██████████| 49/49 [00:06<00:00,  8.00it/s, loss=0.2760, acc=90.58%]


Epoch 2: Train Acc: 90.58%, Test Acc: 82.46%
Checkpoint saved: diet_image_resnet_baseline (epoch 1)


Epoch 3/25: 100%|██████████| 49/49 [00:06<00:00,  8.03it/s, loss=0.1347, acc=95.49%]


Epoch 3: Train Acc: 95.49%, Test Acc: 86.23%
Checkpoint saved: diet_image_resnet_baseline (epoch 2)


Epoch 4/25: 100%|██████████| 49/49 [00:06<00:00,  7.86it/s, loss=0.0766, acc=97.47%]


Epoch 4: Train Acc: 97.47%, Test Acc: 86.77%
Checkpoint saved: diet_image_resnet_baseline (epoch 3)


Epoch 5/25: 100%|██████████| 49/49 [00:06<00:00,  8.16it/s, loss=0.0476, acc=98.43%]


Epoch 5: Train Acc: 98.43%, Test Acc: 87.94%
Checkpoint saved: diet_image_resnet_baseline (epoch 4)


Epoch 6/25: 100%|██████████| 49/49 [00:06<00:00,  7.85it/s, loss=0.0352, acc=98.86%]


Epoch 6: Train Acc: 98.86%, Test Acc: 87.78%
Checkpoint saved: diet_image_resnet_baseline (epoch 5)


Epoch 7/25: 100%|██████████| 49/49 [00:06<00:00,  7.83it/s, loss=0.0289, acc=99.04%]


Epoch 7: Train Acc: 99.04%, Test Acc: 87.58%
Checkpoint saved: diet_image_resnet_baseline (epoch 6)


Epoch 8/25: 100%|██████████| 49/49 [00:05<00:00,  8.19it/s, loss=0.0211, acc=99.30%]


Epoch 8: Train Acc: 99.30%, Test Acc: 88.88%
Checkpoint saved: diet_image_resnet_baseline (epoch 7)


Epoch 9/25: 100%|██████████| 49/49 [00:06<00:00,  8.16it/s, loss=0.0112, acc=99.67%]


Epoch 9: Train Acc: 99.67%, Test Acc: 89.58%
Checkpoint saved: diet_image_resnet_baseline (epoch 8)


Epoch 10/25: 100%|██████████| 49/49 [00:06<00:00,  7.79it/s, loss=0.0049, acc=99.86%]


Epoch 10: Train Acc: 99.86%, Test Acc: 90.79%
Checkpoint saved: diet_image_resnet_baseline (epoch 9)


Epoch 11/25: 100%|██████████| 49/49 [00:06<00:00,  8.09it/s, loss=0.0015, acc=99.97%]


Epoch 11: Train Acc: 99.97%, Test Acc: 90.95%
Checkpoint saved: diet_image_resnet_baseline (epoch 10)


Epoch 12/25: 100%|██████████| 49/49 [00:06<00:00,  8.13it/s, loss=0.0004, acc=100.00%]


Epoch 12: Train Acc: 100.00%, Test Acc: 91.27%
Checkpoint saved: diet_image_resnet_baseline (epoch 11)


Epoch 13/25: 100%|██████████| 49/49 [00:06<00:00,  7.97it/s, loss=0.0002, acc=100.00%]


Epoch 13: Train Acc: 100.00%, Test Acc: 91.37%
Checkpoint saved: diet_image_resnet_baseline (epoch 12)


Epoch 14/25: 100%|██████████| 49/49 [00:06<00:00,  8.10it/s, loss=0.0001, acc=100.00%]


Epoch 14: Train Acc: 100.00%, Test Acc: 91.42%
Checkpoint saved: diet_image_resnet_baseline (epoch 13)


Epoch 15/25: 100%|██████████| 49/49 [00:06<00:00,  7.94it/s, loss=0.0001, acc=100.00%]


Epoch 15: Train Acc: 100.00%, Test Acc: 91.46%
Checkpoint saved: diet_image_resnet_baseline (epoch 14)


Epoch 16/25: 100%|██████████| 49/49 [00:06<00:00,  8.04it/s, loss=0.0001, acc=100.00%]


Epoch 16: Train Acc: 100.00%, Test Acc: 91.45%
Checkpoint saved: diet_image_resnet_baseline (epoch 15)


Epoch 17/25: 100%|██████████| 49/49 [00:06<00:00,  8.10it/s, loss=0.0001, acc=100.00%]


Epoch 17: Train Acc: 100.00%, Test Acc: 91.55%
Checkpoint saved: diet_image_resnet_baseline (epoch 16)


Epoch 18/25: 100%|██████████| 49/49 [00:05<00:00,  8.22it/s, loss=0.0001, acc=100.00%]


Epoch 18: Train Acc: 100.00%, Test Acc: 91.53%
Checkpoint saved: diet_image_resnet_baseline (epoch 17)


Epoch 19/25: 100%|██████████| 49/49 [00:06<00:00,  8.04it/s, loss=0.0001, acc=100.00%]


Epoch 19: Train Acc: 100.00%, Test Acc: 91.52%
Checkpoint saved: diet_image_resnet_baseline (epoch 18)


Epoch 20/25: 100%|██████████| 49/49 [00:06<00:00,  8.07it/s, loss=0.0001, acc=100.00%]


Epoch 20: Train Acc: 100.00%, Test Acc: 91.52%
Checkpoint saved: diet_image_resnet_baseline (epoch 19)


Epoch 21/25: 100%|██████████| 49/49 [00:06<00:00,  8.09it/s, loss=0.0001, acc=100.00%]


Epoch 21: Train Acc: 100.00%, Test Acc: 91.56%
Checkpoint saved: diet_image_resnet_baseline (epoch 20)


Epoch 22/25: 100%|██████████| 49/49 [00:06<00:00,  7.98it/s, loss=0.0001, acc=100.00%]


Epoch 22: Train Acc: 100.00%, Test Acc: 91.49%
Checkpoint saved: diet_image_resnet_baseline (epoch 21)


Epoch 23/25: 100%|██████████| 49/49 [00:06<00:00,  8.16it/s, loss=0.0001, acc=100.00%]


Epoch 23: Train Acc: 100.00%, Test Acc: 91.50%
Checkpoint saved: diet_image_resnet_baseline (epoch 22)


Epoch 24/25: 100%|██████████| 49/49 [00:06<00:00,  8.08it/s, loss=0.0001, acc=100.00%]


Epoch 24: Train Acc: 100.00%, Test Acc: 91.51%
Checkpoint saved: diet_image_resnet_baseline (epoch 23)


Epoch 25/25: 100%|██████████| 49/49 [00:06<00:00,  7.91it/s, loss=0.0001, acc=100.00%]


Epoch 25: Train Acc: 100.00%, Test Acc: 91.62%
Checkpoint saved: diet_image_resnet_baseline (epoch 24)
Baseline model saved to ./outputs/colab_experiments/comprehensive_comparison/cifar10/baseline_resnet.pth
Deleted checkpoint: diet_image_resnet_baseline

[Step 3/4] Running DiET distillation...

Running DiET Distillation
Getting baseline predictions...

--- DiET Step 1/2 ---
Training mask...
  Iter 0: loss=1.0000, faithful_acc=1.0000
  Iter 10: loss=0.5165, faithful_acc=0.9052
Training model...
  Iter 0: loss=0.1775, faithful_acc=0.9282
  Iter 10: loss=0.0550, faithful_acc=0.9815
Step 1 complete: Test Acc = 89.62%

--- DiET Step 2/2 ---
Training mask...
  Iter 0: loss=0.2775, faithful_acc=0.9863
  Iter 10: loss=0.1731, faithful_acc=0.9721
  Iter 20: loss=0.2356, faithful_acc=0.9446
  Iter 30: loss=0.3170, faithful_acc=0.9061
Training model...
  Iter 0: loss=0.3488, faithful_acc=0.8396
Step 2 complete: Test Acc = 88.90%

[Step 4/4] Comparing DiET vs GradCAM...

Comparing DiET vs GradCAM

Epoch 1/25: 100%|██████████| 49/49 [00:06<00:00,  7.47it/s, loss=0.8016, acc=72.63%]


Epoch 1: Train Acc: 72.63%, Test Acc: 76.39%
Checkpoint saved: diet_image_resnet_baseline (epoch 0)


Epoch 2/25: 100%|██████████| 49/49 [00:06<00:00,  7.78it/s, loss=0.3177, acc=89.32%]


Epoch 2: Train Acc: 89.32%, Test Acc: 85.29%
Checkpoint saved: diet_image_resnet_baseline (epoch 1)


Epoch 3/25: 100%|██████████| 49/49 [00:06<00:00,  7.72it/s, loss=0.1412, acc=95.27%]


Epoch 3: Train Acc: 95.27%, Test Acc: 85.60%
Checkpoint saved: diet_image_resnet_baseline (epoch 2)


Epoch 4/25: 100%|██████████| 49/49 [00:06<00:00,  7.85it/s, loss=0.0776, acc=97.38%]


Epoch 4: Train Acc: 97.38%, Test Acc: 87.52%
Checkpoint saved: diet_image_resnet_baseline (epoch 3)


Epoch 5/25: 100%|██████████| 49/49 [00:06<00:00,  7.82it/s, loss=0.0546, acc=98.20%]


Epoch 5: Train Acc: 98.20%, Test Acc: 86.69%
Checkpoint saved: diet_image_resnet_baseline (epoch 4)


Epoch 6/25: 100%|██████████| 49/49 [00:06<00:00,  7.85it/s, loss=0.0414, acc=98.63%]


Epoch 6: Train Acc: 98.63%, Test Acc: 88.79%
Checkpoint saved: diet_image_resnet_baseline (epoch 5)


Epoch 7/25: 100%|██████████| 49/49 [00:06<00:00,  7.82it/s, loss=0.0265, acc=99.14%]


Epoch 7: Train Acc: 99.14%, Test Acc: 87.35%
Checkpoint saved: diet_image_resnet_baseline (epoch 6)


Epoch 8/25: 100%|██████████| 49/49 [00:06<00:00,  7.86it/s, loss=0.0150, acc=99.54%]


Epoch 8: Train Acc: 99.54%, Test Acc: 89.26%
Checkpoint saved: diet_image_resnet_baseline (epoch 7)


Epoch 9/25: 100%|██████████| 49/49 [00:06<00:00,  7.89it/s, loss=0.0078, acc=99.80%]


Epoch 9: Train Acc: 99.80%, Test Acc: 89.89%
Checkpoint saved: diet_image_resnet_baseline (epoch 8)


Epoch 10/25: 100%|██████████| 49/49 [00:06<00:00,  7.74it/s, loss=0.0042, acc=99.90%]


Epoch 10: Train Acc: 99.90%, Test Acc: 90.74%
Checkpoint saved: diet_image_resnet_baseline (epoch 9)


Epoch 11/25: 100%|██████████| 49/49 [00:06<00:00,  7.64it/s, loss=0.0022, acc=99.95%]


Epoch 11: Train Acc: 99.95%, Test Acc: 90.90%
Checkpoint saved: diet_image_resnet_baseline (epoch 10)


Epoch 12/25: 100%|██████████| 49/49 [00:06<00:00,  7.86it/s, loss=0.0008, acc=99.99%]


Epoch 12: Train Acc: 99.99%, Test Acc: 91.19%
Checkpoint saved: diet_image_resnet_baseline (epoch 11)


Epoch 13/25: 100%|██████████| 49/49 [00:06<00:00,  8.00it/s, loss=0.0003, acc=100.00%]


Epoch 13: Train Acc: 100.00%, Test Acc: 91.38%
Checkpoint saved: diet_image_resnet_baseline (epoch 12)


Epoch 14/25: 100%|██████████| 49/49 [00:06<00:00,  7.86it/s, loss=0.0001, acc=100.00%]


Epoch 14: Train Acc: 100.00%, Test Acc: 91.56%
Checkpoint saved: diet_image_resnet_baseline (epoch 13)


Epoch 15/25: 100%|██████████| 49/49 [00:06<00:00,  7.55it/s, loss=0.0001, acc=100.00%]


Epoch 15: Train Acc: 100.00%, Test Acc: 91.56%
Checkpoint saved: diet_image_resnet_baseline (epoch 14)


Epoch 16/25: 100%|██████████| 49/49 [00:06<00:00,  7.89it/s, loss=0.0001, acc=100.00%]


Epoch 16: Train Acc: 100.00%, Test Acc: 91.57%
Checkpoint saved: diet_image_resnet_baseline (epoch 15)


Epoch 17/25: 100%|██████████| 49/49 [00:06<00:00,  7.62it/s, loss=0.0001, acc=100.00%]


Epoch 17: Train Acc: 100.00%, Test Acc: 91.58%
Checkpoint saved: diet_image_resnet_baseline (epoch 16)


Epoch 18/25: 100%|██████████| 49/49 [00:06<00:00,  7.84it/s, loss=0.0001, acc=100.00%]


Epoch 18: Train Acc: 100.00%, Test Acc: 91.61%
Checkpoint saved: diet_image_resnet_baseline (epoch 17)


Epoch 19/25: 100%|██████████| 49/49 [00:06<00:00,  7.76it/s, loss=0.0001, acc=100.00%]


Epoch 19: Train Acc: 100.00%, Test Acc: 91.66%
Checkpoint saved: diet_image_resnet_baseline (epoch 18)


Epoch 20/25: 100%|██████████| 49/49 [00:06<00:00,  7.89it/s, loss=0.0001, acc=100.00%]


Epoch 20: Train Acc: 100.00%, Test Acc: 91.60%
Checkpoint saved: diet_image_resnet_baseline (epoch 19)


Epoch 21/25: 100%|██████████| 49/49 [00:06<00:00,  7.87it/s, loss=0.0001, acc=100.00%]


Epoch 21: Train Acc: 100.00%, Test Acc: 91.59%
Checkpoint saved: diet_image_resnet_baseline (epoch 20)


Epoch 22/25: 100%|██████████| 49/49 [00:06<00:00,  7.77it/s, loss=0.0001, acc=100.00%]


Epoch 22: Train Acc: 100.00%, Test Acc: 91.59%
Checkpoint saved: diet_image_resnet_baseline (epoch 21)


Epoch 23/25: 100%|██████████| 49/49 [00:06<00:00,  7.81it/s, loss=0.0001, acc=100.00%]


Epoch 23: Train Acc: 100.00%, Test Acc: 91.53%
Checkpoint saved: diet_image_resnet_baseline (epoch 22)


Epoch 24/25: 100%|██████████| 49/49 [00:06<00:00,  7.80it/s, loss=0.0001, acc=100.00%]


Epoch 24: Train Acc: 100.00%, Test Acc: 91.61%
Checkpoint saved: diet_image_resnet_baseline (epoch 23)


Epoch 25/25: 100%|██████████| 49/49 [00:06<00:00,  7.65it/s, loss=0.0001, acc=100.00%]


Epoch 25: Train Acc: 100.00%, Test Acc: 91.61%
Checkpoint saved: diet_image_resnet_baseline (epoch 24)
Baseline model saved to ./outputs/colab_experiments/comprehensive_comparison/cifar100/baseline_resnet.pth
Deleted checkpoint: diet_image_resnet_baseline

[Step 3/4] Running DiET distillation...

Running DiET Distillation
Getting baseline predictions...

--- DiET Step 1/2 ---
Training mask...
  Iter 0: loss=1.0000, faithful_acc=1.0000
  Iter 10: loss=0.4700, faithful_acc=0.9184
  Iter 20: loss=0.4284, faithful_acc=0.9069
Training model...
  Iter 0: loss=0.1807, faithful_acc=0.9274
  Iter 10: loss=0.0692, faithful_acc=0.9722
Step 1 complete: Test Acc = 88.97%

--- DiET Step 2/2 ---
Training mask...
  Iter 0: loss=0.2677, faithful_acc=0.9764
  Iter 10: loss=0.1696, faithful_acc=0.9649
  Iter 20: loss=0.2773, faithful_acc=0.9186
Training model...
  Iter 0: loss=0.3753, faithful_acc=0.8253
  Iter 10: loss=0.2904, faithful_acc=0.8616
Step 2 complete: Test Acc = 88.69%

[Step 4/4] Comparing 

Epoch 1/25: 100%|██████████| 49/49 [00:06<00:00,  7.82it/s, loss=0.7166, acc=75.70%]


Epoch 1: Train Acc: 75.70%, Test Acc: 83.23%
Checkpoint saved: diet_image_resnet_baseline (epoch 0)


Epoch 2/25: 100%|██████████| 49/49 [00:06<00:00,  7.88it/s, loss=0.2574, acc=91.48%]


Epoch 2: Train Acc: 91.48%, Test Acc: 87.08%
Checkpoint saved: diet_image_resnet_baseline (epoch 1)


Epoch 3/25: 100%|██████████| 49/49 [00:06<00:00,  7.86it/s, loss=0.1127, acc=96.25%]


Epoch 3: Train Acc: 96.25%, Test Acc: 85.52%
Checkpoint saved: diet_image_resnet_baseline (epoch 2)


Epoch 4/25: 100%|██████████| 49/49 [00:06<00:00,  7.77it/s, loss=0.0613, acc=97.96%]


Epoch 4: Train Acc: 97.96%, Test Acc: 87.37%
Checkpoint saved: diet_image_resnet_baseline (epoch 3)


Epoch 5/25: 100%|██████████| 49/49 [00:06<00:00,  7.83it/s, loss=0.0460, acc=98.43%]


Epoch 5: Train Acc: 98.43%, Test Acc: 88.05%
Checkpoint saved: diet_image_resnet_baseline (epoch 4)


Epoch 6/25: 100%|██████████| 49/49 [00:06<00:00,  7.75it/s, loss=0.0449, acc=98.49%]


Epoch 6: Train Acc: 98.49%, Test Acc: 85.57%
Checkpoint saved: diet_image_resnet_baseline (epoch 5)


Epoch 7/25: 100%|██████████| 49/49 [00:06<00:00,  7.80it/s, loss=0.0271, acc=99.10%]


Epoch 7: Train Acc: 99.10%, Test Acc: 89.54%
Checkpoint saved: diet_image_resnet_baseline (epoch 6)


Epoch 8/25: 100%|██████████| 49/49 [00:06<00:00,  7.83it/s, loss=0.0190, acc=99.42%]


Epoch 8: Train Acc: 99.42%, Test Acc: 90.12%
Checkpoint saved: diet_image_resnet_baseline (epoch 7)


Epoch 9/25: 100%|██████████| 49/49 [00:06<00:00,  7.94it/s, loss=0.0074, acc=99.81%]


Epoch 9: Train Acc: 99.81%, Test Acc: 91.03%
Checkpoint saved: diet_image_resnet_baseline (epoch 8)


Epoch 10/25: 100%|██████████| 49/49 [00:06<00:00,  7.80it/s, loss=0.0032, acc=99.92%]


Epoch 10: Train Acc: 99.92%, Test Acc: 91.60%
Checkpoint saved: diet_image_resnet_baseline (epoch 9)


Epoch 11/25: 100%|██████████| 49/49 [00:06<00:00,  7.74it/s, loss=0.0009, acc=99.99%]


Epoch 11: Train Acc: 99.99%, Test Acc: 92.05%
Checkpoint saved: diet_image_resnet_baseline (epoch 10)


Epoch 12/25: 100%|██████████| 49/49 [00:06<00:00,  7.81it/s, loss=0.0002, acc=100.00%]


Epoch 12: Train Acc: 100.00%, Test Acc: 92.14%
Checkpoint saved: diet_image_resnet_baseline (epoch 11)


Epoch 13/25: 100%|██████████| 49/49 [00:06<00:00,  7.78it/s, loss=0.0001, acc=100.00%]


Epoch 13: Train Acc: 100.00%, Test Acc: 92.17%
Checkpoint saved: diet_image_resnet_baseline (epoch 12)


Epoch 14/25: 100%|██████████| 49/49 [00:06<00:00,  7.84it/s, loss=0.0001, acc=100.00%]


Epoch 14: Train Acc: 100.00%, Test Acc: 92.14%
Checkpoint saved: diet_image_resnet_baseline (epoch 13)


Epoch 15/25: 100%|██████████| 49/49 [00:06<00:00,  7.88it/s, loss=0.0001, acc=100.00%]


Epoch 15: Train Acc: 100.00%, Test Acc: 92.21%
Checkpoint saved: diet_image_resnet_baseline (epoch 14)


Epoch 16/25: 100%|██████████| 49/49 [00:06<00:00,  7.89it/s, loss=0.0001, acc=100.00%]


Epoch 16: Train Acc: 100.00%, Test Acc: 92.20%
Checkpoint saved: diet_image_resnet_baseline (epoch 15)


Epoch 17/25: 100%|██████████| 49/49 [00:06<00:00,  7.87it/s, loss=0.0001, acc=100.00%]


Epoch 17: Train Acc: 100.00%, Test Acc: 92.17%
Checkpoint saved: diet_image_resnet_baseline (epoch 16)


Epoch 18/25: 100%|██████████| 49/49 [00:06<00:00,  7.74it/s, loss=0.0001, acc=100.00%]


Epoch 18: Train Acc: 100.00%, Test Acc: 92.13%
Checkpoint saved: diet_image_resnet_baseline (epoch 17)


Epoch 19/25: 100%|██████████| 49/49 [00:06<00:00,  7.79it/s, loss=0.0001, acc=100.00%]


Epoch 19: Train Acc: 100.00%, Test Acc: 92.18%
Checkpoint saved: diet_image_resnet_baseline (epoch 18)


Epoch 20/25: 100%|██████████| 49/49 [00:06<00:00,  7.89it/s, loss=0.0001, acc=100.00%]


Epoch 20: Train Acc: 100.00%, Test Acc: 92.18%
Checkpoint saved: diet_image_resnet_baseline (epoch 19)


Epoch 21/25: 100%|██████████| 49/49 [00:06<00:00,  7.82it/s, loss=0.0001, acc=100.00%]


Epoch 21: Train Acc: 100.00%, Test Acc: 92.13%
Checkpoint saved: diet_image_resnet_baseline (epoch 20)


Epoch 22/25: 100%|██████████| 49/49 [00:06<00:00,  7.94it/s, loss=0.0001, acc=100.00%]


Epoch 22: Train Acc: 100.00%, Test Acc: 92.13%
Checkpoint saved: diet_image_resnet_baseline (epoch 21)


Epoch 23/25: 100%|██████████| 49/49 [00:06<00:00,  7.82it/s, loss=0.0001, acc=100.00%]


Epoch 23: Train Acc: 100.00%, Test Acc: 92.11%
Checkpoint saved: diet_image_resnet_baseline (epoch 22)


Epoch 24/25: 100%|██████████| 49/49 [00:06<00:00,  7.78it/s, loss=0.0001, acc=100.00%]


Epoch 24: Train Acc: 100.00%, Test Acc: 92.13%
Checkpoint saved: diet_image_resnet_baseline (epoch 23)


Epoch 25/25: 100%|██████████| 49/49 [00:06<00:00,  7.65it/s, loss=0.0001, acc=100.00%]


Epoch 25: Train Acc: 100.00%, Test Acc: 92.13%
Checkpoint saved: diet_image_resnet_baseline (epoch 24)
Baseline model saved to ./outputs/colab_experiments/comprehensive_comparison/svhn/baseline_resnet.pth
Deleted checkpoint: diet_image_resnet_baseline

[Step 3/4] Running DiET distillation...

Running DiET Distillation
Getting baseline predictions...

--- DiET Step 1/2 ---
Training mask...
  Iter 0: loss=1.0000, faithful_acc=1.0000
  Iter 10: loss=0.5024, faithful_acc=0.9084
Training model...
  Iter 0: loss=0.1909, faithful_acc=0.9219
Step 1 complete: Test Acc = 90.58%

--- DiET Step 2/2 ---
Training mask...
  Iter 0: loss=0.2969, faithful_acc=0.9672
  Iter 10: loss=0.2023, faithful_acc=0.9544
  Iter 20: loss=0.2454, faithful_acc=0.9393
  Iter 30: loss=0.3021, faithful_acc=0.9135
  Iter 40: loss=0.3183, faithful_acc=0.8987
Training model...
  Iter 0: loss=0.3577, faithful_acc=0.8349
Step 2 complete: Test Acc = 89.82%

[Step 4/4] Comparing DiET vs GradCAM...

Comparing DiET vs GradCAM
Co

Epoch 1/25: 100%|██████████| 49/49 [00:06<00:00,  7.74it/s, loss=0.7688, acc=73.57%]


Epoch 1: Train Acc: 73.57%, Test Acc: 81.57%
Checkpoint saved: diet_image_resnet_baseline (epoch 0)


Epoch 2/25: 100%|██████████| 49/49 [00:06<00:00,  7.49it/s, loss=0.2933, acc=90.12%]


Epoch 2: Train Acc: 90.12%, Test Acc: 84.00%
Checkpoint saved: diet_image_resnet_baseline (epoch 1)


Epoch 3/25: 100%|██████████| 49/49 [00:06<00:00,  7.81it/s, loss=0.1379, acc=95.33%]


Epoch 3: Train Acc: 95.33%, Test Acc: 85.26%
Checkpoint saved: diet_image_resnet_baseline (epoch 2)


Epoch 4/25: 100%|██████████| 49/49 [00:06<00:00,  7.80it/s, loss=0.0850, acc=97.15%]


Epoch 4: Train Acc: 97.15%, Test Acc: 86.41%
Checkpoint saved: diet_image_resnet_baseline (epoch 3)


Epoch 5/25: 100%|██████████| 49/49 [00:06<00:00,  7.77it/s, loss=0.0552, acc=98.12%]


Epoch 5: Train Acc: 98.12%, Test Acc: 87.92%
Checkpoint saved: diet_image_resnet_baseline (epoch 4)


Epoch 6/25: 100%|██████████| 49/49 [00:06<00:00,  7.81it/s, loss=0.0370, acc=98.79%]


Epoch 6: Train Acc: 98.79%, Test Acc: 88.25%
Checkpoint saved: diet_image_resnet_baseline (epoch 5)


Epoch 7/25: 100%|██████████| 49/49 [00:06<00:00,  7.50it/s, loss=0.0254, acc=99.14%]


Epoch 7: Train Acc: 99.14%, Test Acc: 87.34%
Checkpoint saved: diet_image_resnet_baseline (epoch 6)


Epoch 8/25: 100%|██████████| 49/49 [00:06<00:00,  7.48it/s, loss=0.0200, acc=99.36%]


Epoch 8: Train Acc: 99.36%, Test Acc: 88.32%
Checkpoint saved: diet_image_resnet_baseline (epoch 7)


Epoch 9/25: 100%|██████████| 49/49 [00:06<00:00,  7.69it/s, loss=0.0151, acc=99.49%]


Epoch 9: Train Acc: 99.49%, Test Acc: 88.91%
Checkpoint saved: diet_image_resnet_baseline (epoch 8)


Epoch 10/25: 100%|██████████| 49/49 [00:06<00:00,  7.67it/s, loss=0.0104, acc=99.67%]


Epoch 10: Train Acc: 99.67%, Test Acc: 89.12%
Checkpoint saved: diet_image_resnet_baseline (epoch 9)


Epoch 11/25: 100%|██████████| 49/49 [00:06<00:00,  7.67it/s, loss=0.0037, acc=99.92%]


Epoch 11: Train Acc: 99.92%, Test Acc: 90.61%
Checkpoint saved: diet_image_resnet_baseline (epoch 10)


Epoch 12/25: 100%|██████████| 49/49 [00:06<00:00,  7.79it/s, loss=0.0013, acc=99.98%]


Epoch 12: Train Acc: 99.98%, Test Acc: 90.85%
Checkpoint saved: diet_image_resnet_baseline (epoch 11)


Epoch 13/25: 100%|██████████| 49/49 [00:06<00:00,  7.70it/s, loss=0.0004, acc=100.00%]


Epoch 13: Train Acc: 100.00%, Test Acc: 90.98%
Checkpoint saved: diet_image_resnet_baseline (epoch 12)


Epoch 14/25: 100%|██████████| 49/49 [00:06<00:00,  7.77it/s, loss=0.0002, acc=100.00%]


Epoch 14: Train Acc: 100.00%, Test Acc: 90.99%
Checkpoint saved: diet_image_resnet_baseline (epoch 13)


Epoch 15/25: 100%|██████████| 49/49 [00:06<00:00,  7.76it/s, loss=0.0001, acc=100.00%]


Epoch 15: Train Acc: 100.00%, Test Acc: 91.01%
Checkpoint saved: diet_image_resnet_baseline (epoch 14)


Epoch 16/25: 100%|██████████| 49/49 [00:06<00:00,  7.74it/s, loss=0.0001, acc=100.00%]


Epoch 16: Train Acc: 100.00%, Test Acc: 91.09%
Checkpoint saved: diet_image_resnet_baseline (epoch 15)


Epoch 17/25: 100%|██████████| 49/49 [00:06<00:00,  7.66it/s, loss=0.0001, acc=100.00%]


Epoch 17: Train Acc: 100.00%, Test Acc: 91.11%
Checkpoint saved: diet_image_resnet_baseline (epoch 16)


Epoch 18/25: 100%|██████████| 49/49 [00:06<00:00,  7.51it/s, loss=0.0001, acc=100.00%]


Epoch 18: Train Acc: 100.00%, Test Acc: 91.16%
Checkpoint saved: diet_image_resnet_baseline (epoch 17)


Epoch 19/25: 100%|██████████| 49/49 [00:06<00:00,  7.79it/s, loss=0.0001, acc=100.00%]


Epoch 19: Train Acc: 100.00%, Test Acc: 91.11%
Checkpoint saved: diet_image_resnet_baseline (epoch 18)


Epoch 20/25: 100%|██████████| 49/49 [00:06<00:00,  7.82it/s, loss=0.0001, acc=100.00%]


Epoch 20: Train Acc: 100.00%, Test Acc: 91.16%
Checkpoint saved: diet_image_resnet_baseline (epoch 19)


Epoch 21/25: 100%|██████████| 49/49 [00:06<00:00,  7.75it/s, loss=0.0001, acc=100.00%]


Epoch 21: Train Acc: 100.00%, Test Acc: 91.12%
Checkpoint saved: diet_image_resnet_baseline (epoch 20)


Epoch 22/25: 100%|██████████| 49/49 [00:06<00:00,  7.71it/s, loss=0.0001, acc=100.00%]


Epoch 22: Train Acc: 100.00%, Test Acc: 91.16%
Checkpoint saved: diet_image_resnet_baseline (epoch 21)


Epoch 23/25: 100%|██████████| 49/49 [00:06<00:00,  7.55it/s, loss=0.0001, acc=100.00%]


Epoch 23: Train Acc: 100.00%, Test Acc: 91.15%
Checkpoint saved: diet_image_resnet_baseline (epoch 22)


Epoch 24/25: 100%|██████████| 49/49 [00:06<00:00,  7.72it/s, loss=0.0001, acc=100.00%]


Epoch 24: Train Acc: 100.00%, Test Acc: 91.18%
Checkpoint saved: diet_image_resnet_baseline (epoch 23)


Epoch 25/25: 100%|██████████| 49/49 [00:06<00:00,  7.72it/s, loss=0.0001, acc=100.00%]


Epoch 25: Train Acc: 100.00%, Test Acc: 91.17%
Checkpoint saved: diet_image_resnet_baseline (epoch 24)
Baseline model saved to ./outputs/colab_experiments/comprehensive_comparison/fashion_mnist/baseline_resnet.pth
Deleted checkpoint: diet_image_resnet_baseline

[Step 3/4] Running DiET distillation...

Running DiET Distillation
Getting baseline predictions...

--- DiET Step 1/2 ---
Training mask...
  Iter 0: loss=1.0000, faithful_acc=1.0000
  Iter 10: loss=0.4699, faithful_acc=0.9234
  Iter 20: loss=0.4319, faithful_acc=0.9107
Training model...
  Iter 0: loss=0.1871, faithful_acc=0.9238
  Iter 10: loss=0.0652, faithful_acc=0.9765
Step 1 complete: Test Acc = 89.03%

--- DiET Step 2/2 ---
Training mask...
  Iter 0: loss=0.2643, faithful_acc=0.9797
Training model...
  Iter 0: loss=0.0746, faithful_acc=0.9708
Step 2 complete: Test Acc = 88.51%

[Step 4/4] Comparing DiET vs GradCAM...

Comparing DiET vs GradCAM
Comparison visualizations saved to ./outputs/colab_experiments/comprehensive_comp

### 3.2 Image Results: Visual Summary

In [None]:
# Extract image results with all metrics
image_data = []
for dataset_name, result in image_results.items():
    if "error" not in result:
        row = {
            "Dataset": dataset_name.upper(),
            "Baseline Accuracy": result.get("baseline_accuracy", 0),
            "DiET Accuracy": result.get("diet_accuracy", 0),
            "GradCAM Score": result.get("gradcam_mean_score", 0),
            "DiET Score": result.get("diet_mean_score", 0),
            "Improvement": result.get("improvement", 0),
            "DiET Better": result.get("diet_better", False),
        }
        # Add additional metrics if available
        if result.get("gradcam_aopc") is not None:
            row["GradCAM AOPC"] = result.get("gradcam_aopc", 0)
            row["DiET AOPC"] = result.get("diet_aopc", 0)
        if result.get("gradcam_faithfulness") is not None:
            row["GradCAM Faithfulness"] = result.get("gradcam_faithfulness", 0)
            row["DiET Faithfulness"] = result.get("diet_faithfulness", 0)
        if result.get("gradcam_insertion_auc") is not None:
            row["GradCAM Insertion AUC"] = result.get("gradcam_insertion_auc", 0)
            row["DiET Insertion AUC"] = result.get("diet_insertion_auc", 0)
            row["GradCAM Deletion AUC"] = result.get("gradcam_deletion_auc", 0)
            row["DiET Deletion AUC"] = result.get("diet_deletion_auc", 0)
        image_data.append(row)

image_df = pd.DataFrame(image_data)

# Display table
print("\n" + "=" * 70)
print("IMAGE EXPERIMENTS: QUANTITATIVE RESULTS")
print("=" * 70)
print("\nAll Metrics (higher = better attribution quality for most metrics):\n")
print(image_df.to_string(index=False))

# Summary statistics
if len(image_df) > 0:
    diet_wins = image_df["DiET Better"].sum()
    total = len(image_df)
    avg_improvement = image_df["Improvement"].mean()
    print(f"\n{'='*50}")
    print(f"Summary: DiET outperforms GradCAM on {diet_wins}/{total} datasets")
    print(f"Average improvement: {avg_improvement:+.4f}")
    print(f"{'='*50}")


IMAGE EXPERIMENTS: QUANTITATIVE RESULTS

Pixel Perturbation Scores (higher = better attribution quality):

      Dataset  Baseline Accuracy  DiET Accuracy  GradCAM Score  DiET Score  Improvement  DiET Better
      CIFAR10              91.62          88.90       0.457200    0.459467     0.002267         True
     CIFAR100              91.61          88.69       0.436867    0.456333     0.019467         True
         SVHN              92.13          89.82       0.436933    0.468800     0.031867         True
FASHION_MNIST              91.17          88.51       0.420067    0.337933    -0.082133        False

Summary: DiET outperforms GradCAM on 3/4 datasets
Average improvement: -0.0071


In [None]:
if len(image_df) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle("Image Experiments: DiET vs GradCAM - Comprehensive Visual Summary", 
                 fontsize=18, fontweight='bold', y=1.02)
    
    for ax in axes.flat:
        ax.set_facecolor('#fafafa')

    # Plot 1: Attribution Quality Comparison
    datasets = image_df["Dataset"].tolist()
    x = np.arange(len(datasets))
    width = 0.35

    bars1 = axes[0, 0].bar(x - width/2, image_df["GradCAM Score"], width, 
                           label='GradCAM', color='#3498db', alpha=0.85, edgecolor='#333', linewidth=1.5)
    bars2 = axes[0, 0].bar(x + width/2, image_df["DiET Score"], width, 
                           label='DiET', color='#2ecc71', alpha=0.85, edgecolor='#333', linewidth=1.5)
    axes[0, 0].set_ylabel('Pixel Perturbation Score', fontweight='bold')
    axes[0, 0].set_title('Attribution Quality Comparison (Higher = Better)', fontweight='bold', fontsize=12)
    axes[0, 0].set_xticks(x)
    axes[0, 0].set_xticklabels(datasets, fontweight='bold')
    axes[0, 0].legend(framealpha=0.95)
    axes[0, 0].set_ylim(0, max(image_df["GradCAM Score"].max(), image_df["DiET Score"].max()) * 1.2)
    axes[0, 0].grid(True, alpha=0.3, linestyle='--')
    
    for bar in bars1:
        axes[0, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                       f'{bar.get_height():.3f}', ha='center', fontsize=9, fontweight='bold')
    for bar in bars2:
        axes[0, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                       f'{bar.get_height():.3f}', ha='center', fontsize=9, fontweight='bold')

    # Plot 2: Improvement Over GradCAM
    improvements = image_df["Improvement"].tolist()
    colors = ['#2ecc71' if imp > 0 else '#e74c3c' for imp in improvements]
    bars = axes[0, 1].barh(datasets, improvements, color=colors, alpha=0.85, edgecolor='#333', linewidth=1.5)
    axes[0, 1].axvline(x=0, color='black', linestyle='-', linewidth=1)
    axes[0, 1].set_xlabel('Improvement (DiET - GradCAM)', fontweight='bold')
    axes[0, 1].set_title('DiET Improvement Over GradCAM', fontweight='bold', fontsize=12)
    axes[0, 1].grid(True, alpha=0.3, linestyle='--')
    for bar, v in zip(bars, improvements):
        axes[0, 1].text(v + 0.005 if v >= 0 else v - 0.005, bar.get_y() + bar.get_height()/2,
                       f'{v:+.4f}', va='center', ha='left' if v >= 0 else 'right', 
                       fontsize=10, fontweight='bold')

    # Plot 3: Model Accuracy Comparison
    bars3 = axes[1, 0].bar(x - width/2, image_df["Baseline Accuracy"], width, 
                           label='Baseline', color='#f39c12', alpha=0.85, edgecolor='#333', linewidth=1.5)
    bars4 = axes[1, 0].bar(x + width/2, image_df["DiET Accuracy"], width, 
                           label='After DiET', color='#9b59b6', alpha=0.85, edgecolor='#333', linewidth=1.5)
    axes[1, 0].set_ylabel('Accuracy (%)', fontweight='bold')
    axes[1, 0].set_title('Model Accuracy Before and After DiET', fontweight='bold', fontsize=12)
    axes[1, 0].set_xticks(x)
    axes[1, 0].set_xticklabels(datasets, fontweight='bold')
    axes[1, 0].legend(framealpha=0.95)
    axes[1, 0].set_ylim(0, 100)
    axes[1, 0].grid(True, alpha=0.3, linestyle='--')

    # Plot 4: Summary Pie Chart with better styling
    diet_wins = image_df["DiET Better"].sum()
    gradcam_wins = len(image_df) - diet_wins
    if diet_wins > 0 or gradcam_wins > 0:
        wedges, texts, autotexts = axes[1, 1].pie(
            [diet_wins, gradcam_wins], 
            labels=['DiET Better', 'GradCAM Better'],
            autopct='%1.0f%%', 
            colors=['#2ecc71', '#3498db'], 
            startangle=90,
            explode=(0.05, 0),
            shadow=True,
            textprops={'fontweight': 'bold'}
        )
        for autotext in autotexts:
            autotext.set_color('white')
            autotext.set_fontsize(14)
    axes[1, 1].set_title('Method Performance Summary', fontweight='bold', fontsize=12)

    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/image_visual_summary.png', dpi=150, bbox_inches='tight', facecolor='white')
    plt.show()

    print(f"\nFigure saved: {OUTPUT_DIR}/image_visual_summary.png")


Figure saved: ./outputs/colab_experiments/comprehensive_comparison/image_visual_summary.png


---

## 4. Text Experiments: DiET vs Integrated Gradients

### 4.1 Run Text Comparison Experiments

This section compares DiET and Integrated Gradients on text classification datasets using BERT.

**Note:** Memory is cleaned between each dataset to prevent OOM errors.

In [None]:
print("=" * 70)
print("TEXT EXPERIMENTS: DiET vs Integrated Gradients")
print("=" * 70)
print(f"\nDatasets: {CONFIG['text_datasets']}")
print(f"Model: BERT-base-uncased")
print(f"Max sequence length: {CONFIG['text_max_length']}")
print(f"Training epochs: {CONFIG['text_epochs']}")
print(f"Top-k values: {CONFIG['text_top_k_values']}")
print("\nStarting experiments...\n")

# Clean up memory before starting text experiments
cleanup_memory()

text_start_time = datetime.now()

# Run text experiments
text_results = comparison.run_all_text_comparisons(skip_training=False)

text_end_time = datetime.now()
text_duration = (text_end_time - text_start_time).seconds

# Clean up after text experiments
cleanup_memory()

print(f"\nText experiments completed in {text_duration // 60} minutes {text_duration % 60} seconds.")

TEXT EXPERIMENTS: DiET vs Integrated Gradients

Datasets: ['sst2', 'imdb', 'ag_news']
Model: BERT-base-uncased
Max sequence length: 256
Training epochs: 25

Starting experiments...


RUNNING TEXT COMPARISONS ON 3 DATASETS
Datasets: sst2, imdb, ag_news

TEXT-BASED XAI COMPARISON (SST2)
DiET vs Integrated Gradients - Token Attribution
DiET Text Experiment: Token Attribution

[Step 1/4] Preparing data...
Loading SST2 dataset...


README.md: 0.00B [00:00, ?B/s]

sst2/train-00000-of-00001.parquet:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

sst2/validation-00000-of-00001.parquet:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

sst2/test-00000-of-00001.parquet:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Tokenizing training data...


Tokenizing: 100%|██████████| 3000/3000 [00:02<00:00, 1385.58it/s]


Training samples: 3000
Test samples: 200
Max sequence length: 256

[Step 2/4] Training baseline BERT...

Fine-tuning BERT baseline...


Epoch 1/25: 100%|██████████| 12/12 [00:25<00:00,  2.09s/it, loss=0.6671, acc=59.90%]


Checkpoint saved: diet_text_sst2_baseline (epoch 0)


Epoch 2/25: 100%|██████████| 12/12 [00:23<00:00,  2.00s/it, loss=0.4488, acc=84.73%]


Checkpoint saved: diet_text_sst2_baseline (epoch 1)


Epoch 3/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.2801, acc=89.63%]


Checkpoint saved: diet_text_sst2_baseline (epoch 2)


Epoch 4/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.1993, acc=93.47%]


Checkpoint saved: diet_text_sst2_baseline (epoch 3)


Epoch 5/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.1395, acc=95.50%]


Checkpoint saved: diet_text_sst2_baseline (epoch 4)


Epoch 6/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0892, acc=97.60%]


Checkpoint saved: diet_text_sst2_baseline (epoch 5)


Epoch 7/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0634, acc=98.30%]


Checkpoint saved: diet_text_sst2_baseline (epoch 6)


Epoch 8/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0408, acc=98.90%]


Checkpoint saved: diet_text_sst2_baseline (epoch 7)


Epoch 9/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0354, acc=99.03%]


Checkpoint saved: diet_text_sst2_baseline (epoch 8)


Epoch 10/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0249, acc=99.30%]


Checkpoint saved: diet_text_sst2_baseline (epoch 9)


Epoch 11/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0256, acc=99.37%]


Checkpoint saved: diet_text_sst2_baseline (epoch 10)


Epoch 12/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0166, acc=99.47%]


Checkpoint saved: diet_text_sst2_baseline (epoch 11)


Epoch 13/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0139, acc=99.63%]


Checkpoint saved: diet_text_sst2_baseline (epoch 12)


Epoch 14/25: 100%|██████████| 12/12 [00:23<00:00,  2.00s/it, loss=0.0114, acc=99.60%]


Checkpoint saved: diet_text_sst2_baseline (epoch 13)


Epoch 15/25: 100%|██████████| 12/12 [00:23<00:00,  2.00s/it, loss=0.0115, acc=99.67%]


Checkpoint saved: diet_text_sst2_baseline (epoch 14)


Epoch 16/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0075, acc=99.83%]


Checkpoint saved: diet_text_sst2_baseline (epoch 15)


Epoch 17/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0131, acc=99.70%]


Checkpoint saved: diet_text_sst2_baseline (epoch 16)


Epoch 18/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0040, acc=99.93%]


Checkpoint saved: diet_text_sst2_baseline (epoch 17)


Epoch 19/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0081, acc=99.87%]


Checkpoint saved: diet_text_sst2_baseline (epoch 18)


Epoch 20/25: 100%|██████████| 12/12 [00:23<00:00,  2.00s/it, loss=0.0059, acc=99.87%]


Checkpoint saved: diet_text_sst2_baseline (epoch 19)


Epoch 21/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0051, acc=99.90%]


Checkpoint saved: diet_text_sst2_baseline (epoch 20)


Epoch 22/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0052, acc=99.87%]


Checkpoint saved: diet_text_sst2_baseline (epoch 21)


Epoch 23/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0084, acc=99.63%]


Checkpoint saved: diet_text_sst2_baseline (epoch 22)


Epoch 24/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0066, acc=99.83%]


Checkpoint saved: diet_text_sst2_baseline (epoch 23)


Epoch 25/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0048, acc=99.83%]


Checkpoint saved: diet_text_sst2_baseline (epoch 24)
Deleted checkpoint: diet_text_sst2_baseline
Baseline validation accuracy: 88.00%

[Step 3/4] Running DiET distillation...

Running DiET for Text
Getting baseline predictions...

--- DiET Step 1/2 ---
  Iter 0: loss=0.0040, faithful_acc=1.0000

--- DiET Step 2/2 ---
  Iter 0: loss=0.0022, faithful_acc=1.0000

[Step 4/4] Comparing DiET vs IG...

Comparing DiET vs Integrated Gradients


Comparing methods: 100%|██████████| 1000/1000 [07:02<00:00,  2.37it/s]



Comparison Results:
  Mean top-k token overlap: 0.7198
  (1.0 = perfect agreement, 0.0 = no overlap)
Text comparison saved to ./outputs/colab_experiments/comprehensive_comparison/sst2/comparison_visualizations

DiET Text Experiment Summary
Baseline Val Accuracy: 88.00%
DiET-IG Top-k Overlap: 0.7198

Total time: 1324.8 seconds
Results saved to ./outputs/colab_experiments/comprehensive_comparison/sst2/diet_text_results.json

TEXT-BASED XAI COMPARISON (IMDB)
DiET vs Integrated Gradients - Token Attribution
DiET Text Experiment: Token Attribution

[Step 1/4] Preparing data...
Loading IMDB dataset...


README.md: 0.00B [00:00, ?B/s]

plain_text/train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

plain_text/test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

plain_text/unsupervised-00000-of-00001.p(…):   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Tokenizing training data...


Tokenizing: 100%|██████████| 3000/3000 [00:12<00:00, 232.81it/s]


Training samples: 3000
Test samples: 200
Max sequence length: 256

[Step 2/4] Training baseline BERT...

Fine-tuning BERT baseline...


Epoch 1/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.2481, acc=95.50%]


Checkpoint saved: diet_text_imdb_baseline (epoch 0)


Epoch 2/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0576, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 1)


Epoch 3/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0209, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 2)


Epoch 4/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0102, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 3)


Epoch 5/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0065, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 4)


Epoch 6/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0048, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 5)


Epoch 7/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0037, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 6)


Epoch 8/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0031, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 7)


Epoch 9/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0025, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 8)


Epoch 10/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0022, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 9)


Epoch 11/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0019, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 10)


Epoch 12/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0016, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 11)


Epoch 13/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0014, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 12)


Epoch 14/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0012, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 13)


Epoch 15/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0011, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 14)


Epoch 16/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0010, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 15)


Epoch 17/25: 100%|██████████| 12/12 [00:24<00:00,  2.04s/it, loss=0.0008, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 16)


Epoch 18/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0007, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 17)


Epoch 19/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0007, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 18)


Epoch 20/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0006, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 19)


Epoch 21/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0005, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 20)


Epoch 22/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0004, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 21)


Epoch 23/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0004, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 22)


Epoch 24/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0003, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 23)


Epoch 25/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0003, acc=100.00%]


Checkpoint saved: diet_text_imdb_baseline (epoch 24)
Deleted checkpoint: diet_text_imdb_baseline
Baseline validation accuracy: 100.00%

[Step 3/4] Running DiET distillation...

Running DiET for Text
Getting baseline predictions...

--- DiET Step 1/2 ---
  Iter 0: loss=0.0040, faithful_acc=1.0000

--- DiET Step 2/2 ---
  Iter 0: loss=0.0022, faithful_acc=1.0000

[Step 4/4] Comparing DiET vs IG...

Comparing DiET vs Integrated Gradients


Comparing methods: 100%|██████████| 1000/1000 [07:09<00:00,  2.33it/s]



Comparison Results:
  Mean top-k token overlap: 0.2360
  (1.0 = perfect agreement, 0.0 = no overlap)
Text comparison saved to ./outputs/colab_experiments/comprehensive_comparison/imdb/comparison_visualizations

DiET Text Experiment Summary
Baseline Val Accuracy: 100.00%
DiET-IG Top-k Overlap: 0.2360

Total time: 1367.7 seconds
Results saved to ./outputs/colab_experiments/comprehensive_comparison/imdb/diet_text_results.json

TEXT-BASED XAI COMPARISON (AG_NEWS)
DiET vs Integrated Gradients - Token Attribution
DiET Text Experiment: Token Attribution

[Step 1/4] Preparing data...
Loading AG_NEWS dataset...


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/18.6M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/1.23M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7600 [00:00<?, ? examples/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Tokenizing training data...


Tokenizing: 100%|██████████| 3000/3000 [00:03<00:00, 771.54it/s]


Training samples: 3000
Test samples: 200
Max sequence length: 256

[Step 2/4] Training baseline BERT...

Fine-tuning BERT baseline...


Epoch 1/25: 100%|██████████| 12/12 [00:23<00:00,  2.00s/it, loss=1.2302, acc=47.30%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 0)


Epoch 2/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.8530, acc=79.67%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 1)


Epoch 3/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.5835, acc=87.03%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 2)


Epoch 4/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.3988, acc=90.27%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 3)


Epoch 5/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.2917, acc=92.27%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 4)


Epoch 6/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.2202, acc=94.03%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 5)


Epoch 7/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.1588, acc=96.53%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 6)


Epoch 8/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.1183, acc=97.23%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 7)


Epoch 9/25: 100%|██████████| 12/12 [00:24<00:00,  2.01s/it, loss=0.0883, acc=98.20%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 8)


Epoch 10/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0684, acc=98.37%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 9)


Epoch 11/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0578, acc=98.83%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 10)


Epoch 12/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0433, acc=99.17%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 11)


Epoch 13/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0354, acc=99.37%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 12)


Epoch 14/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0336, acc=99.40%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 13)


Epoch 15/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0280, acc=99.47%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 14)


Epoch 16/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0249, acc=99.53%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 15)


Epoch 17/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0202, acc=99.67%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 16)


Epoch 18/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0184, acc=99.63%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 17)


Epoch 19/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0172, acc=99.70%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 18)


Epoch 20/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0160, acc=99.70%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 19)


Epoch 21/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0154, acc=99.70%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 20)


Epoch 22/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0132, acc=99.83%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 21)


Epoch 23/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0106, acc=99.80%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 22)


Epoch 24/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0081, acc=99.83%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 23)


Epoch 25/25: 100%|██████████| 12/12 [00:24<00:00,  2.00s/it, loss=0.0090, acc=99.83%]


Checkpoint saved: diet_text_ag_news_baseline (epoch 24)
Deleted checkpoint: diet_text_ag_news_baseline
Baseline validation accuracy: 96.00%

[Step 3/4] Running DiET distillation...

Running DiET for Text
Getting baseline predictions...

--- DiET Step 1/2 ---
Error running ag_news: CUDA out of memory. Tried to allocate 768.00 MiB. GPU 0 has a total capacity of 79.32 GiB of which 737.88 MiB is free. Process 4736 has 78.59 GiB memory in use. Of the allocated memory 76.07 GiB is allocated by PyTorch, and 2.02 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Text experiments completed in 56 minutes 33 seconds.


Traceback (most recent call last):
  File "/content/Machine-Learning-Project-2025-2026/scripts/xai_experiments/experiments/xai_comparison.py", line 358, in run_all_text_comparisons
    results = self.run_text_comparison(
              ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/Machine-Learning-Project-2025-2026/scripts/xai_experiments/experiments/xai_comparison.py", line 324, in run_text_comparison
    results = experiment.run_full_experiment(skip_training=skip_training)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/Machine-Learning-Project-2025-2026/scripts/xai_experiments/experiments/diet_text_experiment.py", line 793, in run_full_experiment
    self.run_diet(rounding_steps)
  File "/content/Machine-Learning-Project-2025-2026/scripts/xai_experiments/experiments/diet_text_experiment.py", line 567, in run_diet
    metrics = diet.train_token_mask(
              ^^^^^^^^^^^^^^^^^^^^^^
  File "/content/Machine-Learning-Project-2025-2026/scripts

### 4.2 Text Results: Comprehensive Visual Summary

In [None]:
# Extract text results with all metrics
text_data = []
for dataset_name, result in text_results.items():
    if "error" not in result:
        row = {
            "Dataset": dataset_name.upper(),
            "Baseline Accuracy": result.get("baseline_accuracy", 0),
            "Samples Compared": result.get("samples_compared", 0),
            "Mean Correlation": result.get("mean_correlation", 0),
        }
        # Add all top-k overlap metrics
        for k in CONFIG["text_top_k_values"]:
            key = f"top_{k}_overlap"
            if key in result:
                row[f"Top-{k} Overlap"] = result[key]
                row[f"Top-{k} Std"] = result.get(f"{key}_std", 0)
        text_data.append(row)

text_df = pd.DataFrame(text_data)

print("\n" + "=" * 70)
print("TEXT EXPERIMENTS: QUANTITATIVE RESULTS")
print("=" * 70)
print(f"\nToken Overlap between IG and DiET for various K values:\n")
if len(text_df) > 0:
    print(text_df.to_string(index=False))

    print(f"\n{'='*50}")
    print("Summary Statistics:")
    for k in CONFIG["text_top_k_values"]:
        col = f"Top-{k} Overlap"
        if col in text_df.columns:
            avg = text_df[col].mean()
            std = text_df[col].std() if len(text_df) > 1 else 0
            print(f"  Average Top-{k} Overlap: {avg:.4f} (±{std:.4f})")
    if "Mean Correlation" in text_df.columns:
        avg_corr = text_df["Mean Correlation"].mean()
        print(f"  Average Correlation: {avg_corr:.4f}")
    print(f"{'='*50}")
else:
    print("No successful text experiments to display.")


TEXT EXPERIMENTS: QUANTITATIVE RESULTS

Top-25 Token Overlap between IG and DiET:

Dataset  Baseline Accuracy  IG-DiET Overlap  Samples Compared
   SST2               88.0           0.7198              1000
   IMDB              100.0           0.2360              1000

Summary:
  Average BERT accuracy: 94.0%
  Average IG-DiET overlap: 0.4779
  Interpretation: Methods identify different features


In [None]:
if len(text_df) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle("Text Experiments: DiET vs Integrated Gradients - Comprehensive Visual Summary", 
                 fontsize=18, fontweight='bold', y=1.02)
    
    for ax in axes.flat:
        ax.set_facecolor('#fafafa')

    datasets = text_df["Dataset"].tolist()

    # Plot 1: Top-K Overlap Comparison across K values
    k_values = [k for k in CONFIG["text_top_k_values"] if f"Top-{k} Overlap" in text_df.columns]
    if k_values:
        x = np.arange(len(datasets))
        width = 0.8 / len(k_values)
        colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(k_values)))
        
        for i, k in enumerate(k_values):
            col = f"Top-{k} Overlap"
            offset = (i - len(k_values)/2 + 0.5) * width
            bars = axes[0, 0].bar(x + offset, text_df[col], width, label=f'Top-{k}', 
                                 color=colors[i], alpha=0.85, edgecolor='#333', linewidth=0.5)
        
        axes[0, 0].axhline(y=0.5, color='#e74c3c', linestyle='--', linewidth=2, alpha=0.7, label='50% Threshold')
        axes[0, 0].set_ylabel('Token Overlap Score', fontweight='bold')
        axes[0, 0].set_title('IG-DiET Token Overlap Across K Values', fontweight='bold', fontsize=12)
        axes[0, 0].set_xticks(x)
        axes[0, 0].set_xticklabels(datasets, fontweight='bold')
        axes[0, 0].legend(loc='upper right', ncol=3, framealpha=0.95)
        axes[0, 0].set_ylim(0, 1.1)
        axes[0, 0].grid(True, alpha=0.3, linestyle='--')

    # Plot 2: Top-K Overlap Line Chart (trend across K)
    if k_values:
        for idx, dataset in enumerate(datasets):
            overlaps = [text_df[text_df["Dataset"] == dataset][f"Top-{k} Overlap"].values[0] 
                       for k in k_values if f"Top-{k} Overlap" in text_df.columns]
            axes[0, 1].plot(k_values, overlaps, 'o-', linewidth=2.5, markersize=10, 
                           label=dataset, alpha=0.8)
        
        axes[0, 1].axhline(y=0.5, color='#e74c3c', linestyle='--', linewidth=2, alpha=0.7)
        axes[0, 1].set_xlabel('K (Top-K Tokens)', fontweight='bold')
        axes[0, 1].set_ylabel('Overlap Score', fontweight='bold')
        axes[0, 1].set_title('Overlap Trend Across K Values', fontweight='bold', fontsize=12)
        axes[0, 1].legend(framealpha=0.95)
        axes[0, 1].set_ylim(0, 1.05)
        axes[0, 1].grid(True, alpha=0.3, linestyle='--')

    # Plot 3: BERT Accuracy
    bars2 = axes[1, 0].bar(datasets, text_df["Baseline Accuracy"], 
                           color='#3498db', alpha=0.85, edgecolor='#333', linewidth=1.5)
    axes[1, 0].set_ylabel('Accuracy (%)', fontweight='bold')
    axes[1, 0].set_title('BERT Classification Accuracy', fontweight='bold', fontsize=12)
    axes[1, 0].set_ylim(0, 105)
    axes[1, 0].grid(True, alpha=0.3, linestyle='--')
    for bar, acc in zip(bars2, text_df["Baseline Accuracy"]):
        axes[1, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
                       f'{acc:.1f}%', ha='center', fontsize=11, fontweight='bold')

    # Plot 4: Correlation scores
    if "Mean Correlation" in text_df.columns:
        correlations = text_df["Mean Correlation"].tolist()
        colors_corr = ['#2ecc71' if c > 0 else '#e74c3c' for c in correlations]
        bars3 = axes[1, 1].bar(datasets, correlations, color=colors_corr, 
                              alpha=0.85, edgecolor='#333', linewidth=1.5)
        axes[1, 1].axhline(y=0, color='black', linestyle='-', linewidth=1)
        axes[1, 1].set_ylabel('Correlation', fontweight='bold')
        axes[1, 1].set_title('IG-DiET Attribution Correlation', fontweight='bold', fontsize=12)
        axes[1, 1].set_ylim(-1, 1)
        axes[1, 1].grid(True, alpha=0.3, linestyle='--')
        for bar, val in zip(bars3, correlations):
            axes[1, 1].text(bar.get_x() + bar.get_width()/2, val + 0.05 if val >= 0 else val - 0.1, 
                           f'{val:.3f}', ha='center', fontsize=10, fontweight='bold')
    else:
        axes[1, 1].text(0.5, 0.5, "Correlation data not available", ha='center', va='center', fontsize=12)
        axes[1, 1].axis('off')

    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/text_visual_summary.png', dpi=150, bbox_inches='tight', facecolor='white')
    plt.show()

    print(f"\nFigure saved: {OUTPUT_DIR}/text_visual_summary.png")
else:
    print("No text experiment data to visualize.")


Figure saved: ./outputs/colab_experiments/comprehensive_comparison/text_visual_summary.png


---

## 5. Combined Results Summary

### 5.1 Complete Results Table

In [None]:
print("=" * 70)
print("COMBINED EXPERIMENT SUMMARY")
print("=" * 70)

full_df = comparison.get_results_dataframe()
print("\nComplete Results Table:\n")
if len(full_df) > 0:
    print(full_df.to_string(index=False))
else:
    print("No results to display.")

COMBINED EXPERIMENT SUMMARY

Complete Results Table:

Modality       Dataset Method 1 Method 2  GradCAM Score  DiET Score  Improvement DiET Better  Baseline Accuracy  DiET Accuracy  IG-DiET Overlap  Samples Compared
   Image       CIFAR10  GradCAM     DiET       0.457200    0.459467     0.002267        True              91.62          88.90              NaN               NaN
   Image      CIFAR100  GradCAM     DiET       0.436867    0.456333     0.019467        True              91.61          88.69              NaN               NaN
   Image          SVHN  GradCAM     DiET       0.436933    0.468800     0.031867        True              92.13          89.82              NaN               NaN
   Image FASHION_MNIST  GradCAM     DiET       0.420067    0.337933    -0.082133       False              91.17          88.51              NaN               NaN
    Text          SST2       IG     DiET            NaN         NaN          NaN         NaN              88.00            NaN          

In [13]:
# Generate summary report text
report = comparison.generate_summary_report()
print(report)

XAI METHODS COMPARISON REPORT - MULTI-DATASET
DiET vs GradCAM (Images) & DiET vs IG (Text)

Generated: 2026-01-18T08:15:30.668488
Device: cuda

IMAGE CLASSIFICATION RESULTS

--- CIFAR10 ---
  Baseline Accuracy: 91.62%
  DiET Accuracy: 88.90%
  GradCAM Score: 0.4572
  DiET Score: 0.4595
  ✓ DiET IMPROVES by 0.0023

--- CIFAR100 ---
  Baseline Accuracy: 91.61%
  DiET Accuracy: 88.69%
  GradCAM Score: 0.4369
  DiET Score: 0.4563
  ✓ DiET IMPROVES by 0.0195

--- SVHN ---
  Baseline Accuracy: 92.13%
  DiET Accuracy: 89.82%
  GradCAM Score: 0.4369
  DiET Score: 0.4688
  ✓ DiET IMPROVES by 0.0319

--- FASHION_MNIST ---
  Baseline Accuracy: 91.17%
  DiET Accuracy: 88.51%
  GradCAM Score: 0.4201
  DiET Score: 0.3379
  → GradCAM performs better

--- SUMMARY ---
  DiET better on 3/4 datasets

TEXT CLASSIFICATION RESULTS

--- SST2 ---
  Baseline Accuracy: 88.00%
  IG-DiET Overlap: 0.7198
  Samples Compared: 1000
  → High agreement between methods

--- IMDB ---
  Baseline Accuracy: 100.00%
  IG-DiE

### 5.2 Combined Visual Summary with All Metrics

In [None]:
if len(image_df) > 0 or len(text_df) > 0:
    dashboard_fig = visualizer.create_summary_dashboard(
        image_results=image_results if len(image_df) > 0 else None,
        text_results=text_results if len(text_df) > 0 else None,
        save_name="comprehensive_dashboard",
        show=True
    )

    # Create top-k overlap comparison for text if we have data
    if len(text_df) > 0:
        text_overlap_data = {}
        for _, row in text_df.iterrows():
            dataset = row["Dataset"]
            text_overlap_data[dataset] = {}
            for k in CONFIG["text_top_k_values"]:
                col = f"Top-{k} Overlap"
                if col in row:
                    text_overlap_data[dataset][f"top_{k}_overlap"] = row[col]
                    std_col = f"Top-{k} Std"
                    if std_col in row:
                        text_overlap_data[dataset][f"top_{k}_overlap_std"] = row[std_col]
        
        if text_overlap_data:
            topk_fig = visualizer.plot_top_k_overlap_comparison(
                text_overlap_data,
                title="Token Overlap Across K Values by Dataset",
                save_name="topk_overlap_comparison",
                show=True
            )

    print(f"\nAll dashboard figures saved to: {OUTPUT_DIR}")
else:
    print("No data available for dashboard visualization.")


Figure saved: ./outputs/colab_experiments/comprehensive_comparison/combined_visual_summary.png


---

## 6. Statistical Analysis

### 6.1 Statistical Tests

In [None]:
print("=" * 70)
print("STATISTICAL ANALYSIS")
print("=" * 70)

# Image experiments statistical analysis
if len(image_df) >= 3:
    gradcam_scores = image_df["GradCAM Score"].values
    diet_scores = image_df["DiET Score"].values

    # Paired t-test
    t_stat, p_value = stats.ttest_rel(diet_scores, gradcam_scores)

    # Wilcoxon signed-rank test (non-parametric alternative)
    try:
        w_stat, w_pvalue = stats.wilcoxon(diet_scores, gradcam_scores)
    except:
        w_stat, w_pvalue = None, None

    # Effect size (Cohen's d)
    pooled_std = np.sqrt((np.var(gradcam_scores) + np.var(diet_scores)) / 2)
    cohens_d = (np.mean(diet_scores) - np.mean(gradcam_scores)) / pooled_std if pooled_std > 0 else 0

    print("\nImage Experiments (DiET vs GradCAM):")
    print(f"  Paired t-test:")
    print(f"    t-statistic: {t_stat:.4f}")
    print(f"    p-value: {p_value:.4f}")

    if p_value < 0.05:
        print(f"    Result: Statistically significant (p < 0.05)")
    else:
        print(f"    Result: Not statistically significant")

    if w_pvalue is not None:
        print(f"\n  Wilcoxon signed-rank test:")
        print(f"    W-statistic: {w_stat:.4f}")
        print(f"    p-value: {w_pvalue:.4f}")

    print(f"\n  Effect Size (Cohen's d): {cohens_d:.4f}")
    if abs(cohens_d) < 0.2:
        print("    Interpretation: Small effect")
    elif abs(cohens_d) < 0.8:
        print("    Interpretation: Medium effect")
    else:
        print("    Interpretation: Large effect")
else:
    print("\nImage Experiments: Not enough data points for statistical testing (need >= 3)")

# Text experiments statistical analysis
if len(text_df) >= 2:
    # Get overlap values for the default k (5)
    overlap_col = "Top-5 Overlap" if "Top-5 Overlap" in text_df.columns else None
    
    if overlap_col:
        overlaps = text_df[overlap_col].values

        # One-sample t-test against 0.5 threshold
        t_stat_text, p_value_text = stats.ttest_1samp(overlaps, 0.5)

        print("\nText Experiments (IG-DiET Overlap):")
        print(f"  One-sample t-test (vs 0.5 threshold):")
        print(f"    Mean overlap: {np.mean(overlaps):.4f}")
        print(f"    t-statistic: {t_stat_text:.4f}")
        print(f"    p-value: {p_value_text:.4f}")

        if np.mean(overlaps) >= 0.5:
            print("    Interpretation: Methods show agreement above chance level")
        else:
            print("    Interpretation: Methods identify different features")
    
    if "Mean Correlation" in text_df.columns:
        correlations = text_df["Mean Correlation"].values
        print(f"\n  Attribution Correlation:")
        print(f"    Mean: {np.mean(correlations):.4f}")
        print(f"    Std: {np.std(correlations):.4f}")
else:
    print("\nText Experiments: Not enough data points for statistical testing")

STATISTICAL ANALYSIS

Image Experiments (DiET vs GradCAM):
  Paired t-test:
    t-statistic: -0.2773
    p-value: 0.7996
    Result: Not statistically significant

  Effect Size (Cohen's d): -0.1824
    Interpretation: Small effect

Text Experiments (IG-DiET Overlap):
  One-sample t-test (vs 0.5 threshold):
    Mean overlap: 0.4779
    t-statistic: -0.0914
    p-value: 0.9420
    Interpretation: Methods identify different features


---

## 7. Export Results

### 7.1 Save All Results

In [None]:
comparison.save_results()

if len(image_df) > 0:
    image_df.to_csv(f'{OUTPUT_DIR}/image_results.csv', index=False)
if len(text_df) > 0:
    text_df.to_csv(f'{OUTPUT_DIR}/text_results.csv', index=False)
if len(full_df) > 0:
    full_df.to_csv(f'{OUTPUT_DIR}/all_results.csv', index=False)

with open(f'{OUTPUT_DIR}/experiment_config.json', 'w') as f:
    json.dump(CONFIG, f, indent=2)

print("\nResults saved:")
print(f"  - {OUTPUT_DIR}/comparison_results.json")
print(f"  - {OUTPUT_DIR}/image_results.csv")
print(f"  - {OUTPUT_DIR}/text_results.csv")
print(f"  - {OUTPUT_DIR}/all_results.csv")
print(f"  - {OUTPUT_DIR}/experiment_config.json")
print(f"  - {OUTPUT_DIR}/image_visual_summary.png")
print(f"  - {OUTPUT_DIR}/text_visual_summary.png")
print(f"  - {OUTPUT_DIR}/combined_visual_summary.png")


Results saved to: ./outputs/colab_experiments/comprehensive_comparison

Results saved:
  - ./outputs/colab_experiments/comprehensive_comparison/comparison_results.json
  - ./outputs/colab_experiments/comprehensive_comparison/image_results.csv
  - ./outputs/colab_experiments/comprehensive_comparison/text_results.csv
  - ./outputs/colab_experiments/comprehensive_comparison/all_results.csv
  - ./outputs/colab_experiments/comprehensive_comparison/experiment_config.json
  - ./outputs/colab_experiments/comprehensive_comparison/image_visual_summary.png
  - ./outputs/colab_experiments/comprehensive_comparison/text_visual_summary.png
  - ./outputs/colab_experiments/comprehensive_comparison/combined_visual_summary.png


In [None]:
try:
    viz_files = comparison.visualize_results(save_plots=True, show=False)
    print("\nAdditional visualizations generated:")
    for name, path in viz_files.items():
        print(f"  - {name}: {path}")
except Exception as e:
    print(f"Note: Some visualizations could not be generated: {e}")

Saved: ./outputs/colab_experiments/comprehensive_comparison/image_metric_comparison.png
Saved: ./outputs/colab_experiments/comprehensive_comparison/comparison_dashboard.png
Saved HTML report: ./outputs/colab_experiments/comprehensive_comparison/comparison_report.html

Visualization files generated:
  - image_bar_chart: ./outputs/colab_experiments/comprehensive_comparison/image_metric_comparison.png
  - dashboard: ./outputs/colab_experiments/comprehensive_comparison/comparison_dashboard.png
  - html_report: ./outputs/colab_experiments/comprehensive_comparison/comparison_report.html

Additional visualizations generated:
  - image_bar_chart: ./outputs/colab_experiments/comprehensive_comparison/image_metric_comparison.png
  - dashboard: ./outputs/colab_experiments/comprehensive_comparison/comparison_dashboard.png
  - html_report: ./outputs/colab_experiments/comprehensive_comparison/comparison_report.html


In [None]:
try:
    from google.colab import files

    !zip -r comprehensive_results.zip {OUTPUT_DIR}/

    print("\nDownload your results:")
    files.download('comprehensive_results.zip')
except:
    print(f"\nResults are saved locally in: {OUTPUT_DIR}/")
    print("(Download option only available in Google Colab)")

  adding: outputs/colab_experiments/comprehensive_comparison/ (stored 0%)
  adding: outputs/colab_experiments/comprehensive_comparison/svhn/ (stored 0%)
  adding: outputs/colab_experiments/comprehensive_comparison/svhn/diet_mask_step0.pt (deflated 95%)
  adding: outputs/colab_experiments/comprehensive_comparison/svhn/diet_resnet.pth (deflated 7%)
  adding: outputs/colab_experiments/comprehensive_comparison/svhn/checkpoints/ (stored 0%)
  adding: outputs/colab_experiments/comprehensive_comparison/svhn/comparison_visualizations/ (stored 0%)
  adding: outputs/colab_experiments/comprehensive_comparison/svhn/comparison_visualizations/diet_vs_gradcam.png (deflated 24%)
  adding: outputs/colab_experiments/comprehensive_comparison/svhn/diet_mask_step1.pt (deflated 96%)
  adding: outputs/colab_experiments/comprehensive_comparison/svhn/baseline_resnet.pth (deflated 7%)
  adding: outputs/colab_experiments/comprehensive_comparison/svhn/diet_experiment_results.json (deflated 70%)
  adding: outputs/

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### 7.2 Final Report

In [None]:
total_duration = (image_duration if 'image_duration' in dir() else 0) + (text_duration if 'text_duration' in dir() else 0)

final_report = f"""
================================================================================
                    DiET vs BASIC XAI METHODS
                    COMPREHENSIVE COMPARISON REPORT
================================================================================

Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
Device: {DEVICE} ({GPU_CONFIG})
Total Duration: {total_duration // 60} minutes {total_duration % 60} seconds

--------------------------------------------------------------------------------
IMAGE EXPERIMENTS: DiET vs GradCAM
--------------------------------------------------------------------------------
Datasets: {CONFIG['image_datasets']}
Model: ResNet
Training epochs: {CONFIG['image_epochs']}
Comparison samples: {CONFIG['image_comparison_samples']}

"""

if len(image_df) > 0:
    for _, row in image_df.iterrows():
        status = "[+]" if row["DiET Better"] else "[-]"
        final_report += f"{status} {row['Dataset']}: GradCAM={row['GradCAM Score']:.4f}, DiET={row['DiET Score']:.4f}, Improvement={row['Improvement']:+.4f}\n"

    diet_wins = image_df["DiET Better"].sum()
    final_report += f"\nSummary: DiET outperforms GradCAM on {diet_wins}/{len(image_df)} datasets\n"

final_report += f"""
--------------------------------------------------------------------------------
TEXT EXPERIMENTS: DiET vs Integrated Gradients
--------------------------------------------------------------------------------
Datasets: {CONFIG['text_datasets']}
Model: BERT-base-uncased
Max length: {CONFIG['text_max_length']}
Training epochs: {CONFIG['text_epochs']}
Top-K values: {CONFIG['text_top_k_values']}

"""

if len(text_df) > 0:
    for _, row in text_df.iterrows():
        default_overlap = row.get("Top-5 Overlap", row.get("Top-3 Overlap", 0))
        level = "[HIGH]" if default_overlap >= 0.5 else "[MED]" if default_overlap >= 0.3 else "[LOW]"
        final_report += f"{level} {row['Dataset']}: Accuracy={row['Baseline Accuracy']:.1f}%"
        for k in CONFIG['text_top_k_values']:
            col = f"Top-{k} Overlap"
            if col in row:
                final_report += f", Top-{k}={row[col]:.3f}"
        final_report += "\n"

    if "Top-5 Overlap" in text_df.columns:
        avg_overlap = text_df["Top-5 Overlap"].mean()
        final_report += f"\nSummary: Average Top-5 IG-DiET overlap = {avg_overlap:.4f}\n"

final_report += f"""
================================================================================
Output Files:
  - {OUTPUT_DIR}/comparison_results.json
  - {OUTPUT_DIR}/image_results.csv
  - {OUTPUT_DIR}/text_results.csv
  - {OUTPUT_DIR}/all_results.csv
  - {OUTPUT_DIR}/image_visual_summary.png
  - {OUTPUT_DIR}/text_visual_summary.png
  - {OUTPUT_DIR}/combined_visual_summary.png
  - {OUTPUT_DIR}/comprehensive_dashboard.png
================================================================================
"""

# Save final report
with open(f'{OUTPUT_DIR}/final_report.txt', 'w') as f:
    f.write(final_report)

print(final_report)


                    DiET vs BASIC XAI METHODS
                    COMPREHENSIVE COMPARISON REPORT

Date: 2026-01-18 10:32:39
Device: cuda (high)
Total Duration: 135 minutes 28 seconds

--------------------------------------------------------------------------------
IMAGE EXPERIMENTS: DiET vs GradCAM
--------------------------------------------------------------------------------
Datasets: ['cifar10', 'cifar100', 'svhn', 'fashion_mnist']
Model: ResNet
Training epochs: 25
Comparison samples: 5000

[+] CIFAR10: GradCAM=0.4572, DiET=0.4595, Improvement=+0.0023
[+] CIFAR100: GradCAM=0.4369, DiET=0.4563, Improvement=+0.0195
[+] SVHN: GradCAM=0.4369, DiET=0.4688, Improvement=+0.0319
[-] FASHION_MNIST: GradCAM=0.4201, DiET=0.3379, Improvement=-0.0821

Summary: DiET outperforms GradCAM on 3/4 datasets

--------------------------------------------------------------------------------
TEXT EXPERIMENTS: DiET vs Integrated Gradients
------------------------------------------------------------------

---

## References

1. Bhalla, U., et al. (2023). "Discriminative Feature Attributions: Bridging Post Hoc Explainability and Inherent Interpretability." *NeurIPS 2023.*

2. Selvaraju, R. R., et al. (2017). "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization." *ICCV 2017.*

3. Sundararajan, M., Taly, A., & Yan, Q. (2017). "Axiomatic Attribution for Deep Networks." *ICML 2017.*

4. Devlin, J., et al. (2019). "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding." *NAACL 2019.*

---

**Notebook Version:** 2.0  
**Last Updated:** 2025-2026 Academic Year  
**Repository:** https://github.com/xMOROx/Machine-Learning-Project-2025-2026