In [7]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import SimpleITK as sitk
from tqdm import tqdm
import csv

# =========================================================
# CONFIG
# =========================================================
CACHE_SLICE_DIR = "./radiomics_ready/output_224/images"
NIFTI_NAME = "VP005_pre_0000.nii.gz"
MASK_PATH  = "./radiomics_ready/output_224/masks/VP005_pre_0000.nii.gz"

# 모델 폴더 + 패턴 (run0~run29)
MODEL_DIR = "/Users/oyeoncho/Dropbox/working_directory/prep_cx_ca_image/image_data/models_extract/beit0"
MODEL_PATTERN = "best_model_beit_run{run}.pt"
RUNS = list(range(30))  # run0~run29 (총 30개)

OUT_DIR = "./visualization/all_models"
TARGET_XY = (224, 224)
DEVICE = torch.device("cpu")

INTEGRATED_FEATS = [436, 519]

# ✅ 여러 hotspot percentile
HOT_P_LIST = [50, 70, 90]  # 원하면 [90,95,97,99] 등으로 바꾸면 됨

# run3 특이부위: run4 hotspot인데, 다른 run들에서 hotspot 빈도가 <= FREQ_MAX 인 voxel
RUN_TARGET = 3
FREQ_MAX = 2  # 더 엄격하게: 0~1, 덜 엄격하게: 3~5

# 저장 옵션
SAVE_MRI_MASKED_NIFTI = True
SAVE_ALL_RUN_HEAT_NIFTI = False  # True면 run0~29 heatmap도 전부 NIfTI로 저장(용량 큼)

# 메모리 절약: 모든 run heat를 디스크 memmap으로 저장
USE_MEMMAP = True

# =========================================================
# MODEL
# =========================================================
class BEiTBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model("beit_base_patch16_224", pretrained=False, num_classes=0)

    def forward(self, x):
        return self.backbone(x)

def load_beit_backbone(model_path: str, device):
    model = BEiTBackbone().to(device)
    try:
        with torch.serialization.safe_globals([np._core.multiarray.scalar]):
            ckpt = torch.load(model_path, map_location=device)
    except Exception:
        ckpt = torch.load(model_path, map_location=device, weights_only=False)

    state_dict = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt
    filtered = {k.replace("backbone.", ""): v for k, v in state_dict.items() if k.startswith("backbone.")}
    if len(filtered) == 0:
        filtered = state_dict

    model.backbone.load_state_dict(filtered, strict=False)
    model.eval()
    return model

# =========================================================
# NIfTI I/O
# =========================================================
def load_nifti_as_zyx(nifti_path: str):
    img = sitk.ReadImage(nifti_path)
    arr = sitk.GetArrayFromImage(img).astype(np.float32)  # (Z, Y, X)
    return img, arr

def resample_xy_keep_z(sitk_img, target_xy=(224,224), interp=sitk.sitkLinear):
    size = list(sitk_img.GetSize())     # (X, Y, Z)
    spacing = list(sitk_img.GetSpacing())

    new_size = [int(target_xy[0]), int(target_xy[1]), int(size[2])]
    new_spacing = [
        spacing[0] * (size[0] / new_size[0]),
        spacing[1] * (size[1] / new_size[1]),
        spacing[2]
    ]

    resampler = sitk.ResampleImageFilter()
    resampler.SetSize(new_size)
    resampler.SetOutputSpacing(new_spacing)
    resampler.SetOutputOrigin(sitk_img.GetOrigin())
    resampler.SetOutputDirection(sitk_img.GetDirection())
    resampler.SetInterpolator(interp)
    return resampler.Execute(sitk_img)

def save_zyx_as_nifti(vol_zyx: np.ndarray, ref_img_sitk, out_path: str):
    out = sitk.GetImageFromArray(vol_zyx.astype(np.float32))  # expects (Z,Y,X)
    out.SetSpacing(ref_img_sitk.GetSpacing())
    out.SetOrigin(ref_img_sitk.GetOrigin())
    out.SetDirection(ref_img_sitk.GetDirection())
    sitk.WriteImage(out, out_path)

# =========================================================
# Preprocess
# =========================================================
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(1,3,1,1)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(1,3,1,1)

def volume_to_tensors(arr_zyx: np.ndarray):
    t = torch.from_numpy(arr_zyx).float()
    vmin = float(t.min().item()); vmax = float(t.max().item())
    if vmax > vmin:
        t01 = (t - vmin) / (vmax - vmin)
    else:
        t01 = torch.zeros_like(t)

    imgs_vis = t01.clone()                          # (Z,Y,X) in [0,1]
    x = t01.unsqueeze(1).repeat(1, 3, 1, 1)         # (Z,3,Y,X)
    x = (x - IMAGENET_MEAN) / IMAGENET_STD
    return x, imgs_vis

def mask_to_binary(mask_zyx: np.ndarray, thr=0.5):
    return (mask_zyx > thr).astype(np.uint8)

def upsample_14_to_224(m14: torch.Tensor):
    m = m14.unsqueeze(0).unsqueeze(0)  # (1,1,14,14)
    up = F.interpolate(m, size=(224,224), mode="bilinear", align_corners=False)[0,0]
    return up

