In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
# from torch.nn.utils.rnn import pack_padded_sequence
from model import ImageCNN,MatchCNN

import argparse
import os
import pickle
from data_loader import get_loader,CocoDataset
from build_vocab import Vocabulary
from torchvision import transforms
import time
from pycocotools.coco import COCO
from PIL import Image
import nltk
from random import shuffle

In [3]:
"""load coco dataset"""
data_dir = "../data/coco/"
annotation_file = data_dir + "annotations/captions_train2014.json"
coco=COCO(annotation_file)

loading annotations into memory...
Done (t=0.58s)
creating index...
index created!


In [4]:
# anns = coco.anns
# imgs = coco.imgs

In [5]:
"""extract 100 imgid and corresponding 500 captionid"""
sample_num = 100
caption_num = sample_num * 5
img_ids_all = list(coco.imgs.keys())
shuffle(img_ids_all)
img_ids = []
ann_ids = []

In [6]:
for key in img_ids_all:
    temp = coco.getAnnIds(key)
    if(len(temp) != 5):
        continue
    ann_ids.append(temp)
    img_ids.append(key)
    if(len(img_ids) == sample_num):
        break

In [7]:
"""preprocess images"""
image_dir = data_dir + "resized2014/"
imgs = []

 # Image preprocessing
transform = transforms.Compose([ 
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(), 
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), 
                         (0.229, 0.224, 0.225))])

for i, img_id in enumerate(img_ids):
    img_new = {}
    img = coco.imgs[img_id]
    image = Image.open(image_dir + img["file_name"]).convert("RGB")
    image = transform(image)
    img_new["ann_ids"] = ann_ids[i]
    img_new["data"] = image
    img_new["id"] = img_id
    imgs.append(img_new)

In [8]:
"""preprocess annotations"""
vocab_file = "../data/coco/vocab.pkl"
pad_len = 62
# Load vocabulary wrapper.
with open(vocab_file, 'rb') as f:
    vocab = pickle.load(f)

anns = np.zeros((caption_num, pad_len),dtype = int)
for i,ann_ids_image in enumerate(ann_ids):
    for j,ann_id in enumerate(ann_ids_image):
        caption_str = coco.anns[ann_id]["caption"]
        tokens = nltk.tokenize.word_tokenize(str(caption_str).lower())
        caption = []
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        caption = np.array(caption)
        anns[i*5 + j][:len(tokens) + 2] = caption[:]


In [9]:
"""parameters"""
image_vector_size = 256
embed_size = 100
margin = 0.5
batch_size = 10
epochs = 1
vocab_size = 9956
momentum=0.9
lr = 0.0001
pad_len = 62
num_workers = 2
batch_size = 100

"""set model"""
imageCNN = ImageCNN(image_vector_size=image_vector_size)
matchCNN = MatchCNN(embed_size = embed_size, 
                    image_vector_size = image_vector_size, 
                    vocab_size = vocab_size, 
                    pad_len = pad_len)

if torch.cuda.is_available():
    print("cuda is available")
    imageCNN = imageCNN.cuda()
    matchCNN = matchCNN.cuda()

imageCNN.eval()
matchCNN.eval()


cuda is available


MatchCNN (
  (embed): Embedding(9956, 100)
  (muti_conv1_word): Linear (556 -> 200)
  (conv2_word): Linear (600 -> 300)
  (conv3_word): Linear (900 -> 300)
  (linear1_word): Linear (1800 -> 400)
  (linear2_word): Linear (400 -> 1)
  (conv1_phs): Linear (300 -> 200)
  (muti_conv2_phs): Linear (856 -> 300)
  (conv3_phs): Linear (900 -> 300)
  (linear1_phs): Linear (1800 -> 400)
  (linear2_phs): Linear (400 -> 1)
  (conv1_phl): Linear (300 -> 200)
  (conv2_phl): Linear (600 -> 300)
  (muti_conv3_phl): Linear (1156 -> 300)
  (linear1_phl): Linear (1800 -> 400)
  (linear2_phl): Linear (400 -> 1)
  (conv1_sen): Linear (300 -> 200)
  (conv2_sen): Linear (600 -> 300)
  (conv3_sen): Linear (900 -> 300)
  (muti_linear1_sen): Linear (2056 -> 400)
  (linear2_sen): Linear (400 -> 1)
)

