## Downloading dataset and predefined train, test and validation splits

In [None]:
! pip install -q kaggle
from google.colab import files
files.upload()
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json
! kaggle datasets list

In [None]:
!kaggle datasets download -d adityajn105/flickr8k

Downloading flickr8k.zip to /content
100% 1.04G/1.04G [00:48<00:00, 25.1MB/s]
100% 1.04G/1.04G [00:48<00:00, 23.1MB/s]


In [None]:
!wget http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip

In [None]:
!unzip ./caption_datasets.zip

In [None]:
!unzip flickr8k.zip

## Helper Functions

In [17]:
import os
import numpy as np
import h5py
import json
import torch
from imageio import imread
from skimage.transform import resize
from tqdm import tqdm
from collections import Counter
from random import seed, choice, sample

In [18]:
def create_input_files(dataset, karpathy_json_path, image_folder, captions_per_image, min_word_freq, output_folder,max_len=100):
  assert dataset in {'coco', 'flickr8k', 'flickr30k'}

  with open(karpathy_json_path, 'r') as j:
      data = json.load(j)
  
  train_image_paths = []
  train_image_captions = []
  val_image_paths = []
  val_image_captions = []
  test_image_paths = []
  test_image_captions = []
  word_freq = Counter()

  for img in data['images']:
    captions = []
    for c in img['sentences']:
        word_freq.update(c['tokens'])
        if len(c['tokens']) <= max_len:
            captions.append(c['tokens'])

    if len(captions) == 0:
        continue
    
    path = os.path.join(image_folder, img['filepath'], img['filename']) if dataset == 'coco' else os.path.join(
            image_folder, img['filename'])

    if img['split'] in {'train', 'restval'}:
        train_image_paths.append(path)
        train_image_captions.append(captions)
    elif img['split'] in {'val'}:
        val_image_paths.append(path)
        val_image_captions.append(captions)
    elif img['split'] in {'test'}:
        test_image_paths.append(path)
        test_image_captions.append(captions)

  assert len(train_image_paths) == len(train_image_captions)
  assert len(val_image_paths) == len(val_image_captions)
  assert len(test_image_paths) == len(test_image_captions)

  words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
  word_map = {k: v + 1 for v, k in enumerate(words)}
  word_map['<unk>'] = len(word_map) + 1
  word_map['<start>'] = len(word_map) + 1
  word_map['<end>'] = len(word_map) + 1
  word_map['<pad>'] = 0
  
  base_filename = dataset + '_' + str(captions_per_image) + '_cap_per_img_' + str(min_word_freq) + '_min_word_freq'

  with open(os.path.join(output_folder, 'WORDMAP_' + base_filename + '.json'), 'w') as j:
      json.dump(word_map, j)

  seed(123)
  for impaths, imcaps, split in [(train_image_paths, train_image_captions, 'TRAIN'),
                                  (val_image_paths, val_image_captions, 'VAL'),
                                  (test_image_paths, test_image_captions, 'TEST')]:

      with h5py.File(os.path.join(output_folder, split + '_IMAGES_' + base_filename + '.hdf5'), 'a') as h:
          
          h.attrs['captions_per_image'] = captions_per_image

          
          images = h.create_dataset('images', (len(impaths), 3, 256, 256), dtype='uint8')

          print("\nReading %s images and captions, storing to file...\n" % split)

          enc_captions = []
          caplens = []

          for i, path in enumerate(tqdm(impaths)):

              
              if len(imcaps[i]) < captions_per_image:
                  captions = imcaps[i] + [choice(imcaps[i]) for _ in range(captions_per_image - len(imcaps[i]))]
              else:
                  captions = sample(imcaps[i], k=captions_per_image)

              
              assert len(captions) == captions_per_image

              
              img = imread(impaths[i])
              if len(img.shape) == 2:
                  img = img[:, :, np.newaxis]
                  img = np.concatenate([img, img, img], axis=2)
              img = resize(img, (256, 256))
              img = img.transpose(2, 0, 1)
              assert img.shape == (3, 256, 256)
              assert np.max(img) <= 255

              
              images[i] = img

              for j, c in enumerate(captions):
                  
                  enc_c = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in c] + [
                      word_map['<end>']] + [word_map['<pad>']] * (max_len - len(c))

                  
                  c_len = len(c) + 2

                  enc_captions.append(enc_c)
                  caplens.append(c_len)

          assert images.shape[0] * captions_per_image == len(enc_captions) == len(caplens)

          with open(os.path.join(output_folder, split + '_CAPTIONS_' + base_filename + '.json'), 'w') as j:
              json.dump(enc_captions, j)

          with open(os.path.join(output_folder, split + '_CAPLENS_' + base_filename + '.json'), 'w') as j:
              json.dump(caplens, j)

