In [1]:
from __future__ import print_function

import matplotlib; matplotlib.use('Agg')
import os
import os.path as osp
import argparse

import numpy as np
import pickle 
import time
 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm 
 
from torchvision import models                                                                     
from convcap import convcap
from vggfeats import Vgg16Feats
from coco_loader import Scale
from PIL import Image
from test_beam import repeat_img
from beamsearch import beamsearch 

In [2]:
parser = argparse.ArgumentParser(description='PyTorch Convolutional Image \
    Captioning Model -- Caption Me')

parser.add_argument('model_dir', help='output directory to save models & results')
# parser.add_argument('image_dir', help='directory containing input images \
#                     supported formats .png, .jpg, .jpeg, .JPG')

parser.add_argument('-g', '--gpu', type=int, default=0,\
                    help='gpu device id')

parser.add_argument('--beam_size', type=int, default=1, \
                    help='beam size to use to generate captions') 

parser.add_argument('--attention', dest='attention', action='store_true', \
                    help='set caption model with attention in use (by default set)')

parser.add_argument('--no-attention', dest='attention', action='store_false', \
                    help='set caption model without attention in use')

parser.set_defaults(attention=True)

args, _ = parser.parse_known_args()

In [3]:
parser = argparse.ArgumentParser(description='PyTorch Convolutional Image \
    Captioning Model -- Caption Me')

parser.add_argument('model_dir', help='output directory to save models & results')

# parser.add_argument('image_dir', help='directory containing input images \
#                     supported formats .png, .jpg, .jpeg, .JPG')

parser.add_argument('-g', '--gpu', type=int, default=0,\
                    help='gpu device id')

parser.add_argument('--beam_size', type=int, default=1, \
                    help='beam size to use to generate captions') 

parser.add_argument('--attention', dest='attention', action='store_true', \
                    help='set caption model with attention in use (by default set)')

parser.add_argument('--no-attention', dest='attention', action='store_false', \
                    help='set caption model without attention in use')

parser.set_defaults(attention=True)

args, _ = parser.parse_known_args()

In [4]:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

In [5]:
bestmodelfn = osp.join('output/bestmodel.pth')

In [6]:
bestmodelfn

'output/bestmodel.pth'

In [7]:
def load_images(image_dir):
    exts = ['.jpg', '.jpeg', '.png']
    imgs = torch.FloatTensor(torch.zeros(0, 3, 224, 224))
    imgs_fn = []

    img_transforms = transforms.Compose([
        Scale([224, 224]),
        transforms.ToTensor(),
        transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], 
        std = [ 0.229, 0.224, 0.225 ])
    ])
    
    for fn in os.listdir(image_dir):
        if(osp.splitext(fn)[-1].lower() in exts):
            imgs_fn.append(os.path.join(image_dir, fn))
            img = Image.open(os.path.join(image_dir, fn)).convert('RGB')
            img = img_transforms(img)
            imgs = torch.cat([imgs, img.unsqueeze(0)], 0)

    return imgs, imgs_fn

In [8]:
imgs, imgs_fn = load_images('my_image')

In [9]:
#For trained model released with the code
batchsize = 1
max_tokens = 15
num_layers = 3 
worddict_tmp = pickle.load(open('data/wordlist.p', 'rb'))
wordlist = [l for l in iter(worddict_tmp.keys()) if l != '</S>']
wordlist = ['EOS'] + sorted(wordlist)
numwords = len(wordlist)

In [10]:
model_imgcnn = Vgg16Feats()
model_imgcnn.cuda() 

model_convcap = convcap(numwords, num_layers, is_attention = args.attention)
model_convcap.cuda()

print('[DEBUG] Loading checkpoint %s' % bestmodelfn)
checkpoint = torch.load(bestmodelfn)
model_convcap.load_state_dict(checkpoint['state_dict'])
model_imgcnn.load_state_dict(checkpoint['img_state_dict'])

model_imgcnn.train(False) 
model_convcap.train(False)

[DEBUG] Loading checkpoint output/bestmodel.pth