# =========================================================
# Integrated grad heatmap (raw) for one slice
# =========================================================
def compute_integrated_gradmap_one_slice(model: BEiTBackbone, x1: torch.Tensor, integrated_feats):
    tokens = model.backbone.forward_features(x1)  # (1,1+N,C)
    tokens.retain_grad()

    if hasattr(model.backbone, "forward_head"):
        feats = model.backbone.forward_head(tokens, pre_logits=True)  # (1,C)
    else:
        feats = tokens[:, 0]

    patch_tok = tokens[:, 1:, :]  # (1,N,C)
    n = patch_tok.shape[1]
    h = w = int(np.sqrt(n))

    model.zero_grad(set_to_none=True)
    if tokens.grad is not None:
        tokens.grad.zero_()

    target = sum(feats[:, idx].sum() for idx in integrated_feats)
    target.backward(retain_graph=False)

    patch_grad = tokens.grad[:, 1:, :]
    imp14 = (patch_grad[0] * patch_tok[0]).abs().mean(dim=-1).reshape(h, w)  # (14,14)
    imp224 = upsample_14_to_224(imp14).detach().cpu().numpy().astype(np.float32)
    return imp224

# =========================================================
# Hotspot + metrics
# =========================================================
def hotspot_mask(H, p=95):
    pos = H[H > 0]
    if pos.size < 10:
        return np.zeros_like(H, dtype=np.uint8), np.nan
    thr = np.percentile(pos, p)
    return (H >= thr).astype(np.uint8), float(thr)

def bin_metrics(A, B):
    A = (A > 0); B = (B > 0)
    inter = np.logical_and(A, B).sum()
    union = np.logical_or(A, B).sum()
    a = A.sum(); b = B.sum()

    dice = (2*inter) / (a + b + 1e-8)
    iou  = inter / (union + 1e-8)

    tp = inter
    fp = np.logical_and(A, ~B).sum()
    fn = np.logical_and(~A, B).sum()
    precision = tp / (tp + fp + 1e-8)
    recall    = tp / (tp + fn + 1e-8)
    f1        = 2*precision*recall / (precision + recall + 1e-8)

    return {
        "A_vox": int(a), "B_vox": int(b),
        "inter": int(inter), "union": int(union),
        "dice": float(dice), "iou": float(iou),
        "precision": float(precision), "recall": float(recall), "f1": float(f1),
        "fp": int(fp), "fn": int(fn),
    }

def centroid_distance(A, B, spacing_xyz=None):
    A_idx = np.argwhere(A)
    B_idx = np.argwhere(B)
    if A_idx.size == 0 or B_idx.size == 0:
        return np.nan
    ca = A_idx.mean(axis=0)  # (Z,Y,X)
    cb = B_idx.mean(axis=0)
    d = ca - cb
    if spacing_xyz is not None:
        sx, sy, sz = spacing_xyz  # (X,Y,Z)
        d = np.array([d[0]*sz, d[1]*sy, d[2]*sx], dtype=np.float32)
    return float(np.linalg.norm(d))

def cont_metrics(H1, H2, roi_mask):
    x = H1[roi_mask].astype(np.float64)
    y = H2[roi_mask].astype(np.float64)

    x0 = x - x.mean()
    y0 = y - y.mean()
    corr = (x0*y0).sum() / (np.sqrt((x0*x0).sum() * (y0*y0).sum()) + 1e-12)

    mae  = np.mean(np.abs(x - y))
    rmse = np.sqrt(np.mean((x - y)**2))
    return {"corr": float(corr), "MAE": float(mae), "RMSE": float(rmse)}

# =========================================================
# Heat3D runner (one model)
# =========================================================
def run_model_to_heat3d(model_path, imgs_norm, msk_bin, integrated_feats, device):
    model = load_beit_backbone(model_path, device)
    Z = imgs_norm.shape[0]
    heat3d = np.zeros((Z, TARGET_XY[1], TARGET_XY[0]), dtype=np.float32)  # (Z,Y,X)

    for z in range(Z):
        x1 = imgs_norm[z:z+1].to(device)
        m224 = msk_bin[z].astype(np.float32)
        grad224 = compute_integrated_gradmap_one_slice(model, x1, integrated_feats)
        heat3d[z] = (grad224 * m224).astype(np.float32)

    return heat3d

