# Data Loader Creation

In [17]:
import sys
sys.path.append('/opt/cocoapi/PythonAPI')
from pycocotools.coco import COCO
!pip install nltk
import nltk
nltk.download('punkt')
from data_loader import get_loader
from torchvision import transforms

# Preprocessing of training images
tt = transforms.Compose([ 
    transforms.Resize(256),                          
    transforms.RandomCrop(224),                      
    transforms.RandomHorizontalFlip(),               
    transforms.ToTensor(),                           
    transforms.Normalize((0.485, 0.456, 0.406),     
                         (0.229, 0.224, 0.225))])

# Set the minimum word count threshold.
min_word_threshold = 4

# Specify the batch size.
batch_size = 10

# Obtain the data loader.
data_loader = get_loader(transform=tt,
                         mode='train',
                         batch_size=batch_size,
                         vocab_threshold=min_word_threshold,
                         vocab_from_file=False)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
loading annotations into memory...
Done (t=0.84s)
creating index...
index created!
[0/414113] Tokenizing captions...
[100000/414113] Tokenizing captions...
[200000/414113] Tokenizing captions...
[300000/414113] Tokenizing captions...
[400000/414113] Tokenizing captions...
loading annotations into memory...


  0%|          | 882/414113 [00:00<01:36, 4300.31it/s]

Done (t=0.83s)
creating index...
index created!
Obtaining caption lengths...


100%|██████████| 414113/414113 [01:29<00:00, 4607.37it/s]


In [18]:
caption_smpl = 'A person doing a trick on a rail while riding a skateboard.'

In [19]:
import nltk

tokens_smpl = nltk.tokenize.word_tokenize(str(caption_smpl).lower())
print(tokens_smpl)

['a', 'person', 'doing', 'a', 'trick', 'on', 'a', 'rail', 'while', 'riding', 'a', 'skateboard', '.']


## Defining the start word of the caption

In [20]:
image_caption = []

start_word = data_loader.dataset.vocab.start_word
print('Special start word:', start_word)

image_caption.append(data_loader.dataset.vocab(start_word))
print(image_caption)

Special start word: <start>
[0]


## Adding the start symbol to all caption of the data loader dataset

In [21]:
image_caption.extend([data_loader.dataset.vocab(token) for token in tokens_smpl])
print(len(image_caption))
print(image_caption)

14
[0, 3, 98, 756, 3, 396, 39, 3, 1014, 207, 139, 3, 755, 18]


## Defining the end symbol of the caption

In [22]:
end_symbol = data_loader.dataset.vocab.end_word
print('Special end word:', end_symbol)

image_caption.append(data_loader.dataset.vocab(end_symbol))
print(image_caption)

Special end word: <end>
[0, 3, 98, 756, 3, 396, 39, 3, 1014, 207, 139, 3, 755, 18, 1]


### Conversion of list of tensors to long type int

In [23]:
import torch

image_caption = torch.Tensor(image_caption).long()
print(image_caption)

tensor([    0,     3,    98,   756,     3,   396,    39,     3,  1014,
          207,   139,     3,   755,    18,     1])


In [24]:
# Preview the word2idx dictionary.
dict(list(data_loader.dataset.vocab.word2idx.items())[:10])

{'<start>': 0,
 '<end>': 1,
 '<unk>': 2,
 'a': 3,
 'very': 4,
 'clean': 5,
 'and': 6,
 'well': 7,
 'decorated': 8,
 'empty': 9}

In [25]:
# Print the total number of keys in the word2idx dictionary.
print('Total number of tokens in vocabulary:', len(data_loader.dataset.vocab))

Total number of tokens in vocabulary: 9955


Check this for yourself below, by pre-processing the provided nonsense words that never appear in the training captions. 

In [27]:
# Obtain the data loader (from file). Note that it runs much faster than before!
data_loader = get_loader(transform=tt,
                         mode='train',
                         batch_size=batch_size,
                         vocab_from_file=True)

Vocabulary successfully loaded from vocab.pkl file!
loading annotations into memory...


100%|██████████| 414113/414113 [01:29<00:00, 4605.00it/s]


Done (t=0.83s)
creating index...
index created!
Obtaining caption lengths...


## Obtaining Batches of Training Data

In [28]:
from collections import Counter


total_training_captions = Counter(data_loader.dataset.caption_lengths)
each_caption_length = sorted(total_training_captions.items(), key=lambda pair: pair[1], reverse=True)
for value, count in each_caption_length:
    print('value: %2d --- count: %5d' % (value, count))

value: 10 --- count: 86334
value: 11 --- count: 79948
value:  9 --- count: 71934
value: 12 --- count: 57637
value: 13 --- count: 37645
value: 14 --- count: 22335
value:  8 --- count: 20771
value: 15 --- count: 12841
value: 16 --- count:  7729
value: 17 --- count:  4842
value: 18 --- count:  3104
value: 19 --- count:  2014
value:  7 --- count:  1597
value: 20 --- count:  1451
value: 21 --- count:   999
value: 22 --- count:   683
value: 23 --- count:   534
value: 24 --- count:   383
value: 25 --- count:   277
value: 26 --- count:   215
value: 27 --- count:   159
value: 28 --- count:   115
value: 29 --- count:    86
value: 30 --- count:    58
value: 31 --- count:    49
value: 32 --- count:    44
value: 34 --- count:    39
value: 37 --- count:    32
value: 33 --- count:    31
value: 35 --- count:    31
value: 36 --- count:    26
value: 38 --- count:    18
value: 39 --- count:    18
value: 43 --- count:    16
value: 44 --- count:    16
value: 48 --- count:    12
value: 45 --- count:    11
v

In [34]:
import numpy as np
import torch.utils.data as data


idxs = data_loader.dataset.get_train_indices()
print('sampled indices:', idxs)


sampler_new = data.sampler.SubsetRandomSampler(indices=idxs)
data_loader.batch_sampler.sampler = sampler_new
    

img, image_captions = next(iter(data_loader))
    
print('images.shape:', img.shape)
print('captions.shape:', image_captions.shape)


print('images:', img)
print('captions:', image_captions)

sampled indices: [255961, 384592, 240469, 196185, 292053, 169926, 392656, 404516, 354966, 353065]
images.shape: torch.Size([10, 3, 224, 224])
captions.shape: torch.Size([10, 16])
images: tensor([[[[ 0.8104,  0.8276,  0.7933,  ...,  1.2385,  1.2557,  1.2385],
          [ 0.8104,  0.8447,  0.7933,  ...,  1.2043,  1.2043,  1.2043],
          [ 0.7591,  0.8104,  0.7933,  ...,  1.1872,  1.1872,  1.1700],
          ...,
          [-0.6794, -0.7479, -0.9363,  ..., -0.6452, -0.6623, -0.7479],
          [-0.7308, -0.7650, -1.0904,  ..., -0.8335, -0.6965, -0.8164],
          [-0.6965, -0.9877, -1.4329,  ..., -1.2959, -1.2617, -1.3815]],

         [[ 1.2031,  1.2381,  1.1856,  ...,  1.5707,  1.5707,  1.5882],
          [ 1.2031,  1.2381,  1.2031,  ...,  1.5532,  1.5532,  1.5707],
          [ 1.1681,  1.1856,  1.1856,  ...,  1.5707,  1.5707,  1.5532],
          ...,
          [-0.7402, -0.7752, -0.8978,  ..., -0.3550, -0.4076, -0.4601],
          [-0.7577, -0.7402, -0.9678,  ..., -0.6877, -0.5826,