In [10]:
"""load models"""
model_path = "../models"
imageCNN.load_state_dict(torch.load(os.path.join(model_path, 'imageCNN1513584698-2-0.099086.pkl')))
matchCNN.load_state_dict(torch.load(os.path.join(model_path, 'matchCNN1513584698-2-0.099086.pkl')))

In [11]:
"""extract image feature"""
img_features = torch.zeros((sample_num, image_vector_size))
img_input = torch.zeros((batch_size, 3, 224, 224))
for i, img in enumerate(imgs):
    img_input[i%batch_size] = img["data"]
    if((i+1)%batch_size == 0):
        img_features[i+1-batch_size:i+1] = imageCNN(Variable(img_input).cuda()).data[:]
    

In [12]:
scores = torch.zeros((sample_num, sample_num*5))
for i,img in enumerate(img_features):
    img.unsqueeze_(0)
    img = img.repeat(batch_size,1)
    img = Variable(img).cuda()
    for j in range(int(caption_num/batch_size)):
        ann_input = Variable(torch.from_numpy(anns[j*batch_size:(j+1)*batch_size])).cuda()

        scores[i][j*batch_size:(j+1)*batch_size] = matchCNN(img, ann_input).data[:]

In [20]:
scores[1]


 -4.3912
 -8.0284
 -4.1740
 -5.6836
 -5.5024
 -3.4043
 -3.3955
 -4.0552
 -3.2483
 -3.3294
 -6.0023
 -4.3673
 -4.8983
 -5.9815
 -4.6593
 -3.5766
 -2.9850
 -3.7978
 -3.3847
 -3.0843
 -5.6165
 -7.9343
 -7.3320
 -7.3406
 -4.5620
 -4.7784
 -3.5087
 -3.6286
 -4.9043
 -3.7532
-10.8512
-10.7423
 -9.6521
 -8.3601
 -5.6349
 -3.9638
 -3.9641
 -2.9482
 -4.3373
 -2.5172
 -3.0222
 -3.9355
 -5.1903
 -5.8846
 -5.2306
 -5.0450
 -4.9086
 -3.7830
 -4.2595
 -4.8241
 -5.4933
 -4.1694
 -2.5362
 -4.2888
 -5.0269
 -4.4679
 -6.6931
 -5.9393
 -4.6897
 -5.0898
 -8.1498
 -5.5907
 -7.2148
 -6.2689
 -4.6996
 -2.9625
 -3.8484
 -3.9312
 -5.8361
 -5.9312
 -3.4390
 -4.0774
 -3.8916
 -4.3954
 -3.8559
 -5.9995
 -2.8812
 -7.2780
 -5.4364
 -4.7646
 -4.9933
 -4.8139
 -4.5694
 -5.8373
 -4.9829
 -3.5730
 -3.0351
 -2.7821
 -6.0184
 -3.2136
 -4.5833
 -4.3143
 -3.7560
 -4.4999
 -5.5077
 -5.1256
 -4.9243
 -5.1443
 -4.2053
 -4.1429
 -9.1681
 -6.8289
 -8.8536
 -7.3052
 -6.2879
 -3.4121
 -4.6529
 -6.1512
 -6.8170
 -6.7554
 -8.6813


In [16]:
ranks

array([[465,  30, 187, ..., 364, 136,  37],
       [465, 197,  30, ..., 136, 175, 296],
       [465, 187,  30, ..., 312, 353, 136],
       ..., 
       [187, 319, 465, ...,  84, 139, 398],
       [465, 187, 197, ...,  26, 494, 163],
       [465, 187, 188, ...,  98, 179, 136]])

In [13]:
score_np = scores.numpy()

ranks = np.argsort(score_np)

ranks_image = np.zeros((sample_num,5),dtype=int)

for i in range(sample_num):
    ranks_image[i][:] = ranks[i][i*5:(i+1)*5]

In [14]:
r1 = len(ranks_image[ranks_image==1])/caption_num * 100
r5 = len(ranks_image[ranks_image<=5])/caption_num * 100
r10 = len(ranks_image[ranks_image<=10])/caption_num * 100
med = np.mean(ranks_image)

print("r1:",r1)
print("r5:",r5)
print("r10:",r10)
print("med:",med)

r1: 0.0
r5: 1.2
r10: 1.7999999999999998
med: 242.772


In [None]:
scores