In [1]:
import torch
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import numpy as np
import faiss
import pickle

In [2]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [3]:
weights = ResNet50_Weights.DEFAULT
base_model = resnet50(weights=weights)
model = torch.nn.Sequential(*(list(base_model.children())[:-1]))
model = model.to(device).eval()

In [4]:
img_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [5]:
class GalleryDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.image_files = [f for f in os.listdir(root_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
    
    def __len__(self): return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.root_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        return img_transforms(image), img_name

In [6]:
def build_index(gallery_path):
    dataset = GalleryDataset(gallery_path)
    loader = DataLoader(dataset, batch_size=32, shuffle=False)
    
    vectors, names = [], []
    print("Vectors are calculating...")
    
    with torch.no_grad():
        for imgs, batch_names in loader:
            features = model(imgs.to(device))
            vectors.append(features.cpu().view(features.size(0), -1).numpy())
            names.extend(batch_names)
    
    vectors = np.vstack(vectors).astype('float32')
    
    faiss.normalize_L2(vectors)
    
    index = faiss.IndexFlatL2(vectors.shape[1])
    index.add(vectors)
    
    faiss.write_index(index, "gallery.index")
    with open('names.pkl', 'wb') as f:
        pickle.dump(names, f)
    print(f"{len(names)} images were indexed and saved.")

In [7]:
def search(query_img_path, k=5):
    index = faiss.read_index("gallery.index")
    with open('names.pkl', 'rb') as f:
        names = pickle.load(f)
        
    img = Image.open(query_img_path).convert('RGB')
    img_t = img_transforms(img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        query_vector = model(img_t).cpu().view(1, -1).numpy().astype('float32')
    
    distances, indices = index.search(query_vector, k)
    return [names[i] for i in indices[0]]

In [8]:
build_index('data/gallery')

Vectors are calculating...
5000 images were indexed and saved.
