<a href="https://colab.research.google.com/github/siliang0312/SSZ/blob/master/dinov3_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [22]:
import torch
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import numpy as np


class DinoV2Embedder:
    def __init__(self, model_name: str = "facebook/dinov3-vits16-pretrain-lvd1689m", device: str | None = None):
        if device is None:
            device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
        self.device = device
        self.model_name = model_name

        self.processor = AutoImageProcessor.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(self.device).eval()

    @torch.inference_mode()
    def encode(self, image: Image.Image, pooling: str = "cls_patch_mean") -> np.ndarray:
        """
        Args:
            image: PIL.Image (RGB)
            pooling:
              - "cls": only CLS token
              - "patch_mean": mean of patch tokens
              - "cls_patch_mean": (CLS + patch_mean) / 2  (default, recommended for retrieval)
        Returns:
            vec: np.ndarray, shape [D], L2-normalized
        """
        if image.mode != "RGB":
            image = image.convert("RGB")

        inputs = self.processor(images=image, return_tensors="pt").to(self.device)
        outputs = self.model(**inputs)
        x = outputs.last_hidden_state  # [B, tokens, dim]

        cls = x[:, 0]     # [B, dim]
        patch = x[:, 1:]  # [B, tokens-1, dim]

        if pooling == "cls":
            emb = cls
        elif pooling == "patch_mean":
            emb = patch.mean(dim=1)
        elif pooling == "cls_patch_mean":
            emb = (cls + patch.mean(dim=1)) / 2
        else:
            raise ValueError(f"Unknown pooling: {pooling}")

        emb = F.normalize(emb, dim=-1)     # [B, dim]
        vec = emb[0].detach().cpu().numpy()
        return vec


# 用法示例：
# embedder = DinoV2Embedder("facebook/dinov3-vits16-pretrain-lvd1689m")
# vec = embedder.encode(image)  # shape [D]
# print(vec.shape)


In [23]:
import os
from pathlib import Path
from typing import List, Dict

In [28]:
from google.colab import userdata
from huggingface_hub import login

login(token=userdata.get("HF_TOKEN"))

In [29]:
def build_gallery_pt(
    image_dir: str = "candys",
    out_pt: str = "candys_gallery_dinov2_small.pt",
    model_name: str = "facebook/dinov3-vits16-pretrain-lvd1689m",
    pooling: str = "cls_patch_mean",
):
    image_dir = Path(image_dir)
    assert image_dir.exists(), f"Folder not found: {image_dir.resolve()}"

    exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
    paths = sorted([p for p in image_dir.iterdir() if p.suffix.lower() in exts])
    assert len(paths) > 0, f"No images found in: {image_dir.resolve()}"

    labels = [p.stem for p in paths]       # 类别 = 文件名（不含扩展名）
    rel_paths = [str(p) for p in paths]    # 路径

    embedder = DinoV2Embedder(model_name=model_name)

    vecs = []
    for p in paths:
        with Image.open(p) as img:
            vec = embedder.encode(img, pooling=pooling)  # ✅ 单张图片，不传 list

        # 兼容：encode 返回 numpy 或 torch
        if isinstance(vec, torch.Tensor):
            t = vec.detach().cpu().float()
        else:
            t = torch.from_numpy(np.asarray(vec)).float()

        vecs.append(t)

    embeddings = torch.stack(vecs, dim=0)  # [N, D] on CPU（已归一化）

    payload: Dict = {
        "embeddings": embeddings,   # torch.FloatTensor [N, D] 已 L2 normalize
        "labels": labels,           # list[str]
        "paths": rel_paths,         # list[str]
        "model_name": model_name,
        "pooling": pooling,
    }

    torch.save(payload, out_pt)
    print(f"Saved: {out_pt}")
    print("embeddings shape:", tuple(embeddings.shape))
    print("example:", labels[0], rel_paths[0])

    return out_pt


# 运行
pt_path = build_gallery_pt(
    image_dir="candys",
    out_pt="candys_gallery_dinov2_small.pt",
    model_name="facebook/dinov3-vits16-pretrain-lvd1689m",
    pooling="cls_patch_mean",
)

preprocessor_config.json:   0%|          | 0.00/585 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/86.4M [00:00<?, ?B/s]

Saved: candys_gallery_dinov2_small.pt
embeddings shape: (5, 384)
example: milkplus candys/milkplus.jpg


In [30]:
import torch
import numpy as np
from PIL import Image
from typing import List, Tuple

def knn_search(
    query_image_path: str,
    gallery_pt: str,
    topk: int = 5,
):
    data = torch.load(gallery_pt, map_location="cpu")
    gallery_emb = data["embeddings"]  # [N, D]（已 L2 normalize）
    labels = data["labels"]
    paths = data["paths"]

    embedder = DinoV2Embedder(model_name=data["model_name"])

    with Image.open(query_image_path) as img:
        q_vec = embedder.encode(img, pooling=data["pooling"])  # ✅ 单张图片

    # 兼容：encode 返回 numpy 或 torch
    if isinstance(q_vec, torch.Tensor):
        q = q_vec.detach().cpu().float()
    else:
        q = torch.from_numpy(np.asarray(q_vec)).float()

    # 如果你的 encode 已经 normalize，这步可省；为了稳妥保留
    q = torch.nn.functional.normalize(q, dim=-1)

    # cosine 相似度：gallery 已 normalize，q 也 normalize，所以点积=cosine
    sims = gallery_emb @ q  # [N]

    k = min(topk, len(labels))
    vals, idx = torch.topk(sims, k=k)

    results = []
    for score, j in zip(vals.tolist(), idx.tolist()):
        results.append((score, labels[j], paths[j]))
    return results

In [38]:
knn_search('original_test.jpg','candys_gallery_dinov2_small.pt')

[(0.7603328227996826, 'orange', 'candys/orange.jpg'),
 (0.7515507936477661, 'milkplus', 'candys/milkplus.jpg'),
 (0.7380675673484802, 'original', 'candys/original.jpg'),
 (0.6968473196029663, 'purple', 'candys/purple.jpg'),
 (0.6668444871902466, 'pink', 'candys/pink.jpg')]