# Reverse Image Search

For our reverse image search, we used the ResNet18 model achieved from Task 1 along with the Approximate Nearest Neighbors Oh Yeah library by Spotify to find the nearest neighbours of each image (more specifically, the nearest neighbour of the image's feature vector). The nearest neighbours are saved into a tree data structure for faster retrieval.

In [None]:
import tensorflow as tf
import cv2
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm, tqdm_notebook
from PIL import Image
import os
import pickle
import multiprocessing

In [None]:
import torch
from torchvision import transforms 
device = "cuda"

In [None]:
IMAGE_SIZE = (224, 224)
data_csv = pd.read_csv('final_image_data_path.csv')

### Define feature extractor

In [None]:
transform_test = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

In [None]:
def extract_features(img_path, model, image_size=IMAGE_SIZE):
# Create a new model that outputs the desired feature layer
    print(img_path) # For debugging

    feature_model = torch.nn.Sequential(*list(model.children())[:-2])
    feature_model.eval()

    # Load and preprocess the image
    image = Image.open(img_path)
    image = image.convert('RGB')

    resize = transform_test(image).to(device)
    
    print(resize.shape)

    # Extract features from the image
    features = feature_model(resize.unsqueeze(0))

    # Normalize the features
    flattened_features = features.flatten()
    normalised_features = flattened_features / torch.norm(flattened_features)

    # Convert to NumPy Array
    normalised_features = normalised_features.cpu().detach().numpy()

    return normalised_features

In [None]:
# Load model
base_model = torch.load('model.pth')

In [None]:
# Get the image paths from the csv and put into a list
image_paths = data_csv['image_path'].tolist()

### Extraction of all images in the dataset

Extract features from each image from the path list

In [None]:
feature_list = []

for i in tqdm_notebook(range(len(image_paths))):
    feature_list.append(extract_features(image_paths[i], base_model))

Save the pickles to avoid having to recompute the features

In [None]:
pickle_dir = os.path.join(os.getcwd(), "reverse_image_pickles", "final_vgg")

if not os.path.exists(pickle_dir):
    os.makedirs(pickle_dir)

pickle.dump(feature_list, open(os.path.join(pickle_dir, 'features.pickle'),'wb'))
pickle.dump(image_paths, open(os.path.join(pickle_dir, 'image_paths.pickle'),'wb'))

Load the features vectors and image paths from the pickle files (this is mainly for local testing, as this allows us to avoid having to recompute the features every time)

In [None]:
pickle_dir = os.path.join(os.getcwd(), "reverse_image_pickles", "final_vgg")
print(pickle_dir)
#Getting filenames and features from pickle files
with open(os.path.join(pickle_dir, 'features.pickle'), 'rb') as f:
    feature_list = pickle.load(f)

with open(os.path.join(pickle_dir, 'image_paths.pickle'), 'rb') as f:
    filenames = pickle.load(f)

Get the name of the label from the path

In [None]:
labels = []
for files in filenames:
    label = files.split('\\')[-2]
    labels.append(label)

Create a dataframe with the image path, image representation and label. This dataframe will be used for to store the location of the image, the image representation and the label.

In [None]:
df = pd.DataFrame({'img_id':filenames, 'img_repr': feature_list, 'label': labels})
len(df)

In [None]:
df.head()

The dataframe is saved as a pickle file for deployment

In [None]:
df.to_pickle(os.path.join(pickle_dir, 'df.pickle'))

In [None]:
import io

def convert_to_jpg_arr(image_path):
    with Image.open(image_path) as image:
        with io.BytesIO() as output:
            image.convert('RGB').save(output, format='JPEG')
            output.seek(0)
            return np.asarray(Image.open(output))

Below is where we can specify a new image to test the reverse image search. The feature of this new image is extracted and compared to the features of the images in the dataset. The nearest neighbours are then returned.

In [None]:
# Add a new image to the data
img_location = "IMG_0634.PNG"

# Extract the features of the new image
new_features = extract_features(img_location, base_model)

# Add new_features to the feature_list2
feature_list2 = feature_list.copy()
feature_list2.append(new_features)

# Clone df and add the new image to it
df2 = df.copy()

plt.imshow(convert_to_jpg_arr(img_location))

We have to create a copy of the feature_list and the dataframe, as during development if we only use the original ones, new images will be added into the feature_list and dataframe every time we run the code. This is not ideal as we want to keep the feature_list and dataframe constant for testing and deployment.

In [None]:
df2.loc[len(df2)] = [img_location, new_features, '']
df2.tail()

In [None]:
len(df2['img_repr'][0])

This is where we produce the nearest neighbours tree using AnnoyIndex from the annoy library. This tree allows us to retrieve the closest images to the query image.

In [None]:
from annoy import AnnoyIndex
import random

f = len(df2['img_repr'][0])
t = AnnoyIndex(f, metric='euclidean')

for i in tqdm(range(len(feature_list2))):
    t.add_item(i, feature_list2[i])
    
_ = t.build(150, n_jobs=-1)

Below is the function we use to get a new dataframe storing the 10 images that have closest features to the inputted one

In [None]:
def get_similar_images_annoy(img_index):
    base_img_id, base_vector, base_label  = df2.iloc[img_index, [0, 1, 2]]
    similar_img_ids = t.get_nns_by_item(img_index, 11)
    return base_img_id, base_label, df2.iloc[similar_img_ids[1:]]

In [None]:
base_image, base_label, similar_images_df = get_similar_images_annoy(len(df2)-1)

In [None]:
similar_images_df

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import cv2 as cv

In [None]:
# Another show images function but this time will be in 3 rows with 4 images each
def show_images(new_img_path):
    plt.figure(figsize = (20,20))
    
    plt.subplot(3,4,1)
    image = convert_to_jpg_arr(new_img_path)
    plt.imshow(image)
    plt.title('Base Image')
    plt.axis('off')
    
    for i in range(len(similar_images_df)):
        path = os.path.join(similar_images_df.iloc[i,0])
        image = mpimg.imread(path)
        plt.subplot(3,4,i+2)
        plt.imshow(image)
        plt.title('Similar Image ' + similar_images_df.iloc[i,2])
        plt.axis('off')

In [None]:
show_images(img_location)