# üî¨ DiET vs GradCAM: Discriminative Feature Attribution for Image Classification

## A Comprehensive Comparison Study

---

**Author:** Machine Learning Research Team  
**Date:** 2025-2026 Academic Year  
**Course:** Advanced Machine Learning

---

### üìã Abstract

This notebook presents a comprehensive experimental comparison between **DiET (Distractor Erasure Tuning)** and **GradCAM** for explainable AI in image classification tasks. We evaluate both methods on multiple image datasets (CIFAR-10, CIFAR-100, SVHN, Fashion-MNIST) using robust evaluation metrics including pixel perturbation, AOPC, and faithfulness correlation.

### üéØ Research Questions

1. Does DiET produce more discriminative feature attributions than GradCAM?
2. How do both methods perform across different image classification datasets?
3. What is the trade-off between model accuracy and attribution quality?

### üìö Reference

Bhalla, U., et al. (2023). **"Discriminative Feature Attributions: Bridging Post Hoc Explainability and Inherent Interpretability."** *NeurIPS 2023.*

---

## 1. Environment Setup

### 1.1 Check GPU Availability

This notebook is optimized for **Google Colab Pro** with GPU acceleration. We recommend using a T4 or A100 GPU for faster training.

In [None]:
# Check GPU availability and type
import torch
import subprocess

print("=" * 60)
print("üñ•Ô∏è  HARDWARE CONFIGURATION")
print("=" * 60)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f"‚úÖ GPU Available: {gpu_name}")
    print(f"   Memory: {gpu_memory:.1f} GB")
    
    # Check CUDA version
    print(f"   CUDA Version: {torch.version.cuda}")
    print(f"   PyTorch Version: {torch.__version__}")
    
    if gpu_memory >= 15:
        print("\nüöÄ High-memory GPU detected! Using optimal settings.")
        GPU_CONFIG = "high"
    elif gpu_memory >= 8:
        print("\n‚ú® Standard GPU detected. Using balanced settings.")
        GPU_CONFIG = "standard"
    else:
        print("\n‚ö†Ô∏è  Low-memory GPU detected. Using memory-efficient settings.")
        GPU_CONFIG = "low"
else:
    print("‚ùå No GPU available. Training will be slow.")
    print("   Please enable GPU: Runtime ‚Üí Change runtime type ‚Üí GPU")
    GPU_CONFIG = "cpu"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\nüìç Using device: {DEVICE.upper()}")
print("=" * 60)

### 1.2 Clone Repository and Install Dependencies

In [None]:
# Clone the repository
import os

REPO_URL = "https://github.com/xMOROx/Machine-Learning-Project-2025-2026.git"
REPO_DIR = "Machine-Learning-Project-2025-2026"

if not os.path.exists(REPO_DIR):
    print("üì• Cloning repository...")
    !git clone --recursive {REPO_URL}
    print("‚úÖ Repository cloned successfully!")
else:
    print("üìÅ Repository already exists. Pulling latest changes...")
    %cd {REPO_DIR}
    !git pull
    !git submodule update --init --recursive
    %cd ..

%cd {REPO_DIR}

In [None]:
# Install required packages
print("üì¶ Installing dependencies...")
print("This may take a few minutes on first run.\n")

!pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu118
!pip install -q transformers datasets tqdm matplotlib seaborn pandas numpy pillow scikit-learn captum

print("\n‚úÖ All dependencies installed!")

In [None]:
# Add project to path and import modules
import sys
sys.path.insert(0, './scripts/xai_experiments')

# Standard imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

