# üìù DiET vs Integrated Gradients: Token Attribution for Text Classification

## A Comprehensive Comparison Study

---

**Author:** Machine Learning Research Team  
**Date:** 2025-2026 Academic Year  
**Course:** Advanced Machine Learning

---

### üìã Abstract

This notebook presents a comprehensive experimental comparison between **DiET (Distractor Erasure Tuning)** and **Integrated Gradients (IG)** for explainable AI in text classification tasks. We evaluate both methods on multiple text datasets (SST-2, IMDB, AG News) using BERT-based models and token-level attribution analysis.

### üéØ Research Questions

1. How do DiET token attributions compare with Integrated Gradients?
2. Do both methods identify similar important tokens for classification?
3. How consistent are the attributions across different text classification tasks?

### üìö Reference

Bhalla, U., et al. (2023). **"Discriminative Feature Attributions: Bridging Post Hoc Explainability and Inherent Interpretability."** *NeurIPS 2023.*

---

## 1. Environment Setup

### 1.1 Check GPU Availability

This notebook is optimized for **Google Colab Pro** with GPU acceleration. BERT models benefit significantly from GPU acceleration.

In [None]:
# Check GPU availability and type
import torch

print("=" * 60)
print("üñ•Ô∏è  HARDWARE CONFIGURATION")
print("=" * 60)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f"‚úÖ GPU Available: {gpu_name}")
    print(f"   Memory: {gpu_memory:.1f} GB")
    print(f"   CUDA Version: {torch.version.cuda}")
    print(f"   PyTorch Version: {torch.__version__}")
    
    if gpu_memory >= 15:
        print("\nüöÄ High-memory GPU detected! Using optimal settings.")
        GPU_CONFIG = "high"
    elif gpu_memory >= 8:
        print("\n‚ú® Standard GPU detected. Using balanced settings.")
        GPU_CONFIG = "standard"
    else:
        print("\n‚ö†Ô∏è  Low-memory GPU detected. Using memory-efficient settings.")
        GPU_CONFIG = "low"
else:
    print("‚ùå No GPU available. Training will be very slow.")
    print("   Please enable GPU: Runtime ‚Üí Change runtime type ‚Üí GPU")
    GPU_CONFIG = "cpu"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\nüìç Using device: {DEVICE.upper()}")
print("=" * 60)

### 1.2 Clone Repository and Install Dependencies

In [None]:
# Clone the repository
import os

REPO_URL = "https://github.com/xMOROx/Machine-Learning-Project-2025-2026.git"
REPO_DIR = "Machine-Learning-Project-2025-2026"

if not os.path.exists(REPO_DIR):
    print("üì• Cloning repository...")
    !git clone --recursive {REPO_URL}
    print("‚úÖ Repository cloned successfully!")
else:
    print("üìÅ Repository already exists. Pulling latest changes...")
    %cd {REPO_DIR}
    !git pull
    !git submodule update --init --recursive
    %cd ..

%cd {REPO_DIR}

In [None]:
# Install required packages
print("üì¶ Installing dependencies...")
print("This may take a few minutes on first run.\n")

!pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu118
!pip install -q transformers datasets tqdm matplotlib seaborn pandas numpy scikit-learn captum

print("\n‚úÖ All dependencies installed!")

In [None]:
# Add project to path and import modules
import sys
sys.path.insert(0, './scripts/xai_experiments')

# Standard imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

