# DINOv2 — Image→Image Similarity Search (Complete Notebook)
Swap CLIP for DINOv2 (self‑supervised ViT) to do image → image semantic similarity. No text needed.

✅ Works on CPU or GPU • Optional FAISS for speed • Supports multiple DINOv2 sizes via timm.



# 1) Install


In [None]:
!pip install torch torchvision timm pillow numpy tqdm matplotlib
# Optional for large libraries:
!pip install faiss-cpu


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## 2) Imports & Config


In [3]:
from __future__ import annotations
from pathlib import Path
from dataclasses import dataclass
from typing import List, Tuple, Iterable, Optional, Union

import os, json, math, time, dataclasses
import numpy as np
from PIL import Image, ImageOps
from tqdm import tqdm

import torch
import torch.nn.functional as F
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

try:
    import faiss  # optional
    HAS_FAISS = True
except Exception:
    HAS_FAISS = False

# Paths
PHOTOS_DIR = Path("/content/drive/MyDrive/digi/digi_image-similarity/Photos") # 👈 replace
INDEX_DIR  = Path('./index_dinov2')
INDEX_DIR.mkdir(parents=True, exist_ok=True)

# Choose a DINOv2 backbone (timm names)
# 'vit_small_patch14_dinov2.lvd142m'  (dim≈384)
# 'vit_base_patch14_dinov2.lvd142m'   (dim≈768)
# 'vit_large_patch14_dinov2.lvd142m'  (dim≈1024)
# 'vit_giant_patch14_dinov2.lvd142m'  (dim≈1536)
DINO_MODEL = 'vit_base_patch14_dinov2.lvd142m'
BATCH_SIZE = 32  # reduce if you hit CUDA OOM
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


## 3) Utilities


In [4]:
IMG_EXTS = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tif', '.tiff'}

def list_images(root: Path) -> List[Path]:
    files = []
    for p in root.rglob('*'):
        if p.is_file() and p.suffix.lower() in IMG_EXTS:
            files.append(p)
    files.sort(); return files


def open_image(path: Path) -> Image.Image:
    with Image.open(path) as im:
        im = ImageOps.exif_transpose(im)
        return im.convert('RGB')


def chunked(it: Iterable, size: int):
    buf = []
    for x in it:
        buf.append(x)
        if len(buf) == size:
            yield buf; buf = []
    if buf:
        yield buf


def normalize(a: np.ndarray) -> np.ndarray:
    norms = np.linalg.norm(a, axis=1, keepdims=True) + 1e-12
    return a / norms


## 4) DINOv2 Encoder (via timm)


In [None]:
class DinoEncoder:
    def __init__(self, model_name: str = DINO_MODEL, device: Optional[str] = None):
        self.device = device or DEVICE
        self.model = timm.create_model(model_name, pretrained=True).to(self.device)
        self.model.eval()
        self.dim = getattr(self.model, 'num_features', None) or 768
        cfg = resolve_data_config({}, model=self.model)
        # Build eval transform matching model's expected mean/std/size
        self.transform = create_transform(**cfg, is_training=False)

    @torch.inference_mode()
    def encode_images(self, pil_images: List[Image.Image]) -> np.ndarray:
        if not pil_images:
            return np.zeros((0, self.dim), dtype=np.float32)
        xs = [self.transform(im).unsqueeze(0) for im in pil_images]
        x = torch.cat(xs, 0).to(self.device)
        feats = self.model.forward_features(x)
        # Try to grab pre‑logits features robustly
        try:
            feats = self.model.forward_head(feats, pre_logits=True)
        except Exception:
            # Fallbacks for various timm return types
            if isinstance(feats, dict):
                feats = feats.get('x_norm_clstoken', None) or feats.get('avgpool', None) or next(iter(feats.values()))
            # If spatial, global average pool
            if feats.ndim > 2:
                feats = feats.mean(dim=(2,3))
        feats = F.normalize(feats, p=2, dim=-1)
        return feats.detach().cpu().to(torch.float32).numpy()

enc = DinoEncoder(DINO_MODEL)
print('Using', DINO_MODEL, 'on', enc.device, '| dim =', enc.dim)


## 5) On‑Disk Index


In [6]:
@dataclass
class ImageMeta:
    path: str
    mtime: float

class DiskIndex:
    def __init__(self, index_dir: Path):
        self.index_dir = index_dir
        self.emb_path = index_dir / 'embeddings.npy'
        self.meta_path = index_dir / 'images.jsonl'
        self.header_path = index_dir / 'header.json'
        self.faiss_path = index_dir / 'faiss.index'

    def save(self, header: dict, embeddings: np.ndarray, metas: List[ImageMeta], faiss_index=None):
        self.index_dir.mkdir(parents=True, exist_ok=True)
        embeddings = normalize(embeddings.astype(np.float32, copy=False))
        np.save(self.emb_path, embeddings)
        with open(self.meta_path, 'w', encoding='utf-8') as f:
            for m in metas:
                f.write(json.dumps(dataclasses.asdict(m), ensure_ascii=False) + '\n')
        with open(self.header_path, 'w', encoding='utf-8') as f:
            json.dump(header, f, ensure_ascii=False, indent=2)
        if faiss_index is not None:
            faiss.write_index(faiss_index, str(self.faiss_path))

    def load(self):
        header = json.loads(self.header_path.read_text('utf-8'))
        embs = np.load(self.emb_path)
        metas = [json.loads(line) for line in self.meta_path.read_text('utf-8').splitlines()]
        return header, embs, metas

    def try_load_faiss(self):
        if HAS_FAISS and self.faiss_path.exists():
            return faiss.read_index(str(self.faiss_path))
        return None