convcap(
  (emb_0): Embedding(9221, 512, padding_idx=0)
  (emb_1): Linear(in_features=512, out_features=512, bias=True)
  (imgproj): Linear(in_features=4096, out_features=512, bias=True)
  (resproj): Linear(in_features=1024, out_features=512, bias=True)
  (convs): ModuleList(
    (0): Conv1d(1024, 1024, kernel_size=(5,), stride=(1,), padding=(4,))
    (1): Conv1d(512, 1024, kernel_size=(5,), stride=(1,), padding=(4,))
    (2): Conv1d(512, 1024, kernel_size=(5,), stride=(1,), padding=(4,))
  )
  (attention): ModuleList(
    (0): AttentionLayer(
      (in_projection): Linear(in_features=512, out_features=512, bias=True)
      (out_projection): Linear(in_features=512, out_features=512, bias=True)
    )
    (1): AttentionLayer(
      (in_projection): Linear(in_features=512, out_features=512, bias=True)
      (out_projection): Linear(in_features=512, out_features=512, bias=True)
    )
    (2): AttentionLayer(
      (in_projection): Linear(in_features=512, out_features=512, bias=True)
      

In [11]:
pred_captions = []
for batch_idx, (img_fn) in tqdm(enumerate(imgs_fn), total=len(imgs_fn)):
    
    img = imgs[batch_idx, ...].view(batchsize, 3, 224, 224)

    img_v = Variable(img.cuda())
    imgfeats, imgfc7 = model_imgcnn(img_v)

    b, f_dim, f_h, f_w = imgfeats.size()
    imgfeats = imgfeats.unsqueeze(1).expand(b, args.beam_size, f_dim, f_h, f_w)
    imgfeats = imgfeats.contiguous().view(b*args.beam_size, f_dim, f_h, f_w)

    b, f_dim = imgfc7.size()
    imgfc7 = imgfc7.unsqueeze(1).expand(b, args.beam_size, f_dim)
    imgfc7 = imgfc7.contiguous().view(b*args.beam_size, f_dim)
    
    beam_searcher = beamsearch(args.beam_size, batchsize, max_tokens)

    wordclass_feed = np.zeros((args.beam_size*batchsize, max_tokens), dtype='int64')
    wordclass_feed[:,0] = wordlist.index('<S>') 
    outcaps = np.empty((batchsize, 0)).tolist()
    
    
    for j in range(max_tokens-1):
        wordclass = Variable(torch.from_numpy(wordclass_feed)).cuda()

        wordact, attn = model_convcap(imgfeats, imgfc7, wordclass)
        wordact = wordact[:,:,:-1]
        wordact_j = wordact[..., j]

        beam_indices, wordclass_indices = beam_searcher.expand_beam(wordact_j)  

        if len(beam_indices) == 0 or j == (max_tokens-2): # Beam search is over.
            generated_captions = beam_searcher.get_results()
            for k in range(batchsize):
                g = generated_captions[:, k]
                outcaps[k] = [wordlist[x] for x in g]
        else:
            wordclass_feed = wordclass_feed[beam_indices]
            imgfc7 = imgfc7.index_select(0, Variable(torch.cuda.LongTensor(beam_indices)))
            imgfeats = imgfeats.index_select(0, Variable(torch.cuda.LongTensor(beam_indices)))
            for i, wordclass_idx in enumerate(wordclass_indices):
                wordclass_feed[i, j+1] = wordclass_idx
                
    for j in range(batchsize):
        num_words = len(outcaps[j])
        if 'EOS' in outcaps[j]:
            num_words = outcaps[j].index('EOS')
        outcap = ' '.join(outcaps[j][:num_words])
        pred_captions.append({'img_fn': img_fn, 'caption': outcap})

  x = F.softmax(x.view(sz[0] * sz[1], sz[2]))
  beam_word_logprobs = self.logsoftmax(output).cpu().data.tolist()
100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  3.73it/s]


In [12]:
pred_captions

[{'img_fn': 'my_image\\image1.jpg',
  'caption': 'a group of people standing on top of a mountain'},
 {'img_fn': 'my_image\\image10.jpg',
  'caption': 'two men are drinking from a white table'},
 {'img_fn': 'my_image\\image2.jpg',
  'caption': 'a cat sitting on a ground next to a dead tree'},
 {'img_fn': 'my_image\\image3.jpg',
  'caption': 'a group of people sitting around a table with a laptop'},
 {'img_fn': 'my_image\\image4.jpg',
  'caption': 'a large building with a large building in the background'},
 {'img_fn': 'my_image\\image5.jpg',
  'caption': 'a plate of food with a fork and a fork'},
 {'img_fn': 'my_image\\image6.jpg',
  'caption': 'a table with a bunch of different types of items'},
 {'img_fn': 'my_image\\image7.jpg',
  'caption': 'a busy city street with cars and cars'},
 {'img_fn': 'my_image\\image8.jpg',
  'caption': 'two brown bears are sitting on a tree'},
 {'img_fn': 'my_image\\image9.jpg',
  'caption': 'a man is standing next to a large elephant'}]

In [13]:
resfile = osp.join('my_image', 'captions.txt')
with open(resfile, 'w') as fp:
    for item in pred_captions:
        fp.write('image: %s, caption: %s\n' % (item['img_fn'], item['caption']))