# =========================================================
# MAIN
# =========================================================
def main():
    os.makedirs(OUT_DIR, exist_ok=True)

    nifti_path = os.path.join(CACHE_SLICE_DIR, NIFTI_NAME)
    assert os.path.exists(nifti_path), f"파일 없음: {nifti_path}"
    assert os.path.exists(MASK_PATH),  f"마스크 없음: {MASK_PATH}"

    # 모델 경로 구성 + 존재 확인
    model_paths = {}
    for r in RUNS:
        p = os.path.join(MODEL_DIR, MODEL_PATTERN.format(run=r))
        assert os.path.exists(p), f"모델 없음: {p}"
        model_paths[r] = p

    # load image/mask
    sitk_img, arr = load_nifti_as_zyx(nifti_path)
    sitk_msk, msk = load_nifti_as_zyx(MASK_PATH)

    # resample to 224
    if (sitk_img.GetSize()[0], sitk_img.GetSize()[1]) != TARGET_XY:
        sitk_img = resample_xy_keep_z(sitk_img, TARGET_XY, sitk.sitkLinear)
        arr = sitk.GetArrayFromImage(sitk_img).astype(np.float32)

    if (sitk_msk.GetSize()[0], sitk_msk.GetSize()[1]) != TARGET_XY:
        sitk_msk = resample_xy_keep_z(sitk_msk, TARGET_XY, sitk.sitkNearestNeighbor)
        msk = sitk.GetArrayFromImage(sitk_msk).astype(np.float32)

    assert arr.shape == msk.shape, f"shape mismatch: {arr.shape} vs {msk.shape}"

    imgs_norm, imgs_vis = volume_to_tensors(arr)
    Z = imgs_norm.shape[0]
    msk_bin = mask_to_binary(msk, 0.5)
    roi = (msk_bin > 0)
    spacing_xyz = sitk_img.GetSpacing()  # (X,Y,Z)

    base = os.path.splitext(os.path.splitext(NIFTI_NAME)[0])[0]

    hot_tag = "_".join([str(p) for p in HOT_P_LIST])
    out_dir = os.path.join(
        OUT_DIR,
        f"{base}_RUNS0_29_HOTP{hot_tag}_run{RUN_TARGET}_specific_freqLE{FREQ_MAX}"
    )
    os.makedirs(out_dir, exist_ok=True)
    print("[OUT]", out_dir)

    # save MRI masked
    if SAVE_MRI_MASKED_NIFTI:
        mri3d_masked = (imgs_vis.numpy() * msk_bin.astype(np.float32)).astype(np.float32)
        save_zyx_as_nifti(mri3d_masked, sitk_img, os.path.join(out_dir, f"{base}_mri_masked_224.nii.gz"))

    # -----------------------------------------------------
    # Allocate storage for 30 heats
    # -----------------------------------------------------
    if USE_MEMMAP:
        mm_path = os.path.join(out_dir, f"{base}_heat30runs_float32.dat")
        heats = np.memmap(
            mm_path, dtype="float32", mode="w+",
            shape=(len(RUNS), Z, TARGET_XY[1], TARGET_XY[0])
        )
        print("[MEMMAP]", mm_path)
    else:
        heats = np.zeros((len(RUNS), Z, TARGET_XY[1], TARGET_XY[0]), dtype=np.float32)

    # -----------------------------------------------------
    # 1) Compute heatmaps for all runs
    # -----------------------------------------------------
    for i, r in enumerate(tqdm(RUNS, desc="Compute heatmaps (30 runs)")):
        H = run_model_to_heat3d(model_paths[r], imgs_norm, msk_bin, INTEGRATED_FEATS, DEVICE)
        heats[i] = H
        if SAVE_ALL_RUN_HEAT_NIFTI:
            save_zyx_as_nifti(H, sitk_img, os.path.join(out_dir, f"{base}_heat_integrated_raw_run{r}.nii.gz"))

    if USE_MEMMAP:
        heats.flush()

    # -----------------------------------------------------
    # 2) Consensus heat (공통) 저장
    # -----------------------------------------------------
    consensus = heats.mean(axis=0).astype(np.float32)  # (Z,Y,X)
    save_zyx_as_nifti(consensus, sitk_img, os.path.join(out_dir, f"{base}_CONSENSUS_mean_raw.nii.gz"))

    # -----------------------------------------------------
    # 2.5) run별 hotspot threshold를 p별로 저장할 dict
    # run_thr[p][run] = thr
    # -----------------------------------------------------
    run_thr = {p: {} for p in HOT_P_LIST}

    # -----------------------------------------------------
    # 3~4) 퍼센타일별로 반복
    # -----------------------------------------------------
    for HOT_P in HOT_P_LIST:
        print(f"\n===== HOT_P = {HOT_P} =====")

        # 3-1) hotspot frequency map
        freq = np.zeros((Z, TARGET_XY[1], TARGET_XY[0]), dtype=np.uint16)

        for i, r in enumerate(tqdm(RUNS, desc=f"Hotspot freq (p{HOT_P})")):
            M, thr = hotspot_mask(heats[i], p=HOT_P)
            run_thr[HOT_P][r] = thr
            freq += M.astype(np.uint16)

        save_zyx_as_nifti(
            freq.astype(np.float32),
            sitk_img,
            os.path.join(out_dir, f"{base}_HOT_FREQ_p{HOT_P}_count.nii.gz")
        )

        # 3-2) consensus hotspot (비교 기준)
        Mc, thr_c = hotspot_mask(consensus, p=HOT_P)
        save_zyx_as_nifti(
            Mc.astype(np.float32),
            sitk_img,
            os.path.join(out_dir, f"{base}_CONSENSUS_HOT_p{HOT_P}.nii.gz")
        )

        # 3-3) run_target hotspot & unique
        idx_t = RUNS.index(RUN_TARGET)
        Ht = heats[idx_t].astype(np.float32)
        Mt, thr_t = hotspot_mask(Ht, p=HOT_P)

        # 다른 run에서의 빈도만 보려면 self 제외
        freq_minus_self = freq.astype(np.int32) - Mt.astype(np.int32)
        uniq = (Mt > 0) & (freq_minus_self <= FREQ_MAX)

        save_zyx_as_nifti(
            Mt.astype(np.float32),
            sitk_img,
            os.path.join(out_dir, f"{base}_run{RUN_TARGET}_HOT_p{HOT_P}.nii.gz")
        )
        save_zyx_as_nifti(
            uniq.astype(np.float32),
            sitk_img,
            os.path.join(out_dir, f"{base}_run{RUN_TARGET}_UNIQUE_p{HOT_P}_freqLE{FREQ_MAX}.nii.gz")
        )

        # 특이부위 연속 heat: run_target heat에서 uniq만 남김
        Ht_unique = (Ht * uniq.astype(np.float32)).astype(np.float32)
        save_zyx_as_nifti(
            Ht_unique,
            sitk_img,
            os.path.join(out_dir, f"{base}_run{RUN_TARGET}_UNIQUE_heat_raw_p{HOT_P}.nii.gz")
        )

        # 4) Table: each run vs consensus (p별 CSV)
        rows = []
        for i, r in enumerate(tqdm(RUNS, desc=f"Metrics (run vs consensus) p{HOT_P}")):
            Hr = heats[i].astype(np.float32)
            Mr, _ = hotspot_mask(Hr, p=HOT_P)

            bm = bin_metrics(Mr, Mc)
            cd = centroid_distance(Mr > 0, Mc > 0, spacing_xyz=spacing_xyz)
            cm = cont_metrics(Hr, consensus, roi_mask=roi)

            # 각 run hotspot이 "run_target unique"를 얼마나 포함?
            uniq_overlap = float(((Mr > 0) & (uniq > 0)).sum() / ((Mr > 0).sum() + 1e-8))

            rows.append({
                "run": r,
                "hot_p": HOT_P,
                "thr_run": run_thr[HOT_P][r],
                "thr_consensus": thr_c,
                "dice_vs_cons": bm["dice"],
                "iou_vs_cons": bm["iou"],
                "prec_vs_cons": bm["precision"],
                "recall_vs_cons": bm["recall"],
                "f1_vs_cons": bm["f1"],
                "centroid_dist_mm_vs_cons": cd,
                "corr_roi_vs_cons": cm["corr"],
                "MAE_roi_vs_cons": cm["MAE"],
                "RMSE_roi_vs_cons": cm["RMSE"],
                "hot_vox_run": bm["A_vox"],
                "hot_vox_cons": bm["B_vox"],
                "uniq_overlap_ratio_with_runTargetUnique": uniq_overlap,
            })

        csv_path = os.path.join(out_dir, f"{base}_metrics_30runs_vs_consensus_p{HOT_P}.csv")
        with open(csv_path, "w", newline="") as f:
            w = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
            w.writeheader()
            w.writerows(rows)

        print(f"[SAVE] CSV: {csv_path}")

    print("\n✅ DONE (all percentiles)")
    print("Output folder:", out_dir)
    print("\n[Key outputs]")
    print(" - MRI:", f"{base}_mri_masked_224.nii.gz")
    print(" - consensus heat:", f"{base}_CONSENSUS_mean_raw.nii.gz")
    for p in HOT_P_LIST:
        print(f" - hotspot freq p{p}:", f"{base}_HOT_FREQ_p{p}_count.nii.gz")
        print(f" - consensus hot p{p}:", f"{base}_CONSENSUS_HOT_p{p}.nii.gz")
        print(f" - run{RUN_TARGET} unique p{p}:", f"{base}_run{RUN_TARGET}_UNIQUE_p{p}_freqLE{FREQ_MAX}.nii.gz")
        print(f" - run{RUN_TARGET} unique heat p{p}:", f"{base}_run{RUN_TARGET}_UNIQUE_heat_raw_p{p}.nii.gz")
        print(f" - CSV p{p}:", f"{base}_metrics_30runs_vs_consensus_p{p}.csv")

