In [59]:
import torch

ckpt_path = "../outputs/mvtec_cpr_crop/log/AnomalyNCD_capsule_(2025.12.15_09-21)/checkpoints/model.pt"
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)

print(type(ckpt))
if isinstance(ckpt, dict):
    print("keys:", ckpt.keys())
    sd = ckpt.get("state_dict", ckpt.get("model", ckpt))
    print("state_dict example keys:", list(sd.keys())[:20])

<class 'dict'>
keys: dict_keys(['model', 'optimizer', 'epoch', 'loss_list', 'base_category', 'category', 'mask_layers'])
state_dict example keys: ['0.cls_token', '0.pos_embed', '0.patch_embed.proj.weight', '0.patch_embed.proj.bias', '0.blocks.0.norm1.weight', '0.blocks.0.norm1.bias', '0.blocks.0.attn.qkv.weight', '0.blocks.0.attn.qkv.bias', '0.blocks.0.attn.proj.weight', '0.blocks.0.attn.proj.bias', '0.blocks.0.norm2.weight', '0.blocks.0.norm2.bias', '0.blocks.0.mlp.fc1.weight', '0.blocks.0.mlp.fc1.bias', '0.blocks.0.mlp.fc2.weight', '0.blocks.0.mlp.fc2.bias', '0.blocks.1.norm1.weight', '0.blocks.1.norm1.bias', '0.blocks.1.attn.qkv.weight', '0.blocks.1.attn.qkv.bias']


In [55]:
import torch
import timm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# model 생성/로드
model = timm.create_model("vit_base_patch8_224", pretrained=False, num_classes=0)
missing, unexpected = model.load_state_dict(sd, strict=False)
model = model.to(DEVICE).eval()
print("model device:", next(model.parameters()).device)

model device: cuda:0


In [None]:
import os, numpy as np
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms
import matplotlib.pyplot as plt

IMG_SIZE = 224
GRID = 28
tfm = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

def minmax(x, eps=1e-8):
    x = x.astype(np.float32)
    return (x - x.min()) / (x.max() - x.min() + eps)

@torch.no_grad()
def extract_patch_feats(pil_img):
    x = tfm(pil_img).unsqueeze(0).to(DEVICE)         # [1,3,224,224]
    feats = model.forward_features(x)                 # [1, 1+784, 768] (cls+patch)
    if isinstance(feats, (tuple, list)):
        feats = feats[0]
    patch = feats[:, 1:, :]                           # [1,784,768]
    patch = F.normalize(patch, dim=-1)
    return patch.squeeze(0)                           # [784,768]

@torch.no_grad()
def build_bank(normal_dir, max_imgs=50):
    bank = []
    exts = (".png",".jpg",".jpeg",".bmp")
    files = [f for f in sorted(os.listdir(normal_dir)) if f.lower().endswith(exts)]
    files = files[:max_imgs]
    for fn in files:
        img = Image.open(os.path.join(normal_dir, fn)).convert("RGB")
        bank.append(extract_patch_feats(img))         # [784,768]
    bank = torch.cat(bank, dim=0)                     # [max_imgs*784,768]
    return bank

@torch.no_grad()
def compute_anomaly_map(pil_img, bank):
    q = extract_patch_feats(pil_img)                  # [784,768]
    # cosine similarity -> nearest neighbor distance
    sim = q @ bank.t()                                # [784, M]
    nn_sim, _ = sim.max(dim=1)                        # [784]
    score = (1 - nn_sim).reshape(GRID, GRID)          # [28,28]
    # upsample to 224x224
    score = score.unsqueeze(0).unsqueeze(0)           # [1,1,28,28]
    score_up = F.interpolate(score, size=(IMG_SIZE, IMG_SIZE), mode="bilinear", align_corners=False)
    return score_up.squeeze().cpu().numpy()           # [224,224]

def save_evidence(pil_img, amap, out_png, percentile=95):
    img = np.array(pil_img.resize((IMG_SIZE, IMG_SIZE)).convert("RGB"))
    hm = minmax(amap)

    thr = float(np.percentile(hm, percentile))
    mask = (hm >= thr).astype(np.uint8)

    fig = plt.figure(figsize=(12,3))
    gs = fig.add_gridspec(1,5, wspace=0.05)

    ax = fig.add_subplot(gs[0,0]); ax.imshow(img); ax.set_title("Input"); ax.axis("off")
    ax = fig.add_subplot(gs[0,1]); ax.imshow(hm, cmap="jet"); ax.set_title("Anomaly Map"); ax.axis("off")
    ax = fig.add_subplot(gs[0,2]); ax.imshow(img); ax.imshow(hm, cmap="jet", alpha=0.45); ax.set_title("Overlay"); ax.axis("off")
    ax = fig.add_subplot(gs[0,3]); ax.imshow(mask, cmap="gray"); ax.set_title(f"Binary (p{percentile})"); ax.axis("off")
    ax = fig.add_subplot(gs[0,4]); ax.imshow(img); ax.imshow(mask, cmap="Reds", alpha=0.35); ax.set_title("Mask Overlay"); ax.axis("off")

    os.makedirs(os.path.dirname(out_png), exist_ok=True)
    fig.savefig(out_png, dpi=200, bbox_inches="tight")
    plt.close(fig)

'''
"../data/mvtec_cpr_crop/bottle/images/good"
"../data/mvtec_cpr_crop/cable/images/good"
"../data/mvtec_cpr_crop/capsule/images/good"
"../data/mvtec_cpr_crop/carpet/images/good"
"../data/mvtec_cpr_crop/grid/images/good"
"../data/mvtec_cpr_crop/hazelnut/images/good"
'''
normal_dir = "../data/mvtec_cpr_crop/capsule/images/good"
test_img_path ="../data/mvtec_cpr_crop/capsule/images/squeeze/000_crop0.png"

bank = build_bank(normal_dir, max_imgs=30)
img = Image.open(test_img_path).convert("RGB")
amap = compute_anomaly_map(img, bank)

save_evidence(img, amap, "outputs/evidence/capsule.png")
