In [1]:
import os, json
import torch
from PIL import Image
import numpy as np
import torchvision.transforms as T
from tqdm import tqdm
import faiss
import matplotlib.pyplot as plt
import cv2
from transformers import AutoModel, AutoImageProcessor

In [2]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Device:::", device)

Device::: mps


In [3]:
# Load FAISS index
index = faiss.read_index("oxford_index.bin")

# Load embedding JSON (dùng để mapping index → tên file ảnh)
with open("oxford_embeddings.json", "r") as f:
    all_embeddings = json.load(f)

# Lấy danh sách file ảnh theo thứ tự index đã add
files = list(all_embeddings.keys())

In [4]:
# Transform giống như lúc tạo index
transform = T.Compose([
    T.ToTensor(),
    T.Resize(244),
    T.CenterCrop(224),
    T.Normalize([0.5], [0.5])
])

In [5]:
def load_image(img_path):
    img = Image.open(img_path).convert("RGB")
    return transform(img).unsqueeze(0).to(device)[:, :3]

In [6]:
dinov2_model = AutoModel.from_pretrained('facebook/dinov2-base').to(device)
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base', use_fast=False)

In [None]:
query_img = "collection/images/all_souls_000022.jpg"

with torch.no_grad():
    output = dinov2_model(load_image(query_img).to(device))
    embedding = output.pooler_output.cpu().numpy()  # lấy vector embedding (1, 768)


# Tìm top K ảnh giống nhất
D, I = index.search(embedding, k=5)

# Xem ảnh kết quả
for idx in I[0]:
    print("Giống ảnh:", files[idx])
    img = cv2.imread(files[idx])
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    plt.show()