In [214]:
import os
import re
import json
import glob
import random
from PIL import Image, ImageDraw
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [215]:
from annoy import AnnoyIndex

In [216]:
import torch
from torch.utils import data
from torchvision import transforms
import torchvision.models as models
import torch.nn as nn

In [217]:
ls "./dataset/"

[0m[01;34mold_dataset_all[0m/  [01;34mrefined_dataset[0m/


In [218]:
src_path = "./dataset/refined_dataset/train/"
image_list = [os.path.join(root, name)
            for root, dirs, files in os.walk(src_path)
            for name in files]

In [219]:
len(image_list)
random.shuffle(image_list)

In [220]:
data_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [221]:
PARAMS = {'batch_size': 8,
            'shuffle': False,
            'num_workers': 16}

In [222]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [228]:
class ImageDataset(data.Dataset):
    def __init__(self, images, transforms=None):
        self.images = images
        self.transforms = transforms

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image = Image.open(self.images[index])
        image = np.asarray(image)
        image = image[:,:,:3]
        image = Image.fromarray(image)
        return self.images[index], self.transforms(image)
#         if self.transforms is not None:
#             try:
#                 return self.images[index], self.transforms(image)
#             except:
#                 print(self.images[index])
        return self.images[index], image

In [229]:
dataset = ImageDataset(image_list, data_transform)
data_loader = data.DataLoader(dataset, **PARAMS)

In [230]:
# loading the trained model and generating embedding based on that
base_model = models.resnet18(pretrained=False).to(DEVICE)
for param in base_model.parameters():
    param.requires_grad = False
num_ftrs = base_model.fc.in_features
base_model.fc = nn.Sequential(nn.Linear(num_ftrs, 256), nn.Linear(256, 128))
base_model = base_model.to(DEVICE)

# loading the trained model with trained weights
checkpoint = torch.load("./weights/refined_weights/model_best.pth")
base_model.load_state_dict(checkpoint['state_dict'])
base_model = base_model.eval()

In [231]:
def create_annoy_index(image_embeddings, embedding_size):
    index_to_label = {}
    annoy_index = AnnoyIndex(embedding_size, metric="euclidean")
    for index, embedding in tqdm(enumerate(image_embeddings)):
        index_to_label[index] = embedding["image"].split("/")[-2]
        annoy_index.add_item(index, embedding["embedding"])
    annoy_index.build(10000)
    return annoy_index, index_to_label

In [232]:
def get_embeddings(emb_dataloader):
    embeddings = []				# list to store the embeddings in dict format as name, embedding
    base_model.eval()
    with torch.no_grad():				# no update of parameters
        for image_names, images in tqdm(emb_dataloader):
            images = images.to(DEVICE)
            image_embeddings = base_model(images)
            embeddings.extend([{"image": image_names[index], "embedding": embedding} for index, embedding in enumerate(image_embeddings.cpu().data)])
    return embeddings

In [233]:
image_embeddings = get_embeddings(data_loader)

100%|██████████| 275/275 [00:04<00:00, 67.39it/s]


In [234]:
annoy_index, annoy_index_to_label = create_annoy_index(image_embeddings, 128)

2200it [00:00, 4961.42it/s]


In [235]:
annoy_index.save("annoy_index.ann")

True

In [236]:
with open('annoy_index_to_label.json', 'w') as f:
    json.dump(annoy_index_to_label, f)

In [240]:
query_img_name = './dataset/old_dataset_all/old_dataset_2/val_retina/4.png'

In [241]:
base_model = base_model.to("cpu")

In [242]:
query_img = Image.open(query_img_name)
query_img = data_transform(query_img)
query_img = query_img.unsqueeze(0)
query_img_embedding = base_model(query_img)
query_img_embedding = query_img_embedding.squeeze()
query_img_embedding.shape

torch.Size([128])

In [243]:
query_img_embedding

tensor([ 0.2999,  0.3010,  0.5467, -0.0063, -0.5192, -0.0473, -0.4775, -0.1055,
         0.1031, -0.0983,  0.1274, -0.2049, -0.5786, -0.4274, -0.3502, -0.1535,
        -0.2950, -0.6859, -0.5819,  0.8397, -0.1136,  0.6252,  0.0801,  0.0777,
        -0.3418,  0.0991, -0.4070,  0.2345, -0.9932, -0.1629, -0.0989,  0.3058,
         0.5070,  0.1654, -0.8238,  0.6589, -0.2118,  0.2043, -0.4324, -0.3545,
         0.1861,  0.1656,  0.2705,  0.6299,  0.6270,  0.0326, -0.2639,  0.0637,
        -0.1423,  0.3423, -0.0616, -0.1244,  0.3716,  0.2894,  0.0755, -0.0934,
         0.4488, -0.3963,  0.2110, -0.7865,  0.4953,  0.5968,  0.0142, -0.3433,
         0.3595, -0.2530,  0.0385,  0.9638,  0.1151,  0.0786, -0.0301,  0.1556,
        -0.0017,  0.1398, -0.3046,  0.8517,  0.1961,  0.3400, -0.1359,  0.3403,
        -0.2605,  0.8044, -0.3182, -0.1349,  0.1944, -0.0462, -0.2366,  0.5896,
        -0.1673, -0.6148, -0.2910,  0.2694, -0.3890, -0.0062, -0.3453, -0.0335,
        -0.2148, -0.2522, -0.3993, -0.17

In [244]:
query_img_embedding

tensor([ 0.2999,  0.3010,  0.5467, -0.0063, -0.5192, -0.0473, -0.4775, -0.1055,
         0.1031, -0.0983,  0.1274, -0.2049, -0.5786, -0.4274, -0.3502, -0.1535,
        -0.2950, -0.6859, -0.5819,  0.8397, -0.1136,  0.6252,  0.0801,  0.0777,
        -0.3418,  0.0991, -0.4070,  0.2345, -0.9932, -0.1629, -0.0989,  0.3058,
         0.5070,  0.1654, -0.8238,  0.6589, -0.2118,  0.2043, -0.4324, -0.3545,
         0.1861,  0.1656,  0.2705,  0.6299,  0.6270,  0.0326, -0.2639,  0.0637,
        -0.1423,  0.3423, -0.0616, -0.1244,  0.3716,  0.2894,  0.0755, -0.0934,
         0.4488, -0.3963,  0.2110, -0.7865,  0.4953,  0.5968,  0.0142, -0.3433,
         0.3595, -0.2530,  0.0385,  0.9638,  0.1151,  0.0786, -0.0301,  0.1556,
        -0.0017,  0.1398, -0.3046,  0.8517,  0.1961,  0.3400, -0.1359,  0.3403,
        -0.2605,  0.8044, -0.3182, -0.1349,  0.1944, -0.0462, -0.2366,  0.5896,
        -0.1673, -0.6148, -0.2910,  0.2694, -0.3890, -0.0062, -0.3453, -0.0335,
        -0.2148, -0.2522, -0.3993, -0.17

In [245]:
similar_images = annoy_index.get_nns_by_vector(query_img_embedding, 20, include_distances=True)
similar_image_labels = [annoy_index_to_label[i] for i in similar_images[0]]
similar_image_labels

['salad',
 'salad',
 'salad',
 'salad',
 'salad',
 'salad',
 'salad',
 'salad',
 'salad',
 'salad',
 'salad',
 'salad',
 'salad',
 'salad',
 'salad',
 'salad',
 'salad',
 'salad',
 'salad',
 'salad']

In [92]:
similar_images

([487,
  435,
  405,
  526,
  410,
  452,
  427,
  492,
  501,
  421,
  537,
  446,
  428,
  542,
  432,
  426,
  522,
  407,
  475,
  456],
 [0.8291601538658142,
  0.8713507056236267,
  0.886089026927948,
  0.9035285115242004,
  0.9094564914703369,
  0.9120723009109497,
  0.9153848886489868,
  0.9263986945152283,
  0.9478886127471924,
  0.9479461908340454,
  0.9500370621681213,
  0.9624344706535339,
  0.9677736759185791,
  0.9706259369850159,
  0.9721582531929016,
  0.9772781133651733,
  0.986505389213562,
  0.9901344180107117,
  0.9902897477149963,
  0.995437502861023])