<a href="https://colab.research.google.com/github/seongwoojang1123/Improved-Diagnostic-Accuracy-of-TMJ-Osteoarthritis-through-Machine-Learning-Integration-of-CBCT-MRI/blob/main/7_Feature_map_and_Gradcam_heatmap.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ===================== Unified Verification Script =====================
# - Grad-CAM (hookless) + scores
# - Feature maps @ conv=CONV_IDX, pre-/post-ReLU capture
# - Sparsity: pre_sign, pre_eps, post_zero, post_eps (표/검증)
# ======================================================================
import os, torch, numpy as np, cv2
from PIL import Image
from torchvision import models, transforms
from torch import nn
import torch.nn.functional as F

# ---------------- Config ----------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 140
NUM_CLASSES = 4
USE_NORMALIZE = True   # <-- ImageNet 가중치/학습 체계 맞추려면 True 권장
CKPT_PATH = None       # <-- 파인튜닝 가중치 있으면 경로 지정
CONV_IDX = 0           # <-- 스윕 결과로 확정한 conv index
EPS = 1e-6             # <-- ε-zero 기준
SAVE_TILES = True
TILES_DIR = "/content/drive/MyDrive/TMJ OA/FeatureMaps_Verify"
os.makedirs(TILES_DIR, exist_ok=True)

images = {
    # CBCT
    "CBCT1_Rt": "/content/drive/MyDrive/TMJ OA/학습_data_1,3(crop)/test_data/Rt_OA/20138975 1.JPG",
    "CBCT3_Lt": "/content/drive/MyDrive/TMJ OA/학습_data_1,3(crop)/test_data/Lt_OA/20188759 3.JPG",
    "CBCT2_Rt": "/content/drive/MyDrive/TMJ OA/학습_data_2,4(crop)/test_data/Rt_OA/20138975 2.JPG",
    "CBCT4_Lt": "/content/drive/MyDrive/TMJ OA/학습_data_2,4(crop)/test_data/Lt_OA/20252473 4.JPG",
    # MRI
    "MRI1_Rt":  "/content/drive/MyDrive/TMJ OA/학습_data_MRI_1,2/test_data/Rt_OA/20123901 1 M.jpg",
    "MRI2_Lt":  "/content/drive/MyDrive/TMJ OA/학습_data_MRI_1,2/test_data/Lt_OA/20123901 2 M.jpg",
    "MRI3_Rt":  "/content/drive/MyDrive/TMJ OA/학습_data_MRI_3,4/test_data/Rt_OA/20123901 3 M.jpg",
    "MRI4_Lt":  "/content/drive/MyDrive/TMJ OA/학습_data_MRI_3,4/test_data/Lt_OA/20131063 4 M.jpg",
}

# --------------- Transforms ---------------
trns = [transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor()]
if USE_NORMALIZE:
    trns.append(transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]))
transform_eval = transforms.Compose(trns)
def load_rgb(p): return Image.open(p).convert("RGB")

# --------------- Model --------------------
def disable_inplace_relu(model: nn.Module):
    for m in model.modules():
        if isinstance(m, nn.ReLU):
            m.inplace = False

def build_model():
    m = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
    # 분류기 구조가 다르면 아래를 활성화하고 CKPT 불러오기
    # in_f = m.classifier[-1].in_features
    # m.classifier[-1] = nn.Linear(in_f, NUM_CLASSES)
    if CKPT_PATH and os.path.isfile(CKPT_PATH):
        state = torch.load(CKPT_PATH, map_location=DEVICE)
        m.load_state_dict(state, strict=False)
    disable_inplace_relu(m)  # Grad 안정성
    return m.to(DEVICE).eval()

model = build_model()
features: nn.Sequential = model.features

# conv 이후 첫 ReLU 찾기
def find_relu_after_conv(seq: nn.Sequential, conv_idx: int) -> int:
    n = len(seq)
    for j in range(conv_idx+1, n):
        if isinstance(seq[j], nn.ReLU):
            return j
    raise RuntimeError(f"Conv idx {conv_idx} 뒤 ReLU 없음")
RELU_IDX = find_relu_after_conv(features, CONV_IDX)

# ---- pre/post ReLU 정확 캡처 훅 ----
class PrePostTap:
    def __init__(self, relu_mod: nn.Module):
        self.pre = None
        self.post = None
        self.h1 = relu_mod.register_forward_pre_hook(self._pre)
        self.h2 = relu_mod.register_forward_hook(self._post)
    def _pre(self, module, inputs):
        self.pre = inputs[0]  # Tensor (B,C,H,W) - 그래프 필요 X
    def _post(self, module, inputs, output):
        self.post = output    # Tensor (B,C,H,W)
    def close(self):
        self.h1.remove(); self.h2.remove()

tap = PrePostTap(features[RELU_IDX])

# --------------- Grad-CAM (hookless) ---------------
class GradCAM_NoBwdHook:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.fmaps = None
        self.h = target_layer.register_forward_hook(self._save)
    def _save(self, m, i, o):
        self.fmaps = o
    def __call__(self, x, target_category=None):
        logits = self.model(x)
        if target_category is None:
            target_category = torch.argmax(logits, dim=1).item()
        loss = logits[0, target_category]
        grads = torch.autograd.grad(loss, self.fmaps, retain_graph=True, allow_unused=True)[0]
        if grads is None:
            # 레이어를 경유하지 않은 경우 대비
            grads = torch.zeros_like(self.fmaps)
        weights = grads.mean(dim=(2,3), keepdim=True)  # (1,C,1,1)
        cam = (weights * self.fmaps).sum(dim=1)        # (1,H,W)
        cam = F.relu(cam)
        cam = cam[0]
        cam = (cam - cam.min()) / (cam.max() + 1e-8)
        return cam.detach().cpu().numpy()
    def close(self):
        self.h.remove()