In [19]:
def init_embedding(embeddings):
    bias = np.sqrt(3.0 / embeddings.size(1))
    torch.nn.init.uniform_(embeddings, -bias, bias)

In [20]:
def load_embeddings(emb_file, word_map):
    with open(emb_file, 'r') as f:
        emb_dim = len(f.readline().split(' ')) - 1

    vocab = set(word_map.keys())
    embeddings = torch.FloatTensor(len(vocab), emb_dim)
    init_embedding(embeddings)
    print("\nLoading embeddings...")
    for line in open(emb_file, 'r'):
        line = line.split(' ')

        emb_word = line[0]
        embedding = list(map(lambda t: float(t), filter(lambda n: n and not n.isspace(), line[1:])))
        if emb_word not in vocab:
            continue

        embeddings[word_map[emb_word]] = torch.FloatTensor(embedding)

    return embeddings, emb_dim

In [21]:
def clip_gradient(optimizer, grad_clip):
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)

In [22]:
def save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, decoder_optimizer,
                    bleu4, is_best):
    state = {'epoch': epoch,
             'epochs_since_improvement': epochs_since_improvement,
             'bleu-4': bleu4,
             'encoder': encoder,
             'decoder': decoder,
             'encoder_optimizer': encoder_optimizer,
             'decoder_optimizer': decoder_optimizer}
    filename = 'checkpoint_' + data_name + '.pth.tar'
    torch.save(state, filename)
    if is_best:
        torch.save(state, 'BEST_' + filename)

In [23]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [24]:
def adjust_learning_rate(optimizer, shrink_factor):
    print("\nDECAYING learning rate.")
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * shrink_factor
    print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],))

In [25]:
def accuracy(scores, targets, k):
    batch_size = targets.size(0)
    _, ind = scores.topk(k, 1, True, True)
    correct = ind.eq(targets.view(-1, 1).expand_as(ind))
    correct_total = correct.view(-1).float().sum()
    return correct_total.item() * (100.0 / batch_size)

## Creating input files

In [26]:
create_input_files(
    dataset = 'flickr8k',
    karpathy_json_path = './dataset_flickr8k.json',
    image_folder = './Images/',
    captions_per_image = 5,
    min_word_freq = 5,
    output_folder = './Images/',
    max_len = 50
)


Reading TRAIN images and captions, storing to file...



100%|██████████| 6000/6000 [02:59<00:00, 33.36it/s]



Reading VAL images and captions, storing to file...



100%|██████████| 1000/1000 [00:28<00:00, 35.05it/s]



Reading TEST images and captions, storing to file...



100%|██████████| 1000/1000 [00:33<00:00, 30.25it/s]


## Creating Dataset for PyTorch using the Dataset class in PyTorch so that the dataset can be used for training, eval and testing

In [27]:
import torch
from torch.utils.data import Dataset
import h5py
import json
import os

