# Title: Interpretability Analysis (Grad-CAM and SHAP)
# Description: Generates Grad-CAM overlays and SHAP plots for model auditing.

In [None]:
import torch
import numpy as np
from pathlib import Path
from src.model_architecture import FusionModel
from src.interpretability import make_gradcam, compute_shap_for_metadata
from src.data_loading import make_dataloaders
from src.utils import load_checkpoint
from src import config

# Load model

In [None]:
model = FusionModel(metadata_dim=len(config.METADATA_FEATURES))
ckpt_path = "experiments/checkpoints/epoch10_valauc0.7290.pt"
load_checkpoint(ckpt_path, model)
model.eval()

# Select test images

In [None]:
import pandas as pd
df = pd.read_csv("data/processed/rsna_processed.csv").sample(8, random_state=config.SEED)
_, _, test_loader = make_dataloaders(df, df, df, batch_size=4)

# Get one batch

In [None]:
batch = next(iter(test_loader))
images, metas, labels, _ = batch

# Convert tensors to numpy RGB [0,1]


In [None]:
def tensor_to_np(imgs):
    imgs = imgs.permute(0,2,3,1).cpu().numpy()
    imgs = (imgs - imgs.min()) / (imgs.max() - imgs.min())
    return imgs

# Grad-CAM visualization

In [None]:
target_layer = model.image_encoder.features.denseblock4
make_gradcam(model, target_layer, images, tensor_to_np, "outputs/gradcam")

print("✅ Grad-CAM visualizations saved to outputs/gradcam/")

# SHAP for metadata

In [None]:
background = metas[:50].numpy()
X = metas.numpy()

def model_predict(x):
    x = torch.tensor(x, dtype=torch.float32)
    with torch.no_grad():
        probs, _ = model(torch.zeros(len(x), 3, 224, 224), x)
    return probs.numpy()

shap_values = compute_shap_for_metadata(model_predict, background, X)
print("✅ SHAP computation complete.")