if __name__ == "__main__":
    main()


[OUT] ./visualization/all_models/VP005_pre_0000_RUNS0_29_HOTP50_70_90_run3_specific_freqLE2
[MEMMAP] ./visualization/all_models/VP005_pre_0000_RUNS0_29_HOTP50_70_90_run3_specific_freqLE2/VP005_pre_0000_heat30runs_float32.dat


Compute heatmaps (30 runs): 100%|██████████| 30/30 [20:19<00:00, 40.65s/it]



===== HOT_P = 50 =====


Hotspot freq (p50): 100%|██████████| 30/30 [00:00<00:00, 115.12it/s]
Metrics (run vs consensus) p50: 100%|██████████| 30/30 [00:01<00:00, 17.36it/s]


[SAVE] CSV: ./visualization/all_models/VP005_pre_0000_RUNS0_29_HOTP50_70_90_run3_specific_freqLE2/VP005_pre_0000_metrics_30runs_vs_consensus_p50.csv

===== HOT_P = 70 =====


Hotspot freq (p70): 100%|██████████| 30/30 [00:00<00:00, 116.76it/s]
Metrics (run vs consensus) p70: 100%|██████████| 30/30 [00:01<00:00, 18.12it/s]


[SAVE] CSV: ./visualization/all_models/VP005_pre_0000_RUNS0_29_HOTP50_70_90_run3_specific_freqLE2/VP005_pre_0000_metrics_30runs_vs_consensus_p70.csv

===== HOT_P = 90 =====


Hotspot freq (p90): 100%|██████████| 30/30 [00:00<00:00, 123.91it/s]
Metrics (run vs consensus) p90: 100%|██████████| 30/30 [00:01<00:00, 18.84it/s]

