In [None]:
import glob

import torch
import torch.nn.functional as F
import tqdm

from PIL import Image
from torchvision import models
from torchvision import transforms

from IPython.display import display

from utils.imagenet import CLASSES

In [None]:
# we will use glob to get a list of images that match a regular expression
image_files = sorted(glob.glob('./images/*.jpg'))

In [None]:
# resize images to be no larger than 512px
max_side = 512

img = Image.open(image_files[3])
img.thumbnail((max_side, max_side), Image.ANTIALIAS)

In [None]:
display(img)

In [None]:
# we will use renset50 model for tagging
res50_model = models.resnet50(pretrained=True)
res50_model = res50_model.eval()

## NOTE: Preprocessing

Data preprocessing, in the context of image processing, means turning a sample of data into a format, suitable for a given network. This step is necessary to ensure that distribution of pixel in a test image matches the distribution that was used during training. 

In the case of ResNet50, it was trained on images where each pixel had been shifted and scaled from 0..255 range to *aproximately* -1..1, so we need 
to put our images into the same range.

In [None]:
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

In [None]:
# these values are computed from imagenet
# data samples by taking mean and std along
# each RGB channel
mean = [0.485, 0.456, 0.406] 
std = [0.229, 0.224, 0.225]

Let's sample 500 points from a normal distribution, shift and re-norm it into 0..255 range and plot:

In [None]:
N = 500

# original 0...255 distirbution
img_r = (np.random.randn(N)*std[0]*127.5 + 255.0*mean[0])
img_g = (np.random.randn(N)*std[1]*127.5 + 255.0*mean[1])
img_b = (np.random.randn(N)*std[2]*127.5 + 255.0*mean[2])


plt.figure(figsize=(10,5))

sns.distplot(img_r, color='red')
sns.distplot(img_g, color='green')
sns.distplot(img_b, color='blue')

plt.xlim(0.0, 255.0)

Now, let's normalize the inputs into -1..1:

In [None]:
# normalized -1..1 distribution
img_r_n = (img_r/255.0 - mean[0])/std[0]
img_g_n = (img_g/255.0 - mean[1])/std[1]
img_b_n = (img_b/255.0 - mean[2])/std[2]

plt.figure(figsize=(10,5))

sns.distplot(img_r_n, color='red')
sns.distplot(img_g_n, color='green')
sns.distplot(img_b_n, color='blue')


Getting back to our task at hand, the same computations need to be performed on our input data. Thankfully, PyTorch (more specifically, torchvision) has a **Transfoms** module that will do this for us:

In [None]:
preprocess = transforms.Compose([
    # resize to 224
    transforms.Resize(224),
    # put into 0..1 range
    transforms.ToTensor(),
    # scale into -1 .. 1
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


In [None]:
def load_image(img_f, max_side=512):
    img = Image.open(img_f)
    img.thumbnail((max_side, max_side), Image.ANTIALIAS)
    
    return img


def tags_for_image(img, model, prep, top_k=5):
    input_tensor = prep(img)
    
    # create a mini-batch as expected by the model
    # unsqueeze will insert a new dimension into 
    # our tensor
    input_batch = input_tensor.unsqueeze(0) 
    tags = []
    
    with torch.no_grad():
        output = model(input_batch)
        probs = F.softmax(output, dim=1)
        
        top_k_inds = probs[0].topk(5).indices.cpu().numpy()
        
        for ind in top_k_inds:
            items = [item.strip() for item in CLASSES[ind].split(',')]
            tags.extend(items)
            
    return tags

class TagsDatabase(object):
    
    def __init__(self):
        self.table = {}
    
    def insert(self, key, data):
        self.table[key] = data
        
    def select(self, where, sort=True):
        scores = []

        # score all entries in a database
        for key, tags in self.table.items():
            matches = sum([t in tags for t in where])

            scores.append((key, matches))
            
        if sort:
            # descending sort 
            results = sorted(scores, key=lambda x: x[1], reverse=True)
        else:
            results = scores
            
        return results

In [None]:
tags_for_image(img, res50_model, preprocess)

In [None]:
# here I model the 
# tags_database = {}
database = TagsDatabase()

for img_f in tqdm.tqdm_notebook(image_files[:-1]):
    img = load_image(img_f)
    
    tags = tags_for_image(img, res50_model, preprocess)
    
    # tags_database[img_f] = tags
    database.insert(img_f, tags)

In [None]:
# tags for the second image
database.table[image_files[1]]

In [None]:
display(load_image(image_files[1]))

In [None]:
inpt_img = load_image(image_files[-1])
inpt_tags = tags_for_image(inpt_img, res50_model, preprocess)

In [None]:
inpt_tags

In [None]:
results = database.select(where=inpt_tags)

# take top 3 matches
results = results[:3]

In [None]:
display(inpt_img)

In [None]:
for img_f, score in results:
    img = load_image(img_f)
    display(img)