# 마지막 conv 찾기
def find_last_conv(m: nn.Module):
    last = None
    for mm in m.modules():
        if isinstance(mm, nn.Conv2d):
            last = mm
    return last
target_layer = find_last_conv(model)
camper = GradCAM_NoBwdHook(model, target_layer)

# --------------- Sparsity variants ---------------
def sparsity_metrics(pre: torch.Tensor, post: torch.Tensor, eps=EPS):
    # pre/post: (1,C,H,W) torch
    pre_np  = pre.detach().cpu().numpy()[0]
    post_np = post.detach().cpu().numpy()[0]
    C, H, W = pre_np.shape
    def mean_ratio(arr, cond):
        return float(np.mean([cond(arr[c]).mean() for c in range(arr.shape[0])]))
    pre_sign  = mean_ratio(pre_np,  lambda x: (x <= 0))
    pre_eps   = mean_ratio(pre_np,  lambda x: (np.abs(x) <= eps))
    post_zero = mean_ratio(post_np, lambda x: (x == 0))
    post_eps  = mean_ratio(post_np, lambda x: (x <= eps))
    return pre_sign, pre_eps, post_zero, post_eps

# --------------- Overlay util ---------------
def overlay_on_rgb(rgb_uint8, cam_hw, alpha=0.45):
    H, W, _ = rgb_uint8.shape
    cam_resized = cv2.resize((cam_hw*255).astype(np.uint8), (W, H))
    heat = cv2.applyColorMap(cam_resized, cv2.COLORMAP_JET)      # BGR
    base = cv2.cvtColor(rgb_uint8, cv2.COLOR_RGB2BGR)
    return (alpha*heat + (1-alpha)*base).clip(0,255).astype(np.uint8)

def save_tiles(chw: np.ndarray, save_path: str, rows=5, cols=20, max_tiles=100):
    C,H,W = chw.shape
    tiles = min(max_tiles, C)
    import matplotlib.pyplot as plt
    fig = plt.figure(figsize=(cols*0.9, rows*0.9))
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
    for j in range(tiles):
        ax = fig.add_subplot(rows, cols, j+1); ax.imshow(chw[j]); ax.axis("off")
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=150, bbox_inches="tight"); plt.close(fig)

# --------------- Run ---------------
rows = []
for alias, p in images.items():
    assert os.path.isfile(p), f"파일 없음: {p}"
    rgb = load_rgb(p)
    x = transform_eval(rgb).unsqueeze(0).to(DEVICE)

    # 1) Grad-CAM
    with torch.no_grad():
        logits = model(x)
        pred_id = int(torch.argmax(logits, dim=1).item())
    cam = camper(x, target_category=pred_id)

    # 2) pre/post feature maps @ conv=CONV_IDX (via ReLU hooks)
    _ = model(x)  # forward once to fill taps
    pre  = tap.pre   # (1,C,H,W) torch
    post = tap.post  # (1,C,H,W) torch
    assert pre is not None and post is not None, "pre/post 캡처 실패"

    # 3) Sparsity (4종)
    s_pre_sign, s_pre_eps, s_post_zero, s_post_eps = sparsity_metrics(pre, post, eps=EPS)

    # 4) 저장(옵션)
    if SAVE_TILES:
        pre_np = pre.detach().cpu().numpy()[0]
        save_path = os.path.join(TILES_DIR, f"{alias}_conv{CONV_IDX}_pre_t{IMG_SIZE}_preSign{ s_pre_sign:.4f}.png")
        save_tiles(pre_np, save_path, rows=5, cols=20, max_tiles=100)

    # 5) 기록
    rgb_u8 = np.array(rgb.resize((IMG_SIZE, IMG_SIZE)), dtype=np.uint8)
    overlay_bgr = overlay_on_rgb(rgb_u8, cam, alpha=0.45)
    cam_path = os.path.join(TILES_DIR, f"{alias}_cam_pred{pred_id}.png")
    cv2.imwrite(cam_path, overlay_bgr)

    rows.append({
        "alias": alias,
        "pred": pred_id,
        "pre_sign":  s_pre_sign,
        "pre_eps":   s_pre_eps,
        "post_zero": s_post_zero,
        "post_eps":  s_post_eps,
        "cam_path":  cam_path
    })

camper.close(); tap.close()

# --------------- Report ---------------
try:
    import pandas as pd
    df = pd.DataFrame(rows)
    # 간단한 해석 힌트 출력
    print(df.to_string(index=False))
    print("\n[Hint]")
    print(" - pre_sign≈0.4~0.5, post_zero≈0.6~0.9면: post-ReLU sparsity가 훨씬 큼(정상).")
    print(" - pre_eps는 ε 설정(EPS) 민감. 시각적 착시는 imshow scaling 영향 가능.")
    print(f" - Normalize={'ON' if USE_NORMALIZE else 'OFF'}, Conv idx={CONV_IDX}, ε={EPS}")
except Exception:
    for r in rows: print(r)