In [29]:
class CaptionDataset(Dataset):
  def __init__(self,data_folder,data_name,split,transform = None):
    self.split = split
    assert self.split in {"TRAIN","VAL","TEST"}

    self.h = h5py.File(os.path.join(data_folder, self.split + '_IMAGES_' + data_name + '.hdf5'), 'r')
    self.imgs = self.h['images']
    self.cpi = self.h.attrs['captions_per_image']
    with open(os.path.join(data_folder, self.split + '_CAPTIONS_' + data_name + '.json'), 'r') as j:
        self.captions = json.load(j)
    with open(os.path.join(data_folder, self.split + '_CAPLENS_' + data_name + '.json'), 'r') as j:
        self.caplens = json.load(j)
    self.transform = transform

    self.dataset_size = len(self.captions)
  
  def __getitem__(self, i):
    img = torch.FloatTensor(self.imgs[i // self.cpi] / 255.)
    if self.transform is not None:
        img = self.transform(img)

    caption = torch.LongTensor(self.captions[i])

    caplen = torch.LongTensor([self.caplens[i]])

    if self.split is 'TRAIN':
        return img, caption, caplen
    else:
        all_captions = torch.LongTensor(
            self.captions[((i // self.cpi) * self.cpi):(((i // self.cpi) * self.cpi) + self.cpi)])
        return img, caption, caplen, all_captions

  def __len__(self):
      return self.dataset_size

## Creating the Encoder and Decoder Architecture for model training

In [30]:
import torch
from torch import nn
import torchvision

In [31]:
device = torch.device(
    "cuda" if torch.cuda.is_available()
    else "cpu"
)

In [32]:
device

device(type='cuda')

In [33]:
class Encoder(nn.Module):
  def __init__(self,encoded_image_size = 14):
    super(Encoder,self).__init__()
    self.enc_image_size = encoded_image_size
    resnet = torchvision.models.resnet101(pretrained = True)
    modules = list(resnet.children())[:-2]
    self.resnet = nn.Sequential(*modules)
    self.addaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size,encoded_image_size))
    self.fine_tune()
  
  def forward(self,images):
    out = self.resnet(images)
    out = self.addaptive_pool(out)
    out = out.permute(0,2,3,1)
    return out
  
  def fine_tune(self,fine_tune = True):
    for p in self.resnet.parameters():
      p.requires_grad = False
    for c in list(self.resnet.children())[5:]:
      for p in c.parameters():
        p.required_grad = fine_tune

In [36]:
class Attention(nn.Module):
  def __init__(self,encoder_dim,decoder_dim,attention_dim):
    super(Attention,self).__init__()
    self.encoder_att = nn.Linear(encoder_dim,attention_dim)
    self.decoder_att = nn.Lienar(decoder_dim,attention_dim)
    self.full_att = nn.Linear(attention_dim,1)
    self.relu = nn.ReLU()
    self.softmax = nn.Softmax(dim = 1)
  
  def forward(self,encoder_out,decoder_hidden):
    att1 = self.encoder_att(encoder_out)
    att2 = self.decoder_att(decoder_hidden)
    att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)
    alpha = self.softmax(att)
    attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim = 1)
    return attention_weighted_encoding, alpha

In [37]:
class DecoderWithAttention(nn.Module):

    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5):
        super(DecoderWithAttention, self).__init__()

        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout

        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.dropout = nn.Dropout(p=self.dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)
        self.init_weights()

    def init_weights(self):
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def load_pretrained_embeddings(self, embeddings):
        self.embedding.weight = nn.Parameter(embeddings)

    def fine_tune_embeddings(self, fine_tune=True):
        for p in self.embedding.parameters():
            p.requires_grad = fine_tune

    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size

        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)
        num_pixels = encoder_out.size(1)

        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        embeddings = self.embedding(encoded_captions) 
        h, c = self.init_hidden_state(encoder_out)
        decode_lengths = (caption_lengths - 1).tolist()

        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)
        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
                                                                h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t]))
            preds = self.fc(self.dropout(h))
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha

        return predictions, encoded_captions, decode_lengths, alphas, sort_ind