[SAVE] CSV: ./visualization/all_models/VP005_pre_0000_RUNS0_29_HOTP50_70_90_run3_specific_freqLE2/VP005_pre_0000_metrics_30runs_vs_consensus_p90.csv

✅ DONE (all percentiles)
Output folder: ./visualization/all_models/VP005_pre_0000_RUNS0_29_HOTP50_70_90_run3_specific_freqLE2

[Key outputs]
 - MRI: VP005_pre_0000_mri_masked_224.nii.gz
 - consensus heat: VP005_pre_0000_CONSENSUS_mean_raw.nii.gz
 - hotspot freq p50: VP005_pre_0000_HOT_FREQ_p50_count.nii.gz
 - consensus hot p50: VP005_pre_0000_CONSENSUS_HOT_p50.nii.gz
 - run3 unique p50: VP005_pre_0000_run3_UNIQUE_p50_freqLE2.nii.gz
 - run3 unique heat p50: VP005_pre_0000_run3_UNIQUE_heat_raw_p50.nii.gz
 - CSV p50: VP005_pre_0000_metrics_30runs_vs_consensus_p50.csv
 - hotspot freq p70: VP005_pre_0000_HOT_FREQ_p70_count.nii.gz
 - consensus hot p70: VP005_pre_0000_CONSENSUS_HOT_p70.nii.gz
 - run3 unique p70: VP005_pre_0000_run3_UNIQUE_p70_freqLE2.nii.gz
 - run3 unique heat p70: VP005_pre_0000_run3_UNIQUE_heat_raw_p70.nii.gz
 - CSV p70: VP005




In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import SimpleITK as sitk
from tqdm import tqdm
import csv

# =========================================================
# CONFIG
# =========================================================
CACHE_SLICE_DIR = "./cache_slice"
NIFTI_NAME = "female_pelvis_t1.nii.gz"
MASK_PATH  = "./output_mask/female_pelvis_t1.nii.gz"

# 모델 폴더 + 패턴 (run0~run29)
MODEL_DIR = "/Users/oyeoncho/Dropbox/working_directory/prep_cx_ca_image/image_data/models_extract/beit0"
MODEL_PATTERN = "best_model_beit_run{run}.pt"
RUNS = list(range(30))  # run0~run29 (총 30개)

OUT_DIR = "./vis_out_3d_compare_30runs"
TARGET_XY = (224, 224)
DEVICE = torch.device("cpu")

INTEGRATED_FEATS = [436, 519]

# ✅ 여러 hotspot percentile
HOT_P_LIST = [50, 70, 90]  # 원하면 [90,95,97,99] 등으로 바꾸면 됨

# run4 특이부위: run4 hotspot인데, 다른 run들에서 hotspot 빈도가 <= FREQ_MAX 인 voxel
RUN_TARGET = 3
FREQ_MAX = 2  # 더 엄격하게: 0~1, 덜 엄격하게: 3~5

# 저장 옵션
SAVE_MRI_MASKED_NIFTI = True
SAVE_ALL_RUN_HEAT_NIFTI = False  # True면 run0~29 heatmap도 전부 NIfTI로 저장(용량 큼)

# 메모리 절약: 모든 run heat를 디스크 memmap으로 저장
USE_MEMMAP = True

# =========================================================
# MODEL
# =========================================================
class BEiTBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model("beit_base_patch16_224", pretrained=False, num_classes=0)

    def forward(self, x):
        return self.backbone(x)

def load_beit_backbone(model_path: str, device):
    model = BEiTBackbone().to(device)
    try:
        with torch.serialization.safe_globals([np._core.multiarray.scalar]):
            ckpt = torch.load(model_path, map_location=device)
    except Exception:
        ckpt = torch.load(model_path, map_location=device, weights_only=False)

    state_dict = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt
    filtered = {k.replace("backbone.", ""): v for k, v in state_dict.items() if k.startswith("backbone.")}
    if len(filtered) == 0:
        filtered = state_dict

    model.backbone.load_state_dict(filtered, strict=False)
    model.eval()
    return model

# =========================================================
# NIfTI I/O
# =========================================================
def load_nifti_as_zyx(nifti_path: str):
    img = sitk.ReadImage(nifti_path)
    arr = sitk.GetArrayFromImage(img).astype(np.float32)  # (Z, Y, X)
    return img, arr

def resample_xy_keep_z(sitk_img, target_xy=(224,224), interp=sitk.sitkLinear):
    size = list(sitk_img.GetSize())     # (X, Y, Z)
    spacing = list(sitk_img.GetSpacing())

    new_size = [int(target_xy[0]), int(target_xy[1]), int(size[2])]
    new_spacing = [
        spacing[0] * (size[0] / new_size[0]),
        spacing[1] * (size[1] / new_size[1]),
        spacing[2]
    ]

    resampler = sitk.ResampleImageFilter()
    resampler.SetSize(new_size)
    resampler.SetOutputSpacing(new_spacing)
    resampler.SetOutputOrigin(sitk_img.GetOrigin())
    resampler.SetOutputDirection(sitk_img.GetDirection())
    resampler.SetInterpolator(interp)
    return resampler.Execute(sitk_img)

