<h1 align=center style='color:blue; border:1px dotted blue;'>Image Similarity Search in PyTorch</h1>

# Goal

Create a model that given an image as input, retrieves the image of the most similar image.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from sklearn.neighbors import NearestNeighbors

import torch
from torch import nn
from torch import optim

import cv2
import os
from PIL import Image

# Load Data

In [None]:
# base bath for kaggle
path = '/kaggle/input/images-alike/'

In [None]:
# load the validate dataset
df = pd.read_csv(path + 'validate.csv', index_col=0)
df.head()

In [None]:
def draw_pair_similars(df, idx):
    dim = (300, 300)
    img_a = cv2.imread(path + 'images/' + str(df.iloc[idx]['image_a']) + '.jpg')
    img_a = cv2.resize(img_a, dim, interpolation = cv2.INTER_AREA)
    
    img_b = cv2.imread(path + 'images/' + str(df.iloc[idx]['image_b']) + '.jpg')
    img_b = cv2.resize(img_b, dim, interpolation = cv2.INTER_AREA)
    
    imgs = [img_a, img_b]
    
    f, ax = plt.subplots(1, 2, figsize=(12, 12))
    for ix, img in enumerate(imgs):
        ax[ix].imshow(img)
        ax[ix].axis('off')
        
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.show()

In [None]:
draw_pair_similars(df, 84)

# Dataset Model

In [None]:
class AlikeDataset(Dataset):

    def __init__(self, path, transform=None):
        self.path = path + 'images/'
        self.files = self.absolute_file_paths(self.path)
        self.transform = transform
        
    def absolute_file_paths(self, directory):
        path = os.path.abspath(directory)
        return [entry.path for entry in os.scandir(path) if entry.is_file()]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):

        img_path = self.files[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            tensor_image = self.transform(image)
            return tensor_image, tensor_image
        
        return image

In [None]:
# create transformers
transforms = T.Compose([T.Resize((352, 128)),
                        T.ToTensor()])

# create dataset
alike_ds = AlikeDataset(path, transforms)

In [None]:
plt.figure(figsize = (10,10))
plt.imshow(alike_ds[0][0].permute(1, 2, 0))

# Images Loaders

In [None]:
train_size = int(len(alike_ds) * 0.75)
val_size = len(alike_ds) - train_size

# split data to train and test
train_dataset, val_dataset = torch.utils.data.random_split(alike_ds, [train_size, val_size]) 

# create the train dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
 
# create the validation dataloader
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32)

# create the full dataloader
full_loader = torch.utils.data.DataLoader(alike_ds, batch_size=32)

# Architecture

1. First we train the models for autoeconding:

[Input] -> [Encoder] -> [Decoder] -> [MSE] -> [Optimize]

2. We combine all the encoders in order to create the embeddings:

[Encoders] -> [Embeddings]

3. Train a k-nn model with the embeddings, use an image to find the nearest neighbors embeddings and retrieve the image

[Image] -> [k-nn] -> [Similar Image]

# Model

## Encoder

In [None]:
class ConvEncoder(nn.Module):
    """
    A simple Convolutional Encoder Model
    """

    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 16, (3, 3), padding=(1, 1))
        self.relu1 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d((2, 2))

        self.conv2 = nn.Conv2d(16, 32, (3, 3), padding=(1, 1))
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool2 = nn.MaxPool2d((2, 2))

        self.conv3 = nn.Conv2d(32, 64, (3, 3), padding=(1, 1))
        self.relu3 = nn.ReLU(inplace=True)
        self.maxpool3 = nn.MaxPool2d((2, 2))

        self.conv4 = nn.Conv2d(64, 128, (3, 3), padding=(1, 1))
        self.relu4 = nn.ReLU(inplace=True)
        self.maxpool4 = nn.MaxPool2d((2, 2))

        self.conv5 = nn.Conv2d(128, 256, (3, 3), padding=(1, 1))
        self.relu5 = nn.ReLU(inplace=True)
        self.maxpool5 = nn.MaxPool2d((2, 2))

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)

        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)

        x = self.conv3(x)
        x = self.relu3(x)
        x = self.maxpool3(x)

        x = self.conv4(x)
        x = self.relu4(x)
        x = self.maxpool4(x)

        x = self.conv5(x)
        x = self.relu5(x)
        x = self.maxpool5(x)
        
        return x

## Decoder