print("‚úÖ All modules imported successfully!")
print(f"üìÖ Experiment started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

---

## 2. Experimental Configuration

### 2.1 Hyperparameters and Settings

We configure the experiment based on available GPU memory to optimize training speed while maintaining result quality.

In [None]:
# Configuration based on GPU capabilities
if GPU_CONFIG == "high":  # A100, V100, etc.
    CONFIG = {
        "batch_size": 128,
        "epochs": 10,
        "max_samples": 10000,
        "comparison_samples": 200,
        "datasets": ["cifar10", "cifar100", "svhn", "fashion_mnist"],
    }
elif GPU_CONFIG == "standard":  # T4, P100
    CONFIG = {
        "batch_size": 64,
        "epochs": 5,
        "max_samples": 5000,
        "comparison_samples": 100,
        "datasets": ["cifar10", "cifar100", "svhn", "fashion_mnist"],
    }
elif GPU_CONFIG == "low":  # K80, older GPUs
    CONFIG = {
        "batch_size": 32,
        "epochs": 3,
        "max_samples": 2000,
        "comparison_samples": 50,
        "datasets": ["cifar10", "svhn"],  # Fewer datasets for speed
    }
else:  # CPU
    CONFIG = {
        "batch_size": 16,
        "epochs": 2,
        "max_samples": 1000,
        "comparison_samples": 20,
        "datasets": ["cifar10"],  # Single dataset for demo
    }

# Display configuration
print("=" * 60)
print("‚öôÔ∏è  EXPERIMENT CONFIGURATION")
print("=" * 60)
for key, value in CONFIG.items():
    print(f"   {key}: {value}")
print("=" * 60)

### 2.2 Dataset Overview

We evaluate on the following image classification datasets:

| Dataset | Images | Classes | Image Size | Description |
|---------|--------|---------|------------|-------------|
| **CIFAR-10** | 60,000 | 10 | 32√ó32 | Natural images |
| **CIFAR-100** | 60,000 | 100 | 32√ó32 | Fine-grained |
| **SVHN** | 73,257 | 10 | 32√ó32 | Street numbers |
| **Fashion-MNIST** | 70,000 | 10 | 28√ó28 | Fashion products |

In [None]:
# Visualize sample images from datasets
import torchvision
import torchvision.transforms as transforms

def show_dataset_samples(dataset_name, num_samples=8):
    """Display sample images from a dataset."""
    transform = transforms.Compose([transforms.ToTensor()])
    
    if dataset_name == "cifar10":
        dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    elif dataset_name == "cifar100":
        dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
        classes = None  # Too many to display
    elif dataset_name == "svhn":
        dataset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform)
        classes = [str(i) for i in range(10)]
    elif dataset_name == "fashion_mnist":
        dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
        classes = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Boot']
    
    fig, axes = plt.subplots(1, num_samples, figsize=(16, 2))
    for i in range(num_samples):
        img, label = dataset[i]
        axes[i].imshow(img.permute(1, 2, 0).numpy() if img.shape[0] == 3 else img.squeeze().numpy(), cmap='gray' if img.shape[0] == 1 else None)
        if classes:
            axes[i].set_title(classes[label], fontsize=10)
        else:
            axes[i].set_title(f"Class {label}", fontsize=10)
        axes[i].axis('off')
    plt.suptitle(f"{dataset_name.upper()} Sample Images", fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

print("üìä Dataset Samples Preview")
print("=" * 60)
for dataset in CONFIG["datasets"]:
    show_dataset_samples(dataset)

---

## 3. DiET vs GradCAM Comparison Framework

### 3.1 Method Overview

#### GradCAM (Gradient-weighted Class Activation Mapping)
- **Type:** Post-hoc explanation method
- **Approach:** Uses gradients flowing into the final convolutional layer
- **Limitation:** May highlight regions that are correlated but not discriminative

#### DiET (Distractor Erasure Tuning)
- **Type:** Inherent interpretability via fine-tuning
- **Approach:** Learns masks that preserve model predictions while being sparse
- **Advantage:** Produces truly discriminative attributions

In [None]:
# Import the comparison framework
from experiments.xai_comparison import XAIMethodsComparison, ComparisonConfig

# Create configuration
comparison_config = ComparisonConfig(
    device=DEVICE,
    image_datasets=CONFIG["datasets"],
    image_batch_size=CONFIG["batch_size"],
    image_epochs=CONFIG["epochs"],
    image_max_samples=CONFIG["max_samples"],
    image_comparison_samples=CONFIG["comparison_samples"],
    output_dir="./outputs/colab_experiments/image_comparison",
)

print("‚úÖ Comparison framework initialized!")
print(f"üìÅ Output directory: {comparison_config.output_dir}")

### 3.2 Run Experiments

‚è±Ô∏è **Estimated Time:** 
- High-memory GPU: ~30-45 minutes
- Standard GPU: ~20-30 minutes
- Low-memory GPU: ~15-20 minutes
- CPU: ~60+ minutes (not recommended)

In [None]:
# Initialize comparison
comparison = XAIMethodsComparison(comparison_config)

# Run full comparison (images only)
print("\n" + "=" * 70)
print("üöÄ STARTING DiET vs GradCAM COMPARISON")
print("=" * 70)
print(f"\nüìä Datasets: {CONFIG['datasets']}")
print(f"üî¢ Samples per dataset: {CONFIG['max_samples']}")
print(f"üìà Training epochs: {CONFIG['epochs']}")
print(f"\n‚è≥ This may take a while. Progress will be shown below...\n")

start_time = datetime.now()
results = comparison.run_full_comparison(run_images=True, run_text=False)
end_time = datetime.now()

print(f"\n‚úÖ Experiment completed in {(end_time - start_time).seconds // 60} minutes {(end_time - start_time).seconds % 60} seconds!")

---

## 4. Results Analysis

### 4.1 Quantitative Results

In [None]:
# Get results as DataFrame
df = comparison.get_results_dataframe()

# Display results table
print("\n" + "=" * 70)
print("üìä QUANTITATIVE RESULTS SUMMARY")
print("=" * 70)

if len(df) > 0:
    # Format and display the DataFrame
    display_df = df[["Dataset", "GradCAM Score", "DiET Score", "Improvement", "DiET Better"]].copy()
    display_df["GradCAM Score"] = display_df["GradCAM Score"].apply(lambda x: f"{x:.4f}" if pd.notnull(x) else "N/A")
    display_df["DiET Score"] = display_df["DiET Score"].apply(lambda x: f"{x:.4f}" if pd.notnull(x) else "N/A")
    display_df["Improvement"] = display_df["Improvement"].apply(lambda x: f"{x:+.4f}" if pd.notnull(x) else "N/A")
    display_df["DiET Better"] = display_df["DiET Better"].apply(lambda x: "‚úÖ Yes" if x else "‚ùå No")
    
    print(display_df.to_string(index=False))
    
    # Summary statistics
    diet_wins = (df["DiET Better"] == True).sum()
    total = len(df)
    avg_improvement = df["Improvement"].mean()
    
    print("\n" + "-" * 70)
    print(f"üìà DiET outperforms GradCAM on {diet_wins}/{total} datasets")
    print(f"üìä Average improvement: {avg_improvement:+.4f}")
else:
    print("No results available. Please run the experiment first.")

### 4.2 Visualization: Method Comparison Across Datasets

In [None]:
# Create comparison visualizations
if len(df) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Bar chart: GradCAM vs DiET scores
    datasets = df["Dataset"].tolist()
    gradcam_scores = df["GradCAM Score"].tolist()
    diet_scores = df["DiET Score"].tolist()
    
    x = np.arange(len(datasets))
    width = 0.35
    
    bars1 = axes[0].bar(x - width/2, gradcam_scores, width, label='GradCAM', color='#2196F3', alpha=0.8)
    bars2 = axes[0].bar(x + width/2, diet_scores, width, label='DiET', color='#4CAF50', alpha=0.8)
    
    axes[0].set_ylabel('Pixel Perturbation Score', fontsize=12)
    axes[0].set_title('Attribution Quality Comparison\n(Higher = Better)', fontsize=14, fontweight='bold')
    axes[0].set_xticks(x)
    axes[0].set_xticklabels(datasets, rotation=45, ha='right')
    axes[0].legend(loc='upper right')
    axes[0].set_ylim(0, 1)
    
    # Add value labels
    for bar in bars1 + bars2:
        height = bar.get_height()
        axes[0].annotate(f'{height:.2f}',
                        xy=(bar.get_x() + bar.get_width() / 2, height),
                        xytext=(0, 3),
                        textcoords="offset points",
                        ha='center', va='bottom', fontsize=9)
    
    # Improvement chart
    improvements = df["Improvement"].tolist()
    colors = ['#4CAF50' if imp > 0 else '#F44336' for imp in improvements]
    
    axes[1].barh(datasets, improvements, color=colors, alpha=0.8)
    axes[1].axvline(x=0, color='black', linestyle='-', linewidth=0.5)
    axes[1].set_xlabel('Improvement (DiET - GradCAM)', fontsize=12)
    axes[1].set_title('DiET Improvement Over GradCAM\n(Positive = DiET Better)', fontsize=14, fontweight='bold')
    
    for i, v in enumerate(improvements):
        axes[1].text(v + 0.01 if v >= 0 else v - 0.01, i, f'{v:+.3f}', 
                    va='center', ha='left' if v >= 0 else 'right', fontsize=10)
    
    plt.tight_layout()
    plt.savefig('./outputs/colab_experiments/image_comparison/method_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("\nüìÅ Figure saved to: outputs/colab_experiments/image_comparison/method_comparison.png")

### 4.3 Visualization: Attribution Heatmaps

Comparing GradCAM and DiET attributions on sample images.

In [None]:
# Generate and display comparison visualizations
try:
    viz_files = comparison.visualize_results(save_plots=True, show=True)
    print("\nüìä Generated visualization files:")
    for name, path in viz_files.items():
        print(f"   ‚Ä¢ {name}: {path}")
except Exception as e:
    print(f"Visualization generation skipped: {e}")

In [None]:
# Display saved heatmap comparisons if available
import glob
from IPython.display import Image, display

viz_paths = glob.glob('./outputs/colab_experiments/image_comparison/**/diet_vs_gradcam.png', recursive=True)
if viz_paths:
    print("\nüñºÔ∏è  Attribution Heatmap Comparisons:")
    for path in viz_paths[:2]:  # Show first 2 datasets
        dataset_name = path.split('/')[-3]
        print(f"\n--- {dataset_name.upper()} ---")
        display(Image(filename=path, width=800))

### 4.4 Statistical Analysis

In [None]:
# Perform statistical analysis
if len(df) > 0:
    from scipy import stats
    
    gradcam_scores = df["GradCAM Score"].values
    diet_scores = df["DiET Score"].values
    
    # Paired t-test (if enough samples)
    if len(gradcam_scores) >= 3:
        t_stat, p_value = stats.ttest_rel(diet_scores, gradcam_scores)
        
        print("\n" + "=" * 70)
        print("üìä STATISTICAL ANALYSIS")
        print("=" * 70)
        print(f"\n  Paired t-test (DiET vs GradCAM):")
        print(f"    t-statistic: {t_stat:.4f}")
        print(f"    p-value: {p_value:.4f}")
        
        if p_value < 0.05:
            print(f"\n  ‚úÖ Result is statistically significant (p < 0.05)")
            if t_stat > 0:
                print(f"     ‚Üí DiET significantly outperforms GradCAM")
            else:
                print(f"     ‚Üí GradCAM significantly outperforms DiET")
        else:
            print(f"\n  ‚ö†Ô∏è  Result is not statistically significant (p = {p_value:.4f})")
            print(f"     ‚Üí No significant difference between methods")
    else:
        print("\n‚ö†Ô∏è  Not enough datasets for statistical testing (need ‚â• 3)")
    
    # Effect size (Cohen's d)
    if len(gradcam_scores) >= 2:
        pooled_std = np.sqrt((np.var(gradcam_scores) + np.var(diet_scores)) / 2)
        cohens_d = (np.mean(diet_scores) - np.mean(gradcam_scores)) / pooled_std if pooled_std > 0 else 0
        
        print(f"\n  Effect Size (Cohen's d): {cohens_d:.4f}")
        if abs(cohens_d) < 0.2:
            print("     ‚Üí Small effect")
        elif abs(cohens_d) < 0.8:
            print("     ‚Üí Medium effect")
        else:
            print("     ‚Üí Large effect")

---

## 5. Discussion and Conclusions

### 5.1 Key Findings

In [None]:
# Generate summary report
if len(df) > 0:
    print("\n" + "=" * 70)
    print("üìã EXPERIMENT SUMMARY REPORT")
    print("=" * 70)
    
    diet_wins = (df["DiET Better"] == True).sum()
    total = len(df)
    avg_gradcam = df["GradCAM Score"].mean()
    avg_diet = df["DiET Score"].mean()
    avg_improvement = df["Improvement"].mean()
    
    print(f"\nüìä OVERALL PERFORMANCE:")
    print(f"   ‚Ä¢ DiET wins: {diet_wins}/{total} datasets ({100*diet_wins/total:.1f}%)")
    print(f"   ‚Ä¢ Average GradCAM Score: {avg_gradcam:.4f}")
    print(f"   ‚Ä¢ Average DiET Score: {avg_diet:.4f}")
    print(f"   ‚Ä¢ Average Improvement: {avg_improvement:+.4f}")
    
    print(f"\nüîç KEY OBSERVATIONS:")
    if avg_improvement > 0:
        print(f"   1. DiET produces more discriminative attributions than GradCAM overall")
        print(f"   2. The improvement is consistent across {diet_wins} of {total} tested datasets")
    else:
        print(f"   1. GradCAM shows competitive or better performance in this experiment")
        print(f"   2. Consider running with more samples or epochs for definitive results")
    
    print(f"\nüìà PER-DATASET BREAKDOWN:")
    for _, row in df.iterrows():
        status = "‚úÖ" if row["DiET Better"] else "‚ùå"
        print(f"   {status} {row['Dataset']}: GradCAM={row['GradCAM Score']:.4f}, DiET={row['DiET Score']:.4f} ({row['Improvement']:+.4f})")
    
    print("\n" + "=" * 70)

### 5.2 Limitations and Future Work

**Limitations:**
- Limited training epochs due to computational constraints
- Fixed DiET hyperparameters may not be optimal for all datasets
- Evaluation limited to pixel perturbation metric

**Future Work:**
- Extend to larger datasets (ImageNet)
- Compare with other XAI methods (SHAP, LIME, Attention)
- Investigate DiET hyperparameter sensitivity
- Human evaluation of attribution quality

---

## 6. Save Results and Export Report

In [None]:
# Save all results
import json

# Create comprehensive results dictionary
full_results = {
    "experiment": "DiET vs GradCAM Image Comparison",
    "date": datetime.now().isoformat(),
    "configuration": CONFIG,
    "device": DEVICE,
    "gpu_config": GPU_CONFIG,
    "results": results,
}

# Save to JSON
results_path = "./outputs/colab_experiments/image_comparison/full_results.json"
os.makedirs(os.path.dirname(results_path), exist_ok=True)

def make_serializable(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, (np.int32, np.int64)):
        return int(obj)
    elif isinstance(obj, dict):
        return {k: make_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [make_serializable(item) for item in obj]
    return obj

with open(results_path, 'w') as f:
    json.dump(make_serializable(full_results), f, indent=2)

# Save DataFrame to CSV
if len(df) > 0:
    df.to_csv('./outputs/colab_experiments/image_comparison/results_summary.csv', index=False)

print("\n‚úÖ Results saved successfully!")
print(f"   üìÑ JSON: {results_path}")
print(f"   üìä CSV: ./outputs/colab_experiments/image_comparison/results_summary.csv")

In [None]:
# Download results (for Colab)
try:
    from google.colab import files
    
    # Create zip of all results
    !zip -r image_comparison_results.zip ./outputs/colab_experiments/image_comparison/
    
    print("\nüì• Download your results:")
    files.download('image_comparison_results.zip')
except:
    print("\nüìÅ Results are saved locally in: ./outputs/colab_experiments/image_comparison/")
    print("   (Download option only available in Google Colab)")

---

## üìö References

1. Bhalla, U., et al. (2023). "Discriminative Feature Attributions: Bridging Post Hoc Explainability and Inherent Interpretability." *NeurIPS 2023.*

2. Selvaraju, R. R., et al. (2017). "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization." *ICCV 2017.*

3. Krizhevsky, A. (2009). "Learning Multiple Layers of Features from Tiny Images." *Technical Report, University of Toronto.*

---

**Notebook Version:** 1.0  
**Last Updated:** 2025-2026 Academic Year  
**Repository:** [github.com/xMOROx/Machine-Learning-Project-2025-2026](https://github.com/xMOROx/Machine-Learning-Project-2025-2026)