# DiET XAI Framework

This notebook demonstrates how to use the DiET (Distractor Erasure Tuning) framework for explaining image and text classification models. It compares DiET attributions with GradCAM (for images) and Integrated Gradients (for text).

## 1. Setup

First, we ensure we have the necessary dependencies.

In [None]:
!pip install torch torchvision transformers datasets captum matplotlib scipy

## 2. Import Modules

We import the necessary classes from the repository.

In [None]:
import os
import sys

# Add repository root to path if running from notebook dir
if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")
    sys.path.append(os.getcwd())

from scripts.xai_experiments.experiments.xai_comparison import run_diet_comparison

## 3. Image Experiments

Run DiET on image datasets. You can choose from:
- `cifar10`
- `cifar100`
- `fashion_mnist`
- `mnist`

In [None]:
dataset_name = "cifar10"  # Change to 'cifar100', 'fashion_mnist', or 'mnist'

config = {
    "device": "cuda",
    "output_dir": f"./outputs/notebook_experiments/{dataset_name}",
    "dataset_name": dataset_name,
    "batch_size": 32,
    "max_samples_image": 1000,  # Reduced for demo speed
    "epochs_image": 2,
    "comparison_samples": 8
}

# Only run image experiments
from scripts.xai_experiments.experiments.diet_experiment import DiETExperiment

exp = DiETExperiment(config)
results = exp.run_full_experiment()

## 4. Text Experiments

Run DiET on text datasets. You can choose from:
- `sst2` (Sentiment Analysis)
- `imdb` (Movie Reviews)
- `ag_news` (News Classification)
- `yelp_review_full` (Yelp Reviews)

In [None]:
text_dataset = "sst2"  # Change to 'imdb', 'ag_news', or 'yelp_review_full'

config = {
    "device": "cuda",
    "output_dir": f"./outputs/notebook_experiments/{text_dataset}",
    "dataset_name": text_dataset,
    "batch_size": 8,
    "max_samples_text": 500,
    "epochs_text": 1,
    "comparison_samples_text": 5,
    "top_k": 10
}

# Only run text experiments
from scripts.xai_experiments.experiments.diet_text_experiment import DiETTextExperiment

exp = DiETTextExperiment(config)
results = exp.run_full_experiment()

## 5. View Results

The results and visualizations are saved in the output directory. We can display the comparison images and metric charts here.

In [None]:
from IPython.display import Image, display

print("Image Comparisons:")
image_vis_path = f"./outputs/notebook_experiments/{dataset_name}/comparison_visualizations/diet_vs_gradcam.png"
if os.path.exists(image_vis_path):
    display(Image(filename=image_vis_path))

print("\nText Metrics Comparison:")
metrics_vis_path = f"./outputs/notebook_experiments/{text_dataset}/comparison_visualizations/metrics_comparison.png"
if os.path.exists(metrics_vis_path):
    display(Image(filename=metrics_vis_path))
