In [10]:
import time
from argparse import ArgumentParser

import os
import faiss
import torch
from torch.utils.data import DataLoader, SequentialSampler

from src.feature_extraction import MyResnet50, MyVGG16, RGBHistogram, LBP
from src.dataloader import MyDataLoader

In [19]:
image_root = './data/images'
feature_root = './data/features'

print('Start indexing .......')
start = time.time()

device = torch.device("cpu")
batch_size = 32

# Load module feature extraction 
extractor = MyResnet50(device)

dataset = MyDataLoader(image_root)
sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset,batch_size=batch_size,sampler=sampler)

indexer = faiss.IndexFlatL2(extractor.shape)

for images, image_paths in dataloader:
	images = images.to(device)
	features = extractor.extract_features(images)
	# print(features.shape)
	# indexer.add(features)
	indexer.add(features)
 
	# indexer.add(features)

# Save features
faiss.write_index(indexer, feature_root + '/' + "Resnet50" + '.index.bin')

end = time.time()
print('Finish in ' + str(end - start) + ' seconds')

Start indexing .......
Finish in 192.44080781936646 seconds


In [28]:
from PIL import Image
_img = "/home/hieutm8/Projects/DEEPFAKE_SIMILARITY_SEARCH/data/images/0ae050ad67e48ebad7f5.jpg"
from src.dataloader import get_transformation
transform = get_transformation()

img = Image.open(_img)
img = img.convert('RGB')
image_tensor = transform(img)
image_tensor = image_tensor.unsqueeze(0).to(device)
feat = extractor.extract_features(image_tensor)
feat
D, indices = indexer.search(feat, k=5)
D, indices

(array([[0.       , 3.3062997, 3.8204498, 6.138777 , 6.6625433]],
       dtype=float32),
 array([[ 37, 506, 326,  52, 595]]))

In [23]:
D, indices = indexer.search(feat, k=5)

In [26]:
indices

array([[ 37, 506, 326,  52, 595]])