In [None]:
import torch
import torch.nn as nn
from torchvision.utils import make_grid

from model import ResNet50Encoder
from dataset import build_dataset
from nce import nce_retrieval
from caption_encoder import CaptionEncoder
import mixed_precision

SEQ_LEN = 20
N_RKHS = 1024

device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
    mixed_precision.enable_mixed_precision()

In [None]:
caption_encoder = CaptionEncoder(N_RKHS, SEQ_LEN, device=device, hidden_size=4096)
resnet50 = ResNet50Encoder(encoder_size=128, n_rkhs=N_RKHS, ndf=128 )
resnet50.to(device)
resnet50, _ = mixed_precision.initialize(resnet50, None)

ckpt = torch.load('checkpoints/ck4lxlpp_model.pth')
resnet50.load_state_dict(ckpt['resnet50'])
caption_encoder.fc.load_state_dict(ckpt['caption_fc'])
print('Checkpoint loaded.')


In [None]:
import os
from torchvision import datasets, transforms

batch_size = 1000
INTERP = 3

class VizTransforms:
    '''
    ImageNet dataset, for use with 128x128 full image encoder.
    '''
    def __init__(self):
        post_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.test_transform = transforms.Compose([
            transforms.Resize(146, interpolation=INTERP),
            transforms.CenterCrop(128),
            post_transform
        ])

        self.raw_trans = transforms.Compose([
            transforms.Resize(256, interpolation=INTERP),
            transforms.CenterCrop(256),
            transforms.ToTensor()
        ])
        
    
    def __call__(self, inp):
        out = self.test_transform(inp)
        raw = self.raw_trans(inp)
        return out, raw

transforms128 = VizTransforms()
test_dataset = datasets.CocoCaptions(
                    root=os.path.expanduser('~/data/coco/val2017'), 
                    annFile=os.path.expanduser('~/data/coco/annotations/captions_val2017.json'), 
                    transform=transforms128)

loader = \
    torch.utils.data.DataLoader(dataset=test_dataset,
                                batch_size=batch_size,
                                shuffle=True,
                                pin_memory=True,
                                drop_last=True,
                                num_workers=16)

In [None]:
(transformed_imgs, raw_imgs), captions = next(iter(loader))
transformed_imgs = transformed_imgs.to(device)
encoded_imgs, r7 = resnet50(transformed_imgs)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
%matplotlib inline

def nce_retrieval_reverse(encoded_images, encoded_queries, top_k=5):
    batch_size = encoded_images.size(0)
    n_rkhs = encoded_images.size(1)
    n_queries = encoded_queries.size(0)

    # (bs, 1, 1, rkhs) -> (bs, rkhs)
    encoded_images = encoded_images.reshape(batch_size, n_rkhs)
    encoded_images = F.normalize(encoded_images)

    scores = torch.mm(encoded_images, encoded_queries.t())
    cos_sims_idx = torch.sort(scores, dim=1, descending=True)[1]
    cos_sis_idx = cos_sims_idx[:, :top_k]
    return cos_sis_idx

def show(img):
    npimg = img.numpy()
    fig, ax = plt.subplots(figsize=(30, 60))
    ax.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    
def visualize(encoded_queries, encoded_imgs, raw_imgs):    
    top_k = 5
    top_k_idx = nce_retrieval(encoded_imgs, encoded_queries, top_k)
    top_k_idx = torch.flatten(top_k_idx)
    matches = raw_imgs[top_k_idx]
    viz = make_grid(matches, nrow=top_k)
    show(viz.cpu()) 

def visualize_captions(raw_captions, encoded_queries, encoded_imgs, raw_imgs):
    viz = make_grid(raw_imgs, nrow=raw_imgs.size(0))
    show(viz.cpu()) 
    top_k = 5
    top_k_idx = nce_retrieval_reverse(encoded_imgs, encoded_queries, top_k)
    for i, e in enumerate(top_k_idx):
        print('\n'.join([raw_captions[idx] for idx in e]))
        print('-----------')

In [None]:
# Caption -> Image retrival
queries = [
   'tennis'
]
encoded_queries, _ = caption_encoder(queries)
visualize(encoded_queries, encoded_imgs, raw_imgs) 

In [None]:
# Image -> Caption retrieval
encoded_queries, _ = caption_encoder(captions[0])
visualize_captions(captions[0], encoded_queries, encoded_imgs[0:6], raw_imgs[0:6])    