In [None]:
class ConvDecoder(nn.Module):
    """
    A simple Convolutional Decoder Model
    """

    def __init__(self):
        super().__init__()
        self.deconv1 = nn.ConvTranspose2d(256, 128, (2, 2), stride=(2, 2))
        self.relu1 = nn.ReLU(inplace=True)

        self.deconv2 = nn.ConvTranspose2d(128, 64, (2, 2), stride=(2, 2))
        self.relu2 = nn.ReLU(inplace=True)

        self.deconv3 = nn.ConvTranspose2d(64, 32, (2, 2), stride=(2, 2))
        self.relu3 = nn.ReLU(inplace=True)

        self.deconv4 = nn.ConvTranspose2d(32, 16, (2, 2), stride=(2, 2))
        self.relu4 = nn.ReLU(inplace=True)

        self.deconv5 = nn.ConvTranspose2d(16, 3, (2, 2), stride=(2, 2))
        self.relu5 = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.deconv1(x)
        x = self.relu1(x)

        x = self.deconv2(x)
        x = self.relu2(x)

        x = self.deconv3(x)
        x = self.relu3(x)

        x = self.deconv4(x)
        x = self.relu4(x)

        x = self.deconv5(x)
        x = self.relu5(x)
        return x

# Training

In [None]:
device = "cuda"
loss_fn = nn.MSELoss()

encoder = ConvEncoder().to(device)
decoder = ConvDecoder().to(device)

autoencoder_params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = optim.Adam(autoencoder_params, lr=1e-3)

In [None]:
def train_step(encoder, decoder, train_loader, loss_fn, optimizer, device):
    """
    Performs a single training step
    """
    encoder.train()
    decoder.train()

    for batch_idx, (train_img, target_img) in enumerate(train_loader):
        train_img = train_img.to(device)
        target_img = target_img.to(device)
        
        optimizer.zero_grad()

        enc_output = encoder(train_img)
        dec_output = decoder(enc_output)
        
        loss = loss_fn(dec_output, target_img)

        loss.backward()
        optimizer.step()

    return loss.item()

In [None]:
def val_step(encoder, decoder, val_loader, loss_fn, device):
    """
    Performs a single validation step
    """

    encoder.eval()
    decoder.eval()
    
    with torch.no_grad():
        for batch_idx, (train_img, target_img) in enumerate(val_loader):

            train_img = train_img.to(device)
            target_img = target_img.to(device)

            enc_output = encoder(train_img)
            dec_output = decoder(enc_output)

            loss = loss_fn(dec_output, target_img)

    return loss.item()

In [None]:
train_losses, val_losses = [], []

for epoch in range(10):
    
        train_loss = train_step(encoder, decoder, train_loader, loss_fn, optimizer, device=device)
        train_losses.append(train_loss)
        
        val_loss = val_step(encoder, decoder, val_loader, loss_fn, device=device)
        val_losses.append(val_loss)
        
        if(epoch % 10 == 0):
            print("Epoch:{0:3d}, Train_Loss:{1:1.3f}, Valid_Loss:{2:1.3f}"
                  .format(epoch, train_loss, val_loss))

# Plot Losses

In [None]:
plt.figure(figsize=(8,8))
plt.plot(train_losses, '-o', label='Train Losses')
plt.plot(val_losses, 'g-o', label='Valid Losses')
plt.legend()

# Embeddings

In [None]:
def create_embedding(encoder, full_loader, embedding_dim, device):
    """
    Creates embedding using encoder from dataloader.
    Returns: Embedding of size (num_images_in_loader + 1, c, h, w)
    """

    encoder.eval()

    embedding = torch.randn(embedding_dim)
    
    with torch.no_grad():
        for batch_idx, (train_img, target_img) in enumerate(full_loader):
            train_img = train_img.to(device)
            enc_output = encoder(train_img).cpu()
            embedding = torch.cat((embedding, enc_output), 0)
    
    return embedding

In [None]:
embedding_shape = (1, 256, 11, 4)
embeddings = create_embedding(encoder, full_loader, embedding_shape, device)

# k-nearest neighbors

In [None]:
def get_similar_image_idx(image_tensor, num_images, embeddings, device):
    """
    Given an image and number of similar images to search.
    Returns the num_images closest neares images.
    """
    image_tensor = image_tensor.unsqueeze(0)
    image_tensor = image_tensor.type(torch.cuda.FloatTensor)
    
    with torch.no_grad():
        image_embedding = encoder(image_tensor).cpu().detach().numpy()
        
    flattened_embedding = image_embedding.reshape((image_embedding.shape[0], -1))

    knn = NearestNeighbors(n_neighbors=num_images, metric="cosine")
    knn.fit(embeddings.reshape(embeddings.shape[0], -1))

    _, indices = knn.kneighbors(flattened_embedding)
    indices_list = indices.tolist()
    return indices_list[0][0]

In [None]:
img_idx = 0

similar_idx = get_similar_image_idx(alike_ds[img_idx][0], 1, embeddings, device)

In [None]:
def draw_similar_idx(df, img_idx, similar_idx):
    imgs = [df[img_idx][0].permute(1, 2, 0), df[similar_idx][0].permute(1, 2, 0)]
    f, ax = plt.subplots(1, 2, figsize=(12, 12))
    for ix, img in enumerate(imgs):
        ax[ix].imshow(img)
        ax[ix].axis('off')

In [None]:
draw_similar_idx(alike_ds, img_idx, similar_idx)