In [None]:
!pip install faiss-gpu

In [2]:
import glob
import os

import faiss
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision import models, transforms

In [None]:
IMAGE_DB_PATH = "./data/train"
IMAGE_TEST = ""
PRETRAINED_MODEL = "./pretrained/mobilenet_v2.pth"

transform = transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5071, 0.4867, 0.4408),
                        std=(0.2675, 0.2565, 0.2761))
])

model = models.mobilenet_v2(pretrained=True)
# model.load_state_dict(torch.load(PRETRAINED_MODEL))
model.classifier = nn.Identity()

In [None]:
faiss_index = faiss.IndexFlatL2(1000)   # build the index

im_indices = []
with torch.no_grad():
    for f in glob.glob(os.path.join(IMAGE_DB_PATH, '*/*')):
        im = Image.open(f)
        im = im.resize((224,224))
        im = torch.tensor([transform(im).numpy()]).cuda()
    
        preds = model(im)
        preds = np.array([preds[0].cpu().numpy()])
        faiss_index.add(preds) #add the representation to index
        im_indices.append(f)   #store the image name to find it later on

In [None]:
with torch.no_grad():
    im = Image.open(IMAGE_TEST)
    im = im.resize((224,224))
    im = torch.tensor([transforms(im).numpy()]).cuda()

    test_embed = model(im).cpu().numpy()
    _, I = faiss_index.search(test_embed, 5)
    print("Retrieved Image: {}".format(im_indices[I[0][0]]))