def save_zyx_as_nifti(vol_zyx: np.ndarray, ref_img_sitk, out_path: str):
    out = sitk.GetImageFromArray(vol_zyx.astype(np.float32))  # expects (Z,Y,X)
    out.SetSpacing(ref_img_sitk.GetSpacing())
    out.SetOrigin(ref_img_sitk.GetOrigin())
    out.SetDirection(ref_img_sitk.GetDirection())
    sitk.WriteImage(out, out_path)

# =========================================================
# Preprocess
# =========================================================
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(1,3,1,1)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(1,3,1,1)

def volume_to_tensors(arr_zyx: np.ndarray):
    t = torch.from_numpy(arr_zyx).float()
    vmin = float(t.min().item()); vmax = float(t.max().item())
    if vmax > vmin:
        t01 = (t - vmin) / (vmax - vmin)
    else:
        t01 = torch.zeros_like(t)

    imgs_vis = t01.clone()                          # (Z,Y,X) in [0,1]
    x = t01.unsqueeze(1).repeat(1, 3, 1, 1)         # (Z,3,Y,X)
    x = (x - IMAGENET_MEAN) / IMAGENET_STD
    return x, imgs_vis

def mask_to_binary(mask_zyx: np.ndarray, thr=0.5):
    return (mask_zyx > thr).astype(np.uint8)

def upsample_14_to_224(m14: torch.Tensor):
    m = m14.unsqueeze(0).unsqueeze(0)  # (1,1,14,14)
    up = F.interpolate(m, size=(224,224), mode="bilinear", align_corners=False)[0,0]
    return up

# =========================================================
# Integrated grad heatmap (raw) for one slice
# =========================================================
def compute_integrated_gradmap_one_slice(model: BEiTBackbone, x1: torch.Tensor, integrated_feats):
    tokens = model.backbone.forward_features(x1)  # (1,1+N,C)
    tokens.retain_grad()

    if hasattr(model.backbone, "forward_head"):
        feats = model.backbone.forward_head(tokens, pre_logits=True)  # (1,C)
    else:
        feats = tokens[:, 0]

    patch_tok = tokens[:, 1:, :]  # (1,N,C)
    n = patch_tok.shape[1]
    h = w = int(np.sqrt(n))

    model.zero_grad(set_to_none=True)
    if tokens.grad is not None:
        tokens.grad.zero_()

    target = sum(feats[:, idx].sum() for idx in integrated_feats)
    target.backward(retain_graph=False)

    patch_grad = tokens.grad[:, 1:, :]
    imp14 = (patch_grad[0] * patch_tok[0]).abs().mean(dim=-1).reshape(h, w)  # (14,14)
    imp224 = upsample_14_to_224(imp14).detach().cpu().numpy().astype(np.float32)
    return imp224

# =========================================================
# Hotspot + metrics
# =========================================================
def hotspot_mask(H, p=95):
    pos = H[H > 0]
    if pos.size < 10:
        return np.zeros_like(H, dtype=np.uint8), np.nan
    thr = np.percentile(pos, p)
    return (H >= thr).astype(np.uint8), float(thr)

def bin_metrics(A, B):
    A = (A > 0); B = (B > 0)
    inter = np.logical_and(A, B).sum()
    union = np.logical_or(A, B).sum()
    a = A.sum(); b = B.sum()

    dice = (2*inter) / (a + b + 1e-8)
    iou  = inter / (union + 1e-8)

    tp = inter
    fp = np.logical_and(A, ~B).sum()
    fn = np.logical_and(~A, B).sum()
    precision = tp / (tp + fp + 1e-8)
    recall    = tp / (tp + fn + 1e-8)
    f1        = 2*precision*recall / (precision + recall + 1e-8)

    return {
        "A_vox": int(a), "B_vox": int(b),
        "inter": int(inter), "union": int(union),
        "dice": float(dice), "iou": float(iou),
        "precision": float(precision), "recall": float(recall), "f1": float(f1),
        "fp": int(fp), "fn": int(fn),
    }

def centroid_distance(A, B, spacing_xyz=None):
    A_idx = np.argwhere(A)
    B_idx = np.argwhere(B)
    if A_idx.size == 0 or B_idx.size == 0:
        return np.nan
    ca = A_idx.mean(axis=0)  # (Z,Y,X)
    cb = B_idx.mean(axis=0)
    d = ca - cb
    if spacing_xyz is not None:
        sx, sy, sz = spacing_xyz  # (X,Y,Z)
        d = np.array([d[0]*sz, d[1]*sy, d[2]*sx], dtype=np.float32)
    return float(np.linalg.norm(d))

def cont_metrics(H1, H2, roi_mask):
    x = H1[roi_mask].astype(np.float64)
    y = H2[roi_mask].astype(np.float64)

    x0 = x - x.mean()
    y0 = y - y.mean()
    corr = (x0*y0).sum() / (np.sqrt((x0*x0).sum() * (y0*y0).sum()) + 1e-12)

    mae  = np.mean(np.abs(x - y))
    rmse = np.sqrt(np.mean((x - y)**2))
    return {"corr": float(corr), "MAE": float(mae), "RMSE": float(rmse)}

