# SHAP Analysis for MagniNet

This notebook computes SHAP GradientExplainer attributions
to visualize regions contributing to malignant predictions.

In [ ]:
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from models import MagniNet
from utils.checkpoints import load_checkpoint
from interpretability.shap_explainer import (
    build_shap_gradient_explainer,
    shap_values_for_images,
    shap_attribution_map,
    normalize_01
)

In [ ]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
checkpoint = 'runs/magninet_40x/best.pt'

model = MagniNet(pretrained_backbones=False).to(device)
load_checkpoint(model, checkpoint, device)
model.eval()

In [ ]:
def load_img(path):
    img = Image.open(path).convert('RGB').resize((224,224))
    x = torch.from_numpy(np.array(img, dtype=np.float32)/255.0)
    return x.permute(2,0,1)

background_paths = ['bg1.jpg','bg2.jpg','bg3.jpg','bg4.jpg']
background = torch.stack([load_img(p) for p in background_paths], dim=0)

In [ ]:
explainer, _ = build_shap_gradient_explainer(
    model,
    background,
    mode='binary',
    device=device
)

In [ ]:
img_path = 'sample.jpg'
x = load_img(img_path).unsqueeze(0)
sv = shap_values_for_images(explainer, x, device=device)

heat = shap_attribution_map(sv, reduce_channels='sum_abs')
heat = normalize_01(heat)[0]

img_np = np.array(Image.open(img_path).resize((224,224)))
plt.figure(figsize=(6,6))
plt.imshow(img_np)
plt.imshow(heat, cmap='jet', alpha=0.4)
plt.axis('off')
plt.title('SHAP Attribution')
plt.show()