In [1]:
import torch.utils.data as data
from pycocotools.coco import COCO
from PIL import Image
import os
import nltk
import pickle
import torch
from torchvision import transforms

In [2]:
class Vocabulary(object):
    """Simple vocabulary wrapper."""
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

In [3]:
class CoCoDataset(data.Dataset):
    
    def __init__(self, root, json, vocab, transform=None):
        self.root = root
        self.coco = COCO(json)
        self.ids = list(self.coco.anns.keys())
        self.vocab = vocab
        self.transform = transform

    def __getitem__(self, index):
        coco = self.coco
        vocab = self.vocab
        ann_id = self.ids[index]
        caption = coco.anns[ann_id]['caption']
        img_id = coco.anns[ann_id]['image_id']
        path = coco.loadImgs(img_id)[0]['file_name']

        image = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        # Convert caption (string) to word ids.
        tokens = nltk.tokenize.word_tokenize(str(caption).lower())
        caption = []
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        target = torch.Tensor(caption)
        return image, target

    def __len__(self):
        return len(self.ids)

In [4]:
json = 'annotations/captions_train2014.json'
root = 'resized_images/train2014/'

with open('data/vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)
    
transform = transforms.Compose([ 
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor(), 
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])

coco = CoCoDataset(root, json, vocab, transform)

loading annotations into memory...
Done (t=0.51s)
creating index...
index created!


In [5]:
test_iter = iter(coco)
next(test_iter)

(
 ( 0 ,.,.) = 
   1.6495  1.6324  1.6324  ...  -1.0904 -1.0904 -1.0904
   1.7694  1.7180  1.6838  ...  -1.0733 -1.0904 -1.0904
   1.5639  1.5468  1.5639  ...  -1.0733 -1.0904 -1.0904
            ...             ⋱             ...          
   1.3413  1.3242  1.3584  ...   1.2899  1.2043  1.2899
   1.5982  1.6153  1.5810  ...   1.2214  1.3070  1.3413
   1.4440  1.3927  1.3755  ...   1.3242  1.1872  1.2385
 
 ( 1 ,.,.) = 
   1.8158  1.7983  1.7983  ...  -0.2150 -0.2150 -0.2150
   1.9384  1.8859  1.8508  ...  -0.1975 -0.2150 -0.2150
   1.7283  1.7108  1.7283  ...  -0.1975 -0.2150 -0.2150
            ...             ⋱             ...          
   1.5007  1.4832  1.5182  ...   1.4307  1.3256  1.4132
   1.7633  1.7808  1.7458  ...   1.3606  1.4657  1.5007
   1.6057  1.5532  1.5357  ...   1.4657  1.3431  1.3957
 
 ( 2 ,.,.) = 
   2.0300  2.0125  2.0125  ...   1.0714  1.0714  1.0714
   2.1520  2.0997  2.0648  ...   1.0888  1.0714  1.0714
   1.9428  1.9254  1.9428  ...   1.0888  1.0714  1.0714


In [6]:
def collate_fn(data):
    # Sort a data list by caption length (descending order).
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor).
    print(images)
    images = torch.stack(images, 0)

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]        
    return images, targets, lengths

In [7]:
batch_size = 5
shuffle = True
num_workers = 2

data_loader = torch.utils.data.DataLoader(dataset=coco, 
                                          batch_size=batch_size,
                                          shuffle=shuffle,
                                          num_workers=num_workers,
                                          collate_fn=collate_fn)

In [8]:
next(iter(data_loader))

(
( 0 ,.,.) = 
  0.2967  0.2967  0.2967  ...   0.3823  0.3652  0.3994
  0.2967  0.2967  0.3138  ...   0.3823  0.3823  0.4166
  0.2967  0.3138  0.3138  ...   0.3823  0.3823  0.4166
           ...             ⋱             ...          
 -0.2513 -0.2513 -0.2513  ...  -0.2856 -0.2856 -0.2856
 -0.2513 -0.2342 -0.2171  ...  -0.1486 -0.1657 -0.1486
 -0.1486 -0.1657 -0.1828  ...  -0.0629 -0.0629 -0.1143

( 1 ,.,.) = 
  0.6078  0.6078  0.6078  ...   0.6954  0.6779  0.6604
  0.6078  0.6078  0.6254  ...   0.6954  0.6954  0.6779
  0.6078  0.6254  0.6254  ...   0.6954  0.6954  0.6779
           ...             ⋱             ...          
 -0.1275 -0.1275 -0.1275  ...  -0.1275 -0.1275 -0.1275
 -0.1275 -0.1099 -0.0924  ...  -0.0224 -0.0399 -0.0224
 -0.0224 -0.0399 -0.0574  ...   0.0651  0.0651  0.0126

( 2 ,.,.) = 
  0.9842  0.9842  0.9842  ...   1.0714  1.0539  1.0191
  0.9842  0.9842  1.0017  ...   1.0714  1.0714  1.0365
  0.9842  1.0017  1.0017  ...   1.0714  1.0714  1.0365
           ...        

)
(
( 0 ,.,.) = 
  1.8208  1.8379  1.8379  ...   1.6838  1.6838  1.6667
  1.8208  1.8208  1.8208  ...   1.6838  1.6838  1.6838
  1.8037  1.8037  1.8037  ...   1.7009  1.6838  1.6838
           ...             ⋱             ...          
  0.0398  0.0227  0.0056  ...   0.4679  0.4508  0.4508
  0.0741  0.0912  0.0912  ...   0.3481  0.3652  0.3994
  0.0912  0.1254  0.1768  ...   0.3652  0.3823  0.4337

( 1 ,.,.) = 
  2.0434  2.0609  2.0609  ...   1.9734  1.9734  1.9559
  2.0434  2.0434  2.0434  ...   1.9734  1.9734  1.9734
  2.0259  2.0259  2.0259  ...   1.9909  1.9734  1.9734
           ...             ⋱             ...          
 -0.7402 -0.7577 -0.7752  ...  -0.4776 -0.4951 -0.4951
 -0.7052 -0.6877 -0.6527  ...  -0.6001 -0.5826 -0.5476
 -0.6877 -0.6176 -0.5651  ...  -0.5826 -0.5651 -0.5126

( 2 ,.,.) = 
  2.3786  2.3960  2.3960  ...   2.3611  2.3611  2.3437
  2.3786  2.3786  2.3786  ...   2.3611  2.3611  2.3611
  2.3611  2.3611  2.3611  ...   2.3786  2.3611  2.3611
           ...      

)