# =========================================================
# Heat3D runner (one model)
# =========================================================
def run_model_to_heat3d(model_path, imgs_norm, msk_bin, integrated_feats, device):
    model = load_beit_backbone(model_path, device)
    Z = imgs_norm.shape[0]
    heat3d = np.zeros((Z, TARGET_XY[1], TARGET_XY[0]), dtype=np.float32)  # (Z,Y,X)

    for z in range(Z):
        x1 = imgs_norm[z:z+1].to(device)
        m224 = msk_bin[z].astype(np.float32)
        grad224 = compute_integrated_gradmap_one_slice(model, x1, integrated_feats)
        heat3d[z] = (grad224 * m224).astype(np.float32)

    return heat3d

# =========================================================
# MAIN
# =========================================================
def main():
    os.makedirs(OUT_DIR, exist_ok=True)

    nifti_path = os.path.join(CACHE_SLICE_DIR, NIFTI_NAME)
    assert os.path.exists(nifti_path), f"파일 없음: {nifti_path}"
    assert os.path.exists(MASK_PATH),  f"마스크 없음: {MASK_PATH}"

    # 모델 경로 구성 + 존재 확인
    model_paths = {}
    for r in RUNS:
        p = os.path.join(MODEL_DIR, MODEL_PATTERN.format(run=r))
        assert os.path.exists(p), f"모델 없음: {p}"
        model_paths[r] = p

    # load image/mask
    sitk_img, arr = load_nifti_as_zyx(nifti_path)
    sitk_msk, msk = load_nifti_as_zyx(MASK_PATH)

    # resample to 224
    if (sitk_img.GetSize()[0], sitk_img.GetSize()[1]) != TARGET_XY:
        sitk_img = resample_xy_keep_z(sitk_img, TARGET_XY, sitk.sitkLinear)
        arr = sitk.GetArrayFromImage(sitk_img).astype(np.float32)

    if (sitk_msk.GetSize()[0], sitk_msk.GetSize()[1]) != TARGET_XY:
        sitk_msk = resample_xy_keep_z(sitk_msk, TARGET_XY, sitk.sitkNearestNeighbor)
        msk = sitk.GetArrayFromImage(sitk_msk).astype(np.float32)

    assert arr.shape == msk.shape, f"shape mismatch: {arr.shape} vs {msk.shape}"

    imgs_norm, imgs_vis = volume_to_tensors(arr)
    Z = imgs_norm.shape[0]
    msk_bin = mask_to_binary(msk, 0.5)
    roi = (msk_bin > 0)
    spacing_xyz = sitk_img.GetSpacing()  # (X,Y,Z)

    base = os.path.splitext(os.path.splitext(NIFTI_NAME)[0])[0]

    hot_tag = "_".join([str(p) for p in HOT_P_LIST])
    out_dir = os.path.join(
        OUT_DIR,
        f"{base}_RUNS0_29_HOTP{hot_tag}_run{RUN_TARGET}_specific_freqLE{FREQ_MAX}"
    )
    os.makedirs(out_dir, exist_ok=True)
    print("[OUT]", out_dir)

    # save MRI masked
    if SAVE_MRI_MASKED_NIFTI:
        mri3d_masked = (imgs_vis.numpy() * msk_bin.astype(np.float32)).astype(np.float32)
        save_zyx_as_nifti(mri3d_masked, sitk_img, os.path.join(out_dir, f"{base}_mri_masked_224.nii.gz"))

    # -----------------------------------------------------
    # Allocate storage for 30 heats
    # -----------------------------------------------------
    if USE_MEMMAP:
        mm_path = os.path.join(out_dir, f"{base}_heat30runs_float32.dat")
        heats = np.memmap(
            mm_path, dtype="float32", mode="w+",
            shape=(len(RUNS), Z, TARGET_XY[1], TARGET_XY[0])
        )
        print("[MEMMAP]", mm_path)
    else:
        heats = np.zeros((len(RUNS), Z, TARGET_XY[1], TARGET_XY[0]), dtype=np.float32)

    # -----------------------------------------------------
    # 1) Compute heatmaps for all runs
    # -----------------------------------------------------
    for i, r in enumerate(tqdm(RUNS, desc="Compute heatmaps (30 runs)")):
        H = run_model_to_heat3d(model_paths[r], imgs_norm, msk_bin, INTEGRATED_FEATS, DEVICE)
        heats[i] = H
        if SAVE_ALL_RUN_HEAT_NIFTI:
            save_zyx_as_nifti(H, sitk_img, os.path.join(out_dir, f"{base}_heat_integrated_raw_run{r}.nii.gz"))

    if USE_MEMMAP:
        heats.flush()

    # -----------------------------------------------------
    # 2) Consensus heat (공통) 저장
    # -----------------------------------------------------
    consensus = heats.mean(axis=0).astype(np.float32)  # (Z,Y,X)
    save_zyx_as_nifti(consensus, sitk_img, os.path.join(out_dir, f"{base}_CONSENSUS_mean_raw.nii.gz"))

    # -----------------------------------------------------
    # 2.5) run별 hotspot threshold를 p별로 저장할 dict
    # run_thr[p][run] = thr
    # -----------------------------------------------------
    run_thr = {p: {} for p in HOT_P_LIST}

    # -----------------------------------------------------
    # 3~4) 퍼센타일별로 반복
    # -----------------------------------------------------
    for HOT_P in HOT_P_LIST:
        print(f"\n===== HOT_P = {HOT_P} =====")

        # 3-1) hotspot frequency map
        freq = np.zeros((Z, TARGET_XY[1], TARGET_XY[0]), dtype=np.uint16)

        for i, r in enumerate(tqdm(RUNS, desc=f"Hotspot freq (p{HOT_P})")):
            M, thr = hotspot_mask(heats[i], p=HOT_P)
            run_thr[HOT_P][r] = thr
            freq += M.astype(np.uint16)

        save_zyx_as_nifti(
            freq.astype(np.float32),
            sitk_img,
            os.path.join(out_dir, f"{base}_HOT_FREQ_p{HOT_P}_count.nii.gz")
        )

        # 3-2) consensus hotspot (비교 기준)
        Mc, thr_c = hotspot_mask(consensus, p=HOT_P)
        save_zyx_as_nifti(
            Mc.astype(np.float32),
            sitk_img,
            os.path.join(out_dir, f"{base}_CONSENSUS_HOT_p{HOT_P}.nii.gz")
        )

        # 3-3) run_target hotspot & unique
        idx_t = RUNS.index(RUN_TARGET)
        Ht = heats[idx_t].astype(np.float32)
        Mt, thr_t = hotspot_mask(Ht, p=HOT_P)

        # 다른 run에서의 빈도만 보려면 self 제외
        freq_minus_self = freq.astype(np.int32) - Mt.astype(np.int32)
        uniq = (Mt > 0) & (freq_minus_self <= FREQ_MAX)

        save_zyx_as_nifti(
            Mt.astype(np.float32),
            sitk_img,
            os.path.join(out_dir, f"{base}_run{RUN_TARGET}_HOT_p{HOT_P}.nii.gz")
        )
        save_zyx_as_nifti(
            uniq.astype(np.float32),
            sitk_img,
            os.path.join(out_dir, f"{base}_run{RUN_TARGET}_UNIQUE_p{HOT_P}_freqLE{FREQ_MAX}.nii.gz")
        )

        # 특이부위 연속 heat: run_target heat에서 uniq만 남김
        Ht_unique = (Ht * uniq.astype(np.float32)).astype(np.float32)
        save_zyx_as_nifti(
            Ht_unique,
            sitk_img,
            os.path.join(out_dir, f"{base}_run{RUN_TARGET}_UNIQUE_heat_raw_p{HOT_P}.nii.gz")
        )

        # 4) Table: each run vs consensus (p별 CSV)
        rows = []
        for i, r in enumerate(tqdm(RUNS, desc=f"Metrics (run vs consensus) p{HOT_P}")):
            Hr = heats[i].astype(np.float32)
            Mr, _ = hotspot_mask(Hr, p=HOT_P)

            bm = bin_metrics(Mr, Mc)
            cd = centroid_distance(Mr > 0, Mc > 0, spacing_xyz=spacing_xyz)
            cm = cont_metrics(Hr, consensus, roi_mask=roi)

            # 각 run hotspot이 "run_target unique"를 얼마나 포함?
            uniq_overlap = float(((Mr > 0) & (uniq > 0)).sum() / ((Mr > 0).sum() + 1e-8))

            rows.append({
                "run": r,
                "hot_p": HOT_P,
                "thr_run": run_thr[HOT_P][r],
                "thr_consensus": thr_c,
                "dice_vs_cons": bm["dice"],
                "iou_vs_cons": bm["iou"],
                "prec_vs_cons": bm["precision"],
                "recall_vs_cons": bm["recall"],
                "f1_vs_cons": bm["f1"],
                "centroid_dist_mm_vs_cons": cd,
                "corr_roi_vs_cons": cm["corr"],
                "MAE_roi_vs_cons": cm["MAE"],
                "RMSE_roi_vs_cons": cm["RMSE"],
                "hot_vox_run": bm["A_vox"],
                "hot_vox_cons": bm["B_vox"],
                "uniq_overlap_ratio_with_runTargetUnique": uniq_overlap,
            })

        csv_path = os.path.join(out_dir, f"{base}_metrics_30runs_vs_consensus_p{HOT_P}.csv")
        with open(csv_path, "w", newline="") as f:
            w = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
            w.writeheader()
            w.writerows(rows)

        print(f"[SAVE] CSV: {csv_path}")

    print("\n✅ DONE (all percentiles)")
    print("Output folder:", out_dir)
    print("\n[Key outputs]")
    print(" - MRI:", f"{base}_mri_masked_224.nii.gz")
    print(" - consensus heat:", f"{base}_CONSENSUS_mean_raw.nii.gz")
    for p in HOT_P_LIST:
        print(f" - hotspot freq p{p}:", f"{base}_HOT_FREQ_p{p}_count.nii.gz")
        print(f" - consensus hot p{p}:", f"{base}_CONSENSUS_HOT_p{p}.nii.gz")
        print(f" - run{RUN_TARGET} unique p{p}:", f"{base}_run{RUN_TARGET}_UNIQUE_p{p}_freqLE{FREQ_MAX}.nii.gz")
        print(f" - run{RUN_TARGET} unique heat p{p}:", f"{base}_run{RUN_TARGET}_UNIQUE_heat_raw_p{p}.nii.gz")
        print(f" - CSV p{p}:", f"{base}_metrics_30runs_vs_consensus_p{p}.csv")

if __name__ == "__main__":
    main()
