# XAI Evaluation: Grad-CAM, IG, RISE + Faithfulness Metrics

**Goals:**
- Generate saliency maps using Grad-CAM, Integrated Gradients, RISE
- Evaluate faithfulness: Insertion/Deletion AUC, Pointing Game, Saliency IoU
- Compare methods across sample test images
- Visualize overlays with ground truth masks
- Perform sanity checks (randomization test)

In [None]:
# Imports
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from PIL import Image

from src.models.cls_resnet18 import ResNet18Classifier
from src.xai.grad_cam import GradCAM, GradCAMPlusPlus
from src.xai.integrated_gradients import IntegratedGradientsExplainer
from src.xai.rise import RISE
from src.xai.faithfulness import (
    insertion_auc,
    deletion_auc,
    pointing_game,
    saliency_iou,
    randomization_test
)

%matplotlib inline

## 1. Load Trained Model

In [None]:
# TODO: Load best classification checkpoint
# model = ResNet18Classifier.load_from_checkpoint('path/to/best.ckpt')
# model.eval()
# model.cuda()

## 2. Select Test Samples

In [None]:
# TODO: Select ~20 test images across lesion size spectrum
# - Small, medium, large lesions
# - Both benign and malignant
# - Correct and incorrect predictions

# TODO: Load corresponding ground truth masks

## 3. Generate Saliency Maps

In [None]:
# TODO: Initialize XAI methods
# grad_cam = GradCAM(model, target_layer=model.layer4[-1])
# grad_cam_pp = GradCAMPlusPlus(model, target_layer=model.layer4[-1])
# ig_explainer = IntegratedGradientsExplainer(model)
# rise = RISE(model, n_masks=800)

# TODO: Generate saliency for each test image
# for img in test_images:
#     saliency_gradcam = grad_cam.generate(img)
#     saliency_gradcam_pp = grad_cam_pp.generate(img)
#     saliency_ig = ig_explainer.generate(img, n_steps=50)
#     saliency_rise = rise.generate(img)

## 4. Faithfulness Metrics

In [None]:
# TODO: Compute faithfulness metrics for each method
# results = []
# for img, mask, saliency in zip(images, masks, saliencies):
#     ins_auc = insertion_auc(model, img, saliency)
#     del_auc = deletion_auc(model, img, saliency)
#     pg = pointing_game(saliency, mask)
#     sal_iou = saliency_iou(saliency, mask, top_k_percent=20)
#     results.append({'insertion': ins_auc, 'deletion': del_auc, 'pointing': pg, 'iou': sal_iou})

# TODO: Aggregate metrics across all samples per method

## 5. Insertion/Deletion Curves

In [None]:
# TODO: Plot insertion/deletion curves
# - Compare Grad-CAM vs IG vs RISE
# - Plot on same axes
# - Show random baseline
# - Compute AUC for each method

## 6. Visualization Panel

In [None]:
# TODO: Create comprehensive visualization for each test image
# Row layout:
# [Input Image] [GT Mask Overlay] [Grad-CAM] [Grad-CAM++] [IG] [RISE]

# TODO: Add metrics annotation:
# - Pointing Game result (✓/✗)
# - Saliency IoU
# - Insertion/Deletion AUC

## 7. Sanity Checks

In [None]:
# TODO: Randomization test
# - Progressively randomize model weights
# - Generate Grad-CAM at each randomization level
# - Verify saliency degrades (visual similarity to random should decrease)

# TODO: Data randomization test (optional)
# - Feed random noise as input
# - Saliency should not match GT mask

## 8. Method Comparison

In [None]:
# TODO: Create comparison table
# Columns: Method | Insertion AUC | Deletion AUC | Pointing Game (%) | Saliency IoU
# Rows: Grad-CAM, Grad-CAM++, IG, RISE

# TODO: Statistical comparison
# - Paired t-test or Wilcoxon test between methods
# - Highlight best performing method per metric

## 9. Qualitative Analysis

In [None]:
# TODO: Analyze strengths/weaknesses of each method
# - Which method best highlights lesion boundaries?
# - Which is most consistent across samples?
# - Any spurious highlights outside lesion?
# - Computational cost comparison

## Summary

TODO: Print final XAI evaluation summary:
- Best method by Insertion AUC
- Best method by Pointing Game
- Overall recommendation
- Quality gate: Pointing Game ≥ 0.8 on malignant?