(
 ( 0 , 0 ,.,.) = 
  -1.2959 -1.2959 -1.2959  ...   0.7419  0.7248  0.7248
  -1.2959 -1.2959 -1.2959  ...   0.7419  0.7248  0.7248
  -1.3130 -1.3130 -1.3130  ...   0.7591  0.7591  0.7591
            ...             ⋱             ...          
  -1.8953 -1.9809 -1.9638  ...   2.0434  2.0263  1.9920
  -1.8953 -1.9980 -1.9809  ...   1.9749  1.9920  2.0263
  -1.9295 -1.9809 -1.9809  ...   1.9064  1.9407  2.0263
 
 ( 0 , 1 ,.,.) = 
  -1.3704 -1.3704 -1.3704  ...   1.1681  1.1506  1.1506
  -1.3704 -1.3704 -1.3704  ...   1.1681  1.1506  1.1506
  -1.3880 -1.3880 -1.3880  ...   1.1856  1.1856  1.1856
            ...             ⋱             ...          
  -1.2129 -1.3004 -1.3004  ...   2.1660  2.1485  2.1134
  -1.1954 -1.3004 -1.3179  ...   2.1310  2.1485  2.1835
  -1.2304 -1.3179 -1.3529  ...   2.0784  2.1134  2.2010
 
 ( 0 , 2 ,.,.) = 
  -1.3164 -1.3164 -1.3164  ...   1.3328  1.3154  1.3154
  -1.3164 -1.3164 -1.3164  ...   1.3328  1.3154  1.3154
  -1.3339 -1.3339 -1.3339  ...   1.3502  1.3

(
( 0 ,.,.) = 
  2.2318  2.2489  2.2489  ...   2.2489  2.2489  2.2489
  2.2318  2.2489  2.2489  ...   2.2489  2.2489  2.2489
  2.2318  2.2489  2.2489  ...   2.2489  2.2489  2.2489
           ...             ⋱             ...          
 -0.6452 -0.4739 -0.5253  ...  -0.0287 -0.0458  0.0569
 -0.7308 -0.3198 -0.0972  ...   0.2111  0.3823  0.5022
 -0.8507 -0.5253 -0.4226  ...   0.0398  0.3481  0.5193

( 1 ,.,.) = 
  2.4111  2.4111  2.4286  ...   2.4286  2.4286  2.4286
  2.4111  2.4111  2.4286  ...   2.4286  2.4286  2.4286
  2.4111  2.4111  2.4286  ...   2.4286  2.4286  2.4286
           ...             ⋱             ...          
 -0.7402 -0.4776 -0.4951  ...  -0.2150 -0.2325 -0.1275
 -0.7402 -0.3025 -0.0049  ...   0.0476  0.2227  0.3452
 -0.8627 -0.4601 -0.2850  ...  -0.0574  0.2577  0.4328

( 2 ,.,.) = 
  2.6226  2.5877  2.6051  ...   2.6400  2.6400  2.6400
  2.6226  2.5877  2.6051  ...   2.6400  2.6400  2.6400
  2.6226  2.5877  2.6051  ...   2.6400  2.6400  2.6400
           ...        

In [9]:
[vocab.idx2word[i] for i in [1, 7196, 7155, 9599, 742, 7196, 8829, 9825, 5580, 6602, 7196, 6154, 6511, 2]]

['<start>',
 'a',
 'cat',
 'laying',
 'on',
 'a',
 'bed',
 'in',
 'front',
 'of',
 'a',
 'television',
 '.',
 '<end>']