# 03 — XAI Inspection (CAM overlays + quantitative XAI metrics)

**Objective**: Visual inspection of CAM overlays and a minimal quantitative proxy evaluation.

This notebook is designed to:
- load CAM heatmaps already exported by your pipeline (`outputs/figures/xai/...`)
- display class-wise samples for qualitative figures
- compute lightweight **coverage** statistics if GT boxes are available

Note: For paper-grade metrics (IoU-Heat / PointAcc / Relevance / Fidelity), prefer your `src/xai/` pipeline.


In [None]:
from pathlib import Path
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2

DATA_ROOT = Path(os.getenv("DATA_ROOT", "../data"))
OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "../outputs"))

DATASET_ROOT = DATA_ROOT / "processed" / "D1_yolo640"   # or D2_balanced_yolo640
IMAGES_DIR = DATASET_ROOT / "images"
LABELS_DIR = DATASET_ROOT / "labels"

XAI_DIR = OUTPUT_ROOT / "figures" / "xai"   # expected from configs/xai/cam_methods.yaml

CLASSES = [
    "Cleft Lip",
    "Epibulbar Dermoid",
    "Eyelid Coloboma",
    "Facial Asymmetry",
    "Malocclusion",
    "Microtia",
    "Vertebral Abnormalities",
]
NUM_CLASSES = len(CLASSES)

def parse_yolo_label_file(p: Path):
    if not p.exists():
        return []
    txt = p.read_text(encoding="utf-8").strip()
    if not txt:
        return []
    rows = []
    for line in txt.splitlines():
        parts = line.strip().split()
        if len(parts) != 5:
            continue
        c, x, y, w, h = parts
        rows.append((int(float(c)), float(x), float(y), float(w), float(h)))
    return rows

def yolo_to_xyxy(norm_xywh, w, h):
    c, x, y, bw, bh = norm_xywh
    x1 = (x - bw/2) * w
    y1 = (y - bh/2) * h
    x2 = (x + bw/2) * w
    y2 = (y + bh/2) * h
    return int(c), np.array([x1, y1, x2, y2], dtype=float)


## Qualitative gallery

Expected CAM export layout (recommended):

```
outputs/figures/xai/{method}/{class_name}/...png
```

If your pipeline uses a different folder structure, update the globbing below.


In [None]:
METHOD = "gradcampp"   # try: gradcam, layercam, eigencam, hirescam
N_SHOW = 8

method_dir = XAI_DIR / METHOD
print("Method dir:", method_dir)

if not method_dir.exists():
    print("XAI directory not found. Export CAMs first, or update XAI_DIR.")
else:
    class_idx = 0
    class_name = CLASSES[class_idx]
    cand = sorted(method_dir.rglob("*.png"))
    cand = [p for p in cand if class_name.lower().replace(" ", "") in p.name.lower().replace(" ", "")]
    print("Candidates:", len(cand))

    show = cand[:N_SHOW]
    cols = 4
    rows = int(np.ceil(len(show) / cols))
    fig = plt.figure(figsize=(12, 3*rows), dpi=200)
    for i, p in enumerate(show, start=1):
        img = cv2.imread(str(p))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        ax = plt.subplot(rows, cols, i)
        ax.imshow(img)
        ax.set_title(p.name, fontsize=8)
        ax.axis("off")
    plt.tight_layout()
    plt.show()


## Lightweight quantitative proxy (coverage)

This computes a simple proxy: fraction of the ground-truth box area covered by the top-X% heat region.

If you have raw heatmaps (not only overlay PNGs), store them as grayscale maps for correct evaluation.


In [None]:
TOP_PERCENT = 0.15  # top 15% of heat activations

def box_mask(h, w, box_xyxy):
    x1, y1, x2, y2 = box_xyxy
    x1 = int(np.clip(x1, 0, w-1))
    x2 = int(np.clip(x2, 0, w-1))
    y1 = int(np.clip(y1, 0, h-1))
    y2 = int(np.clip(y2, 0, h-1))
    m = np.zeros((h, w), dtype=bool)
    m[y1:y2+1, x1:x2+1] = True
    return m

def top_region_mask(heat, top_percent=0.15):
    flat = heat.flatten()
    thr = np.quantile(flat, 1.0 - top_percent)
    return heat >= thr

def coverage_score(gt_mask, hot_mask):
    inter = (gt_mask & hot_mask).sum()
    denom = gt_mask.sum() + 1e-12
    return float(inter / denom)


In [None]:
# Optional evaluation if you have grayscale heatmaps per image.
# Expected format (recommended): outputs/figures/xai/{method}/heatmaps/{image_stem}.png (single-channel)
HEAT_DIR = method_dir / "heatmaps"
if not HEAT_DIR.exists():
    print("No heatmaps folder found:", HEAT_DIR)
    print("Export raw heatmaps to compute coverage accurately.")
else:
    heatmaps = sorted(HEAT_DIR.glob("*.png"))
    scores = []
    for hp in heatmaps[:200]:  # cap for speed
        stem = hp.stem
        img_candidates = list(IMAGES_DIR.rglob(stem + ".jpg")) + list(IMAGES_DIR.rglob(stem + ".png"))
        if not img_candidates:
            continue
        img_path = img_candidates[0]
        img = cv2.imread(str(img_path))
        if img is None:
            continue
        h, w = img.shape[:2]

        heat = cv2.imread(str(hp), cv2.IMREAD_GRAYSCALE)
        if heat is None:
            continue
        heat = cv2.resize(heat, (w, h), interpolation=cv2.INTER_LINEAR)
        heat = heat.astype(np.float32)
        heat = (heat - heat.min()) / (heat.max() - heat.min() + 1e-12)

        gt_path = LABELS_DIR / (stem + ".txt")
        anns = parse_yolo_label_file(gt_path)
        if len(anns) == 0:
            continue

        hot = top_region_mask(heat, TOP_PERCENT)

        best = 0.0
        for ann in anns:
            _, box = yolo_to_xyxy(ann, w, h)
            gt = box_mask(h, w, box)
            best = max(best, coverage_score(gt, hot))
        scores.append(best)

    if scores:
        print("Coverage (mean ± std):", float(np.mean(scores)), float(np.std(scores)))
        plt.figure(figsize=(6,4), dpi=200)
        plt.hist(scores, bins=20)
        plt.xlabel("Coverage")
        plt.ylabel("Frequency")
        plt.title(f"CAM Coverage Proxy ({METHOD}, top {int(TOP_PERCENT*100)}%)")
        plt.tight_layout()
        plt.show()
    else:
        print("No scores computed. Check heatmap/image/label naming consistency.")