print("‚úÖ All modules imported successfully!")
print(f"üìÖ Experiment started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

---

## 2. Experimental Configuration

### 2.1 Hyperparameters and Settings

We configure the experiment based on available GPU memory. Text models (BERT) require more memory than image models.

In [None]:
# Configuration based on GPU capabilities
if GPU_CONFIG == "high":  # A100, V100, etc.
    CONFIG = {
        "batch_size": 32,
        "epochs": 3,
        "max_length": 256,
        "max_samples": 3000,
        "comparison_samples": 100,
        "datasets": ["sst2", "imdb", "ag_news"],
        "top_k": 10,  # Number of top tokens to compare
    }
elif GPU_CONFIG == "standard":  # T4, P100
    CONFIG = {
        "batch_size": 16,
        "epochs": 2,
        "max_length": 128,
        "max_samples": 2000,
        "comparison_samples": 50,
        "datasets": ["sst2", "imdb", "ag_news"],
        "top_k": 5,
    }
elif GPU_CONFIG == "low":  # K80, older GPUs
    CONFIG = {
        "batch_size": 8,
        "epochs": 2,
        "max_length": 64,
        "max_samples": 1000,
        "comparison_samples": 30,
        "datasets": ["sst2", "ag_news"],  # Skip IMDB (long texts)
        "top_k": 5,
    }
else:  # CPU
    CONFIG = {
        "batch_size": 4,
        "epochs": 1,
        "max_length": 64,
        "max_samples": 500,
        "comparison_samples": 20,
        "datasets": ["sst2"],  # Single dataset for demo
        "top_k": 5,
    }

# Display configuration
print("=" * 60)
print("‚öôÔ∏è  EXPERIMENT CONFIGURATION")
print("=" * 60)
for key, value in CONFIG.items():
    print(f"   {key}: {value}")
print("=" * 60)

### 2.2 Dataset Overview

We evaluate on the following text classification datasets:

| Dataset | Type | Classes | Avg Length | Description |
|---------|------|---------|------------|-------------|
| **SST-2** | Sentiment | 2 | ~20 words | Movie review sentences |
| **IMDB** | Sentiment | 2 | ~250 words | Full movie reviews |
| **AG News** | Topic | 4 | ~40 words | News articles |

In [None]:
# Load and preview datasets
from datasets import load_dataset

def preview_dataset(dataset_name, num_samples=3):
    """Display sample texts from a dataset."""
    print(f"\n{'='*60}")
    print(f"üìö {dataset_name.upper()} Dataset")
    print("=" * 60)
    
    if dataset_name == "sst2":
        dataset = load_dataset("glue", "sst2", split="train")
        text_col = "sentence"
        label_col = "label"
        labels = ["Negative", "Positive"]
    elif dataset_name == "imdb":
        dataset = load_dataset("imdb", split="train")
        text_col = "text"
        label_col = "label"
        labels = ["Negative", "Positive"]
    elif dataset_name == "ag_news":
        dataset = load_dataset("ag_news", split="train")
        text_col = "text"
        label_col = "label"
        labels = ["World", "Sports", "Business", "Sci/Tech"]
    
    print(f"Total samples: {len(dataset):,}")
    print(f"Classes: {labels}")
    print(f"\nSample texts:")
    
    for i in range(num_samples):
        text = dataset[i][text_col]
        label = labels[dataset[i][label_col]]
        text_preview = text[:100] + "..." if len(text) > 100 else text
        print(f"\n  [{label}] {text_preview}")

print("üìä Dataset Samples Preview")
for dataset in CONFIG["datasets"]:
    preview_dataset(dataset)

---

## 3. DiET vs Integrated Gradients Comparison Framework

### 3.1 Method Overview

#### Integrated Gradients (IG)
- **Type:** Post-hoc attribution method
- **Approach:** Integrates gradients along path from baseline to input
- **Advantage:** Satisfies axioms of sensitivity and implementation invariance
- **Limitation:** May highlight features correlated but not causal

#### DiET for Text
- **Type:** Inherent interpretability via fine-tuning
- **Approach:** Learns token-level masks that preserve predictions
- **Advantage:** Produces discriminative token attributions
- **Focus:** Identifies truly necessary tokens for classification

In [None]:
# Import the comparison framework
from experiments.xai_comparison import XAIMethodsComparison, ComparisonConfig

# Create configuration
comparison_config = ComparisonConfig(
    device=DEVICE,
    text_datasets=CONFIG["datasets"],
    text_batch_size=CONFIG["batch_size"],
    text_epochs=CONFIG["epochs"],
    text_max_length=CONFIG["max_length"],
    text_max_samples=CONFIG["max_samples"],
    text_comparison_samples=CONFIG["comparison_samples"],
    text_top_k=CONFIG["top_k"],
    low_vram=(GPU_CONFIG == "low"),
    output_dir="./outputs/colab_experiments/text_comparison",
)

print("‚úÖ Comparison framework initialized!")
print(f"üìÅ Output directory: {comparison_config.output_dir}")
print(f"üî§ Model: BERT-base-uncased")

### 3.2 Run Experiments

‚è±Ô∏è **Estimated Time:** 
- High-memory GPU: ~45-60 minutes
- Standard GPU: ~30-45 minutes
- Low-memory GPU: ~20-30 minutes
- CPU: ~2+ hours (not recommended for BERT)

In [None]:
# Initialize comparison
comparison = XAIMethodsComparison(comparison_config)

# Run full comparison (text only)
print("\n" + "=" * 70)
print("üöÄ STARTING DiET vs INTEGRATED GRADIENTS COMPARISON")
print("=" * 70)
print(f"\nüìö Datasets: {CONFIG['datasets']}")
print(f"üî¢ Samples per dataset: {CONFIG['max_samples']}")
print(f"üìà Training epochs: {CONFIG['epochs']}")
print(f"üî§ Max sequence length: {CONFIG['max_length']}")
print(f"üéØ Top-k tokens to compare: {CONFIG['top_k']}")
print(f"\n‚è≥ This may take a while. Progress will be shown below...\n")

start_time = datetime.now()
results = comparison.run_full_comparison(run_images=False, run_text=True)
end_time = datetime.now()

duration = end_time - start_time
print(f"\n‚úÖ Experiment completed in {duration.seconds // 60} minutes {duration.seconds % 60} seconds!")

---

## 4. Results Analysis

### 4.1 Quantitative Results: Token Overlap Analysis

In [None]:
# Get results as DataFrame
df = comparison.get_results_dataframe()

# Filter text results only
text_df = df[df["Modality"] == "Text"].copy()

# Display results table
print("\n" + "=" * 70)
print("üìä QUANTITATIVE RESULTS SUMMARY")
print("=" * 70)

if len(text_df) > 0:
    print("\n  Token Attribution Comparison (IG vs DiET)")
    print("  " + "-" * 60)
    
    for _, row in text_df.iterrows():
        overlap = row.get("IG-DiET Overlap", 0)
        samples = row.get("Samples Compared", 0)
        accuracy = row.get("Baseline Accuracy", 0)
        
        # Interpret overlap
        if overlap >= 0.7:
            interpretation = "üü¢ High agreement"
        elif overlap >= 0.4:
            interpretation = "üü° Moderate agreement"
        else:
            interpretation = "üî¥ Low agreement (methods find different features)"
        
        print(f"\n  üìö {row['Dataset']}")
        print(f"     Baseline Accuracy: {accuracy:.1f}%")
        print(f"     IG-DiET Top-{CONFIG['top_k']} Overlap: {overlap:.4f}")
        print(f"     Samples Compared: {samples}")
        print(f"     Interpretation: {interpretation}")
    
    # Summary
    avg_overlap = text_df["IG-DiET Overlap"].mean()
    print("\n" + "-" * 70)
    print(f"üìà Average IG-DiET Overlap: {avg_overlap:.4f}")
    
    if avg_overlap >= 0.5:
        print("\n‚úÖ Methods show good agreement on important tokens")
    else:
        print("\nüîç Methods identify different features - DiET may find more discriminative tokens")
else:
    print("No text results available. Please run the experiment first.")

### 4.2 Visualization: Method Agreement Across Datasets

In [None]:
# Create comparison visualizations
if len(text_df) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Bar chart: IG-DiET Overlap by dataset
    datasets = text_df["Dataset"].tolist()
    overlaps = text_df["IG-DiET Overlap"].tolist()
    
    colors = ['#4CAF50' if o >= 0.5 else '#FF9800' if o >= 0.3 else '#F44336' for o in overlaps]
    
    bars = axes[0].bar(datasets, overlaps, color=colors, alpha=0.8, edgecolor='black')
    axes[0].axhline(y=0.5, color='gray', linestyle='--', linewidth=1, label='50% threshold')
    axes[0].set_ylabel(f'Top-{CONFIG["top_k"]} Token Overlap', fontsize=12)
    axes[0].set_title('IG-DiET Token Attribution Agreement\n(Higher = More Agreement)', fontsize=14, fontweight='bold')
    axes[0].set_ylim(0, 1)
    axes[0].legend()
    
    # Add value labels
    for bar, val in zip(bars, overlaps):
        axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
                    f'{val:.3f}', ha='center', va='bottom', fontsize=11, fontweight='bold')
    
    # Accuracy comparison
    accuracies = text_df["Baseline Accuracy"].tolist()
    
    bars2 = axes[1].barh(datasets, accuracies, color='#2196F3', alpha=0.8, edgecolor='black')
    axes[1].set_xlim(0, 100)
    axes[1].set_xlabel('Accuracy (%)', fontsize=12)
    axes[1].set_title('BERT Baseline Classification Accuracy', fontsize=14, fontweight='bold')
    
    for bar, acc in zip(bars2, accuracies):
        axes[1].text(acc + 1, bar.get_y() + bar.get_height()/2,
                    f'{acc:.1f}%', va='center', fontsize=11, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig('./outputs/colab_experiments/text_comparison/text_method_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("\nüìÅ Figure saved to: outputs/colab_experiments/text_comparison/text_method_comparison.png")

### 4.3 Visualization: Token Attribution Examples

Comparing which tokens IG and DiET consider most important.

In [None]:
# Display token comparison HTML if available
import glob
from IPython.display import HTML, display

html_files = glob.glob('./outputs/colab_experiments/text_comparison/**/text_comparison.html', recursive=True)
if html_files:
    print("\nüî§ Token Attribution Comparisons:")
    for path in html_files[:2]:  # Show first 2 datasets
        dataset_name = path.split('/')[-3]
        print(f"\n--- {dataset_name.upper()} ---")
        with open(path, 'r') as f:
            html_content = f.read()
        display(HTML(html_content))
else:
    print("\nüìù Token comparison visualizations will be generated after experiment completion.")

In [None]:
# Generate visualization
try:
    viz_files = comparison.visualize_results(save_plots=True, show=True)
    print("\nüìä Generated visualization files:")
    for name, path in viz_files.items():
        print(f"   ‚Ä¢ {name}: {path}")
except Exception as e:
    print(f"Visualization generation note: {e}")

### 4.4 Interpretation Analysis

In [None]:
# Detailed interpretation analysis
if len(text_df) > 0:
    print("\n" + "=" * 70)
    print("üîç INTERPRETATION ANALYSIS")
    print("=" * 70)
    
    avg_overlap = text_df["IG-DiET Overlap"].mean()
    
    print(f"\nüìä Overall IG-DiET Agreement: {avg_overlap:.1%}")
    
    print("\nüìù What does this mean?")
    print("="*50)
    
    if avg_overlap >= 0.7:
        print("""
   ‚úÖ HIGH AGREEMENT:
   Both IG and DiET identify similar important tokens.
   This suggests:
   ‚Ä¢ The model relies on genuinely discriminative features
   ‚Ä¢ Post-hoc explanations (IG) align with inherent importance (DiET)
   ‚Ä¢ Explanations are likely faithful to model reasoning
        """)
    elif avg_overlap >= 0.4:
        print("""
   üü° MODERATE AGREEMENT:
   IG and DiET partially agree on important tokens.
   This suggests:
   ‚Ä¢ Some tokens are universally recognized as important
   ‚Ä¢ DiET may identify additional discriminative features
   ‚Ä¢ IG may include correlated but not causal features
        """)
    else:
        print("""
   üî¥ LOW AGREEMENT:
   IG and DiET identify different tokens as important.
   This suggests:
   ‚Ä¢ IG may highlight spurious correlations
   ‚Ä¢ DiET focuses on truly discriminative features
   ‚Ä¢ Post-hoc explanations may not reflect true model reasoning
        """)
    
    print("\nüìå Key Insight:")
    print("   DiET is designed to find the minimal set of tokens")
    print("   that are NECESSARY for classification, while IG")
    print("   measures SENSITIVITY to each token.")

---

## 5. Discussion and Conclusions

### 5.1 Key Findings

In [None]:
# Generate comprehensive summary report
if len(text_df) > 0:
    print("\n" + "=" * 70)
    print("üìã EXPERIMENT SUMMARY REPORT")
    print("=" * 70)
    
    avg_overlap = text_df["IG-DiET Overlap"].mean()
    avg_accuracy = text_df["Baseline Accuracy"].mean()
    
    print(f"\nüìä OVERALL STATISTICS:")
    print(f"   ‚Ä¢ Datasets evaluated: {len(text_df)}")
    print(f"   ‚Ä¢ Average BERT accuracy: {avg_accuracy:.1f}%")
    print(f"   ‚Ä¢ Average IG-DiET overlap: {avg_overlap:.4f}")
    print(f"   ‚Ä¢ Top-k tokens compared: {CONFIG['top_k']}")
    
    print(f"\nüîç PER-DATASET BREAKDOWN:")
    for _, row in text_df.iterrows():
        overlap = row.get("IG-DiET Overlap", 0)
        status = "üü¢" if overlap >= 0.5 else "üü°" if overlap >= 0.3 else "üî¥"
        print(f"   {status} {row['Dataset']}: Overlap={overlap:.4f}, Accuracy={row['Baseline Accuracy']:.1f}%")
    
    print(f"\nüìù KEY CONCLUSIONS:")
    print(f"   1. BERT achieves strong baseline performance ({avg_accuracy:.1f}% avg)")
    print(f"   2. IG-DiET agreement varies by dataset complexity")
    if avg_overlap >= 0.5:
        print(f"   3. Both methods generally agree on important tokens")
    else:
        print(f"   3. Methods identify different features - suggests IG may include spurious tokens")
    
    print("\n" + "=" * 70)

### 5.2 Theoretical Implications

**Why might IG and DiET disagree?**

1. **Sensitivity vs. Necessity:**
   - IG measures how sensitive predictions are to each token
   - DiET identifies which tokens are *necessary* for the prediction

2. **Spurious Correlations:**
   - IG may highlight tokens correlated with labels but not causal
   - DiET's distractor erasure removes such spurious features

3. **Baseline Dependency:**
   - IG attributions depend on the choice of baseline (zero embedding)
   - DiET learns a data-driven importance measure

### 5.3 Limitations and Future Work

**Limitations:**
- Limited fine-tuning epochs due to computational constraints
- Top-k comparison may miss nuanced differences
- Results may vary with different random seeds

**Future Work:**
- Extend to other transformer architectures (RoBERTa, DistilBERT)
- Compare with other attribution methods (LIME, SHAP, Attention)
- Conduct human evaluation of explanations
- Investigate relationship between overlap and model faithfulness

---

## 6. Save Results and Export Report

In [None]:
# Save all results
import json

# Create comprehensive results dictionary
full_results = {
    "experiment": "DiET vs Integrated Gradients Text Comparison",
    "date": datetime.now().isoformat(),
    "configuration": CONFIG,
    "device": DEVICE,
    "gpu_config": GPU_CONFIG,
    "results": results,
}

# Save to JSON
results_path = "./outputs/colab_experiments/text_comparison/full_results.json"
os.makedirs(os.path.dirname(results_path), exist_ok=True)

def make_serializable(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, (np.int32, np.int64)):
        return int(obj)
    elif isinstance(obj, dict):
        return {k: make_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [make_serializable(item) for item in obj]
    return obj

with open(results_path, 'w') as f:
    json.dump(make_serializable(full_results), f, indent=2)

# Save DataFrame to CSV
if len(text_df) > 0:
    text_df.to_csv('./outputs/colab_experiments/text_comparison/results_summary.csv', index=False)

print("\n‚úÖ Results saved successfully!")
print(f"   üìÑ JSON: {results_path}")
print(f"   üìä CSV: ./outputs/colab_experiments/text_comparison/results_summary.csv")

In [None]:
# Generate printable report
report_content = f"""
================================================================================
                    DiET vs INTEGRATED GRADIENTS COMPARISON
                         Text Classification Analysis
================================================================================

Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
Device: {DEVICE} ({GPU_CONFIG})

CONFIGURATION:
  ‚Ä¢ Datasets: {', '.join(CONFIG['datasets'])}
  ‚Ä¢ Model: BERT-base-uncased
  ‚Ä¢ Max sequence length: {CONFIG['max_length']}
  ‚Ä¢ Training epochs: {CONFIG['epochs']}
  ‚Ä¢ Batch size: {CONFIG['batch_size']}
  ‚Ä¢ Comparison samples: {CONFIG['comparison_samples']}
  ‚Ä¢ Top-k tokens: {CONFIG['top_k']}

RESULTS:
"""

if len(text_df) > 0:
    for _, row in text_df.iterrows():
        report_content += f"""
  {row['Dataset'].upper()}:
    ‚Ä¢ Baseline Accuracy: {row['Baseline Accuracy']:.1f}%
    ‚Ä¢ IG-DiET Overlap: {row.get('IG-DiET Overlap', 0):.4f}
"""
    
    avg_overlap = text_df["IG-DiET Overlap"].mean()
    avg_accuracy = text_df["Baseline Accuracy"].mean()
    
    report_content += f"""
SUMMARY:
  ‚Ä¢ Average Accuracy: {avg_accuracy:.1f}%
  ‚Ä¢ Average IG-DiET Overlap: {avg_overlap:.4f}
  
INTERPRETATION:
  {'High agreement - Both methods identify similar important tokens' if avg_overlap >= 0.5 else 'Lower agreement - Methods may identify different features'}

================================================================================
"""

# Save report
with open('./outputs/colab_experiments/text_comparison/experiment_report.txt', 'w') as f:
    f.write(report_content)

print(report_content)

In [None]:
# Download results (for Colab)
try:
    from google.colab import files
    
    # Create zip of all results
    !zip -r text_comparison_results.zip ./outputs/colab_experiments/text_comparison/
    
    print("\nüì• Download your results:")
    files.download('text_comparison_results.zip')
except:
    print("\nüìÅ Results are saved locally in: ./outputs/colab_experiments/text_comparison/")
    print("   (Download option only available in Google Colab)")

---

## üìö References

1. Bhalla, U., et al. (2023). "Discriminative Feature Attributions: Bridging Post Hoc Explainability and Inherent Interpretability." *NeurIPS 2023.*

2. Sundararajan, M., Taly, A., & Yan, Q. (2017). "Axiomatic Attribution for Deep Networks." *ICML 2017.*

3. Devlin, J., et al. (2019). "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding." *NAACL 2019.*

4. Socher, R., et al. (2013). "Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank." *EMNLP 2013.* (SST-2)

5. Maas, A., et al. (2011). "Learning Word Vectors for Sentiment Analysis." *ACL 2011.* (IMDB)

---

**Notebook Version:** 1.0  
**Last Updated:** 2025-2026 Academic Year  
**Repository:** [github.com/xMOROx/Machine-Learning-Project-2025-2026](https://github.com/xMOROx/Machine-Learning-Project-2025-2026)