In [None]:
import glob

import torch
import torch.nn.functional as F
import tqdm

from PIL import Image
from torchvision import models
from torchvision import transforms

from IPython.display import display

from utils.imagenet import CLASSES

In [None]:
# we will use glob to get a list of images that match a regular expression
image_files = sorted(glob.glob('./images/*.jpg'))

In [None]:
def load_image(img_f, max_side=512):
    img = Image.open(img_f)
    img.thumbnail((max_side, max_side), Image.ANTIALIAS)
    
    return img


def embed_image(img, model, emb, prep):
    input_tensor = prep(img)
    
    # create a mini-batch as expected by the model
    # unsqueeze will insert a new dimension into 
    # our tensor
    input_batch = input_tensor.unsqueeze(0) 
    tags = []
    
    with torch.no_grad():
        _ = model(input_batch)
        e = emb.unit_tensor
            
    return e[0]

class IntermediateTensor(object):
    def __init__(self, layer):
        _ = layer.register_forward_hook(self.__hook)
        self.__tensor = None
        
    def __hook(self, module, inpt, output):
        self.__tensor = output
      
    @property
    def unit_tensor(self):
        t_out = self.__tensor.squeeze()
        
        # preserve batch dim after squeeze
        if t_out.ndim==1:
            t_out = t_out[None]
        
        return t_out.renorm_(2, 0, 1)
    
    @property
    def tensor(self):
        t_out = self.__tensor.squeeze()
        
        # preserve batch dim after squeeze
        if t_out.ndim==1:
            t_out = t_out[None]
        
        return t_out
        
        
def distance(e1, e2):
    """ Compute euclidian (L2) distance between two unit vectors
    
    Parameters:
        e1 (torch.Tensor): first tenosr
        e2 (torch.Tensor): second tenosr
    """
    return 2 * (1 - e1 @ e2.T)


class EmbeddingsDatabase(object):
    
    def __init__(self):
        self.table = {}
    
    def insert(self, key, data):
        self.table[key] = data
        
    def select(self, where, sort=True):
        scores = []

        # score all entries in a database
        for key, emb in self.table.items():
            d = distance(emb, where)
            scores.append((key, d))
            
        if sort:
            # descending sort 
            results = sorted(scores, key=lambda x: x[1])
        else:
            results = scores
            
        return results

In [None]:
# we will use renset50 model for embedding
res50_model = models.resnet50(pretrained=True)
res50_model = res50_model.eval()

embeddings_op = IntermediateTensor(res50_model.avgpool)

# define preprocessing pipeline
preprocess = transforms.Compose([
    # resize to 224px
    transforms.Resize(256),
    # put into 0..1 range
    transforms.ToTensor(),
    # scale into -1 .. 1
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

### Example

Let's take a similar and a dissimilar images and compute distance between them:

In [None]:
img_orig = load_image('./images/eiliv-sonas-aceron-gqxSUgngBPA-unsplash.jpg')
# content-wise similar image
img_similar = load_image('./images/duncan-kidd-Js4jgpksRGk-unsplash.jpg')
# content-wise dissimilar image
img_dissimilar = load_image('./images/taneli-lahtinen-0cSOFraG4uc-unsplash.jpg')

In [None]:
# img_orig.thumbnail((256, 256))
display(img_orig)

# img_similar.thumbnail((256, 256))
display(img_similar)

# img_dissimilar.thumbnail((256, 256))
display(img_dissimilar)

In [None]:
e_orig = embed_image(img_orig, res50_model, embeddings_op, preprocess)
e_similar = embed_image(img_similar, res50_model, embeddings_op, preprocess)
e_dissimilar = embed_image(img_dissimilar, res50_model, embeddings_op, preprocess)

In [None]:
print('orig, similar', distance(e_orig, e_similar))
print('orig, dissimilar', distance(e_orig, e_dissimilar))

In [None]:
database = EmbeddingsDatabase()

database_images = image_files[:-1]

for img_f in tqdm.tqdm_notebook(database_images):
    img = load_image(img_f)
    
    e = embed_image(img, res50_model, embeddings_op, preprocess)
    
    database.insert(img_f, e)

In [None]:
# embs for the second image
database.table[database_images[1]]

In [None]:
inpt_img = load_image(image_files[-1])
inpt_emb = embed_image(inpt_img, res50_model, embeddings_op, preprocess)

In [None]:
display(inpt_img)

In [None]:
results = database.select(where=inpt_emb)

# take top 3 matches
results = results[:3]

In [None]:
for img_f, score in results:
    img = load_image(img_f)
    display(img)