In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (20, 20)

import os
import json
import nltk
import numpy as np 
import pandas as pd
from PIL import Image
from scipy.spatial.distance import cdist
from tqdm import tqdm_notebook as tqdm

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms

nltk.download('punkt')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load coco images and captions

In [None]:
with open('/mnt/efs/images/coco/annotations/captions_val2014.json') as f:
    meta = json.load(f)

df = (pd.merge(pd.DataFrame(meta['images']).set_index('id'),
               pd.DataFrame(meta['annotations']).set_index('image_id'), 
               left_index=True, right_index=True)
      .reset_index()
      [['caption', 'file_name']]
     )

df['file_name'] = '/mnt/efs/images/coco/val2014/' + df['file_name']

df['caption'] = (df['caption']
                 .apply(lambda x: ''.join([c for c in x if c.isalpha() or c.isspace()]))
                 .apply(str.lower)
                 .apply(lambda x: ' '.join(x.split()))
                )

# train test splits

In [None]:
split_ratio = 0.8
train_size = int(split_ratio * len(df))

train_df = df.loc[:train_size]
test_df  = df.loc[train_size:]
len(train_df), len(test_df)

# load InferSent model

In [None]:
from InferSent import InferSent

In [None]:
MODEL_PATH =  '/mnt/efs/models/infersent2.pkl'

params_model = {'bsize': 1024, 
                'word_emb_dim': 300, 
                'enc_lstm_dim': 2048,
                'pool_type': 'max', 
                'dpout_model': 0.0, 
                'version': 2}

infersent_model = InferSent(params_model)
infersent_model.load_state_dict(torch.load(MODEL_PATH))

In [None]:
W2V_PATH = '/mnt/efs/nlp/word_vectors/fasttext/crawl-300d-2M.vec'
infersent_model.set_w2v_path(W2V_PATH)

In [None]:
infersent_model.build_vocab_k_words(K=100000)

In [None]:
infersent_model = infersent_model.to(device)

# embed captions with infersent

In [None]:
train_embeddings = infersent_model.encode(train_df['caption'].values, tokenize=True)
test_embeddings = infersent_model.encode(test_df['caption'].values, tokenize=True)

len(train_embeddings), len(test_embeddings)

# pytorch datasets and dataloaders

### dataset

In [None]:
class CaptionsDataset(Dataset):
    def __init__(self, path_df, caption_embeddings, 
                 transform=transforms.ToTensor()):
        self.ids = path_df.index.values
        self.image_paths = path_df['file_name'].values
        self.titles = path_df['caption'].values
        self.caption_embeddings = caption_embeddings
        self.transform = transform

    def __getitem__(self, index):
        image = Image.open(self.image_paths[index]).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        target = self.caption_embeddings[index]
        return image, target

    def __len__(self):
        return len(self.ids)

In [None]:
transform = transforms.Compose([transforms.RandomResizedCrop(224, scale=[0.5, 0.9]),
                                transforms.RandomHorizontalFlip(),
                                transforms.RandomGrayscale(0.25),
                                transforms.ToTensor()])

In [None]:
train_dataset = CaptionsDataset(train_df, train_embeddings, transform=transform)
test_dataset = CaptionsDataset(test_df, test_embeddings, transform=transform)

In [None]:
train_dataset.__getitem__(0)

### dataloader

In [None]:
batch_size = 128

train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=5)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=batch_size,
                         num_workers=5)

# create DeViSE model

In [None]:
backbone = models.vgg16_bn(pretrained=True).features

In [None]:
for param in backbone[:34].parameters():
    param.requires_grad = False

In [None]:
class DeViSE(nn.Module):
    def __init__(self, backbone, target_size=300):
        super(DeViSE, self).__init__()
        self.backbone = backbone
        self.head = nn.Sequential(
            nn.Linear(in_features=(25088), out_features=target_size*2),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(in_features=target_size*2, out_features=target_size),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(in_features=target_size, out_features=target_size),
        )

    def forward(self, x):
        x = self.backbone(x)
        x = x.view(x.size(0), -1)
        x = self.head(x)
        x = x / x.max()
        return x

In [None]:
devise_model = DeViSE(backbone, target_size=4096).to(device)

# train

In [None]:
losses = []

def train(model, train_loader, n_epochs, loss_function, 
          additional_metric, optimiser, device=device):
    '''
    do some training
    '''
    model.train()
    for epoch in range(n_epochs):
        loop = tqdm(train_loader)
        for data, target in loop:
            data, target, flags = (data.cuda(non_blocking=True), 
                                   target.cuda(non_blocking=True), 
                                   torch.ones(len(target)).cuda(non_blocking=True))

            optimiser.zero_grad()
            prediction = model(data)

            loss = loss_function(prediction, target, flags)
            mean_sq_error = additional_metric(prediction, target)
            losses.append([loss.item(), mean_sq_error.item()])

            loss.backward()
            optimiser.step()

            loop.set_description('Epoch {}/{}'.format(epoch + 1, n_epochs))
            loop.set_postfix(loss=loss.item(), mse=mean_sq_error.item())

In [None]:
torch.backends.cudnn.benchmark = True

trainable_parameters = filter(lambda p: p.requires_grad, devise_model.parameters())

loss_function, mse = nn.CosineEmbeddingLoss(), nn.MSELoss()
optimiser = optim.Adam(trainable_parameters, lr=0.001)

In [None]:
train(model=devise_model,
      train_loader=train_loader,
      loss_function=loss_function,
      additional_metric=mse, 
      optimiser=optimiser,
      n_epochs=3)

In [None]:
loss_data = pd.DataFrame(losses).rolling(window=15).mean()
loss_data.columns = ['cosine loss', 'mse']
ax = loss_data.plot(subplots=True);

ax[0].set_xlim(0,);
ax[0].set_ylim(0.3, 0.6);
ax[1].set_ylim(0,);

# evaluate on test set

In [None]:
preds = []
test_loss = []

with torch.no_grad():
    test_loop = tqdm(test_loader)
    for data, target in test_loop:
        data, target, flags = (data.cuda(non_blocking=True),
                               target.cuda(non_blocking=True),
                               torch.ones(len(target)).cuda(non_blocking=True))

        prediction = devise_model.eval()(data)
        loss = loss_function(prediction, target, flags)

        preds.append(prediction.cpu().data.numpy())
        test_loss.append(loss.item())

        test_loop.set_description('Test set')
        test_loop.set_postfix(loss=loss.item())

In [None]:
preds = np.concatenate(preds).reshape(-1, 4096)
np.mean(test_loss)

# run a test search

In [None]:
def search(query):
    query_embedding = infersent_model.encode([query], tokenize=True)

    distances = cdist(query_embedding, preds, 'cosine').squeeze()
    nearby_image_paths = test_df['file_name'].values[np.argsort(distances)][:20]
    nearby_images = [np.array((Image.open(path)
                               .convert('RGB')
                               .resize((224, 224), Image.BILINEAR)))
                     for path in nearby_image_paths]

    return Image.fromarray(np.concatenate([np.concatenate(nearby_images[:5], axis=1),
                                           np.concatenate(nearby_images[5:10], axis=1),
                                           np.concatenate(nearby_images[10:15], axis=1),
                                           np.concatenate(nearby_images[15:20], axis=1)],
                                          axis=0))

In [None]:
search('a man playing tennis')