## 6) Build the Index (embed your photo folder)


In [None]:
files = list_images(PHOTOS_DIR)
print(f'Found {len(files)} images under {PHOTOS_DIR}')

metas: List[ImageMeta] = []
all_feats: List[np.ndarray] = []

for batch_paths in tqdm(list(chunked(files, BATCH_SIZE)), desc='Embedding', unit='batch'):
    pil_batch, batch_metas = [], []
    for p in batch_paths:
        try:
            pil_batch.append(open_image(p))
            batch_metas.append(ImageMeta(path=str(p), mtime=p.stat().st_mtime))
        except Exception as e:
            print('[skip]', p, e)
    if not pil_batch: continue
    feats = enc.encode_images(pil_batch)
    all_feats.append(feats)
    metas.extend(batch_metas)

if not metas:
    raise SystemExit('No images embedded — check PHOTOS_DIR')

embeddings = np.vstack(all_feats)
header = {
    'backbone': DINO_MODEL,
    'dim': int(embeddings.shape[-1]),
    'created_at': time.time(),
    'num_items': len(metas),
    'device_build': enc.device,
}

index = DiskIndex(INDEX_DIR)
index.save(header, embeddings, metas)
print(f'Saved DINOv2 index → {INDEX_DIR} with {len(metas)} items.')


## 7) Image → Image Search (cosine or FAISS IP)


In [8]:
from typing import Union

@torch.inference_mode()
def encode_query_image(img: Union[str, Path, Image.Image]) -> np.ndarray:
    if isinstance(img, (str, Path)):
        pil = open_image(Path(img))
    else:
        pil = img
    return enc.encode_images([pil])[0].astype(np.float32)


def search_image(index_dir: Path, query_image: Union[str, Path, Image.Image], topk: int = 12,
                 prefer_faiss: bool = True, include_self: bool = False):
    index = DiskIndex(index_dir)
    header, embs, metas = index.load()
    embs = normalize(embs.astype(np.float32))

    q = encode_query_image(query_image)  # (D,)

    if prefer_faiss and HAS_FAISS:
        fidx = index.try_load_faiss()
        if fidx is None:
            fidx = faiss.IndexFlatIP(header['dim'])
            fidx.add(embs)
        D, I = fidx.search(q[None, :], min(topk+10, len(embs)))
        candidates = [(metas[int(i)], float(d)) for i, d in zip(I[0], D[0]) if int(i) >= 0]
    else:
        sims = embs @ q
        idx = np.argsort(-sims)[:topk+10]
        candidates = [(metas[int(i)], float(sims[int(i)])) for i in idx]

    q_path = None
    if isinstance(query_image, (str, Path)):
        q_path = str(Path(query_image).resolve())
    results = []
    for meta, score in candidates:
        p = str(Path(meta['path']).resolve())
        if not include_self and q_path is not None and p == q_path:
            continue
        results.append((meta, score))
        if len(results) >= topk:
            break
    return results


## 8) Visualization


In [9]:
import matplotlib.pyplot as plt

def show_query_and_results(query_image: Union[str, Path, Image.Image], results: List[Tuple[dict, float]],
                           cols: int = 4, max_w: int = 320, title: str = None):
    if isinstance(query_image, (str, Path)):
        q_img = open_image(Path(query_image)); q_title = Path(query_image).name
    else:
        q_img = query_image; q_title = 'query'

    plt.figure(figsize=(4,4))
    plt.imshow(q_img); plt.axis('off')
    plt.title(f"Query: {q_title}")
    plt.show()

    if not results:
        print('No neighbors.'); return
    rows = math.ceil(len(results)/cols)
    if title: print(title)
    plt.figure(figsize=(cols*4, rows*4))
    for i, (meta, score) in enumerate(results, start=1):
        try:
            img = open_image(Path(meta['path']))
            img.thumbnail((max_w, max_w))
            plt.subplot(rows, cols, i)
            plt.imshow(img)
            plt.axis('off')
            plt.title(f"{score:.3f}\n{Path(meta['path']).name}", fontsize=9)
        except Exception as e:
            print('[viz-skip]', meta['path'], e)
    plt.tight_layout(); plt.show()


In [None]:
query_image_path = "/content/Gemini_Generated_Image_4qdc6c4qdc6c4qdc.png"  # 👈 set a specific file if you like
neighbors = search_image(INDEX_DIR, query_image_path, topk=12, prefer_faiss=True, include_self=False)
show_query_and_results(query_image_path, neighbors, cols=4, title='DINOv2 — Image→Image')
