## Downloading the Flickr Dataset from Kaggle

In [1]:
from google.colab import files

In [None]:
files.upload()

In [3]:
! mkdir ~/.kaggle

In [4]:
!cp kaggle.json ~/.kaggle/

In [5]:
! chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d aladdinpersson/flickr8kimagescaptions

In [None]:
!unzip flickr8kimagescaptions.zip

In [None]:
!wget https://raw.githubusercontent.com/aladdinpersson/Machine-Learning-Collection/master/ML/Pytorch/more_advanced/image_captioning/test_examples/boat.png
!wget https://raw.githubusercontent.com/aladdinpersson/Machine-Learning-Collection/master/ML/Pytorch/more_advanced/image_captioning/test_examples/bus.png
!wget https://raw.githubusercontent.com/aladdinpersson/Machine-Learning-Collection/master/ML/Pytorch/more_advanced/image_captioning/test_examples/child.jpg
!wget https://raw.githubusercontent.com/aladdinpersson/Machine-Learning-Collection/master/ML/Pytorch/more_advanced/image_captioning/test_examples/dog.jpg
!wget https://raw.githubusercontent.com/aladdinpersson/Machine-Learning-Collection/master/ML/Pytorch/more_advanced/image_captioning/test_examples/horse.png

## Creating a Dataloader for loading the FLICKR 8k dataset and preparing it for training and testing

In [9]:
import os
import pandas as pd
import spacy
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision.transforms as transforms

In [10]:
spacy_eng = spacy.load("en_core_web_sm")

In [11]:
class Vocabulary:
  def __init__(self,freq_thresh):
    self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
    self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
    self.freq_thresh = freq_thresh
  
  def __len__(self):
    return len(self.itos)
  
  @staticmethod
  def tokenizer_eng(text):
    return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
  
  def build_vocabulary(self,sentence_list):
    frequencies = {}
    idx = 4

    for sentence in sentence_list:
      for word in self.tokenizer_eng(sentence):
        if word not in frequencies:
          frequencies[word] = 1
        else:
          frequencies[word] += 1
        
        if frequencies[word] == self.freq_thresh:
          self.stoi[word] = idx
          self.itos[idx] = word
          idx = idx + 1
  
  def numericalize(self,text):
    tokenized_text = self.tokenizer_eng(text)

    return [
        self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
        for token in tokenized_text
    ]

In [33]:
class FlickrDataset(Dataset):
  def __init__(self,root_dir,captions_file,transform = None, freq_thresh = 5):
    self.root_dir = root_dir
    self.df = pd.read_csv(captions_file)
    self.transform = transform

    self.imgs = self.df["image"]
    self.captions = self.df['caption']

    self.vocab = Vocabulary(freq_thresh)
    self.vocab.build_vocabulary(self.captions.tolist())
  
  def __len__(self):
    return len(self.df)
  
  def __getitem__(self,index):
    caption = self.captions[index]
    img_id = self.imgs[index]
    img = Image.open(os.path.join(self.root_dir,img_id)).convert("RGB")

    if self.transform is not None:
      img = self.transform(img)
    
    numericalized_caption = [self.vocab.stoi["<SOS>"]]
    numericalized_caption += self.vocab.numericalize(caption)
    numericalized_caption.append(self.vocab.stoi["<EOS>"])

    return img, torch.tensor(numericalized_caption)

In [34]:
class MyCollate:
  def __init__(self,pad_idx):
    self.pad_idx = pad_idx
  
  def __call__(self,batch):
    imgs = [item[0].unsqueeze(0) for item in batch]
    imgs = torch.cat(imgs,dim = 0)
    targets = [item[1] for item in batch]
    targets = pad_sequence(targets,batch_first = False, padding_value = self.pad_idx)

    return imgs, targets

In [35]:
def get_loader(root,annotation_file,transform,batch_size = 32,num_workers = 8,shuffle = True, pin_memory = True):
  dataset = FlickrDataset(root,annotation_file,transform = transform)

  pad_idx = dataset.vocab.stoi["<PAD>"]
  loader = DataLoader(
      dataset = dataset,
      batch_size = batch_size,
      num_workers = num_workers,
      pin_memory = pin_memory,
      shuffle = shuffle,
      collate_fn = MyCollate(pad_idx = pad_idx)
  )

  return loader, dataset

In [36]:
transform = transforms.Compose(
    [transforms.Resize((224,224)), transforms.ToTensor()]
)

loader, dataset = get_loader(
    "./flickr8k/images/",
    "./flickr8k/captions.txt",
    transform = transform
)

  cpuset_checked))


In [37]:
for idx, (imgs, captions) in enumerate(loader):
  print("Image Shape",imgs.shape)
  print("Captions Shape",captions.shape)
  break

Image Shape torch.Size([32, 3, 224, 224])
Captions Shape torch.Size([24, 32])


## Helper Functions

In [38]:
import torch
import torchvision.transforms as transforms
from PIL import Image

In [39]:
def print_examples(model,device,dataset):
  transform = transforms.Compose(
        [
            transforms.Resize((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )
  model.eval()
  test_img1 = transform(Image.open("./dog.jpg").convert("RGB")).unsqueeze(0)
  print("Example 1 CORRECT: Dog on a beach by the ocean")
  print(
      "Example 1 OUTPUT: "
      + " ".join(model.caption_image(test_img1.to(device), dataset.vocab))
  )
  test_img2 = transform(
      Image.open("./child.jpg").convert("RGB")
  ).unsqueeze(0)
  print("Example 2 CORRECT: Child holding red frisbee outdoors")
  print(
      "Example 2 OUTPUT: "
      + " ".join(model.caption_image(test_img2.to(device), dataset.vocab))
  )
  test_img3 = transform(Image.open("./bus.png").convert("RGB")).unsqueeze(
      0
  )
  print("Example 3 CORRECT: Bus driving by parked cars")
  print(
      "Example 3 OUTPUT: "
      + " ".join(model.caption_image(test_img3.to(device), dataset.vocab))
  )
  test_img4 = transform(
      Image.open("./boat.png").convert("RGB")
  ).unsqueeze(0)
  print("Example 4 CORRECT: A small boat in the ocean")
  print(
      "Example 4 OUTPUT: "
      + " ".join(model.caption_image(test_img4.to(device), dataset.vocab))
  )
  test_img5 = transform(
      Image.open("./horse.png").convert("RGB")
  ).unsqueeze(0)
  print("Example 5 CORRECT: A cowboy riding a horse in the desert")
  print(
      "Example 5 OUTPUT: "
      + " ".join(model.caption_image(test_img5.to(device), dataset.vocab))
  )
  model.train()

In [40]:
def save_checkpoint(state,file_name = 'model_checkpoint.pth.tar'):
  print("************* Saving Model Checkpoint ***************************")
  torch.save(state,file_name)

In [41]:
def load_checkpoint(checkpoint, model, optimizer):
  print("************ Loading Checkpoint ************")
  model.load_state_dict(checkpoint['state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer'])
  step = checkpoint['step']
  return step

## Defining the Model

In [42]:
import torch
import torch.nn as nn
import statistics
import torchvision.models as models

In [43]:
class EncoderCNN(nn.Module):

  def __init__(self,embed_size,train_cnn = False):
    super(EncoderCNN,self).__init__()

    self.train_cnn = train_cnn
    self.inception = models.inception_v3(pretrained = True, aux_logits = True)
    self.inception.fc = nn.Linear(self.inception.fc.in_features,embed_size)
    self.relu = nn.ReLU()
    self.times = []
    self.dropout = nn.Dropout(0.5)

  def forward(self,images):
    features = self.inception(images)
    return self.dropout(self.relu(features[0]))

In [44]:
class DecoderRNN(nn.Module):
  
  def __init__(self,embed_size,hidden_size,vocab_size,num_layers):
    super(DecoderRNN,self).__init__()
    self.embed =  nn.Embedding(vocab_size,embed_size)
    self.lstm = nn.LSTM(embed_size,hidden_size,num_layers)
    self.linear = nn.Linear(hidden_size,vocab_size)
    self.dropout = nn.Dropout(0.5)
  
  def forward(self,features,captions):
    embeddings = self.dropout(self.embed(captions))
    embeddings = torch.cat((features.unsqueeze(0),embeddings),dim = 0)
    hiddens, _ = self.lstm(embeddings)
    outputs = self.linear(hiddens)
    return outputs

In [45]:
class CNNtoRNN(nn.Module):

  def __init__(self,embed_size,hidden_size,vocab_size,num_layers):
    super(CNNtoRNN,self).__init__()
    self.encoderCNN = EncoderCNN(embed_size)
    self.decoderRNN = DecoderRNN(embed_size,hidden_size,vocab_size,num_layers)
    
  def forward(self,images,captions):
    features = self.encoderCNN(images)
    outputs = self.decoderRNN(features,captions)
    return outputs
  
  def caption_image(self,image,vocabulary,max_len = 50):
    result_caption = []

    with torch.no_grad():
      x = self.encoderCNN(image).unsqueeze(0)
      states = None

      for _ in range(max_len):
        hiddens, states = self.decoderRNN.lstm(x,states)
        output = self.decoderRNN.linear(hiddens.squeeze(0))
        predicted = output.argmax(0)
        result_caption.append(predicted.item())
        x = self.decoderRNN.embed(predicted).unsqueeze(0)
        if vocabulary.itos[predicted.item()] == '<EOS>':
          break
    
    return [vocabulary.itos[idx] for idx in result_caption]

## Model Training

In [46]:
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms

In [47]:
def train():
  transform = transforms.Compose(
      [
        transforms.Resize((356,356)),
        transforms.RandomCrop((299,299)),
       transforms.ToTensor(),
       transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
      ]
  )

  train_loader, dataset = get_loader(
      root = './flickr8k/images/',
      annotation_file = './flickr8k/captions.txt',
      transform = transform,
      num_workers = 2
  )

  torch.backends.cudnn.benchmark = True
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  load_model = False
  save_model = True
  train_cnn = False

  embed_size = 256
  hidden_size = 256
  vocab_size = len(dataset.vocab)
  num_layers = 1
  learning_rate = 3e-4
  num_epochs = 20


  writer = SummaryWriter("./runs/flickr")
  step = 0
  model = CNNtoRNN(embed_size,hidden_size,vocab_size,num_layers).to(device)
  criterion = nn.CrossEntropyLoss(ignore_index = dataset.vocab.stoi["<PAD>"])
  optimizer = optim.Adam(model.parameters(),lr = learning_rate)

  for name, param in model.encoderCNN.inception.named_parameters():
    if "fc.weight" in name or "fc.bias" in name:
      param.requires_grad = True
    else:
      param.required_grad = train_cnn
    
  if load_model:
    step = load_checkpoint(torch.load("model_checkpoint.pth.tar"),model,optimizer)
  
  model.train()

  for epoch in range(num_epochs):

    print_examples(model, device, dataset)

    print("<============Epoch=============>",epoch)
    if save_model:
      checkpoint = {
          "state_dict": model.state_dict(),
          "optimizer": optimizer.state_dict(),
      }
      save_checkpoint(checkpoint)

    for idx, (imgs,captions) in tqdm(enumerate(train_loader), total = len(train_loader), leave = False):
      imgs = imgs.to(device)
      captions = captions.to(device)
      outputs = model(imgs,captions[:-1])
      loss = criterion(outputs.reshape(-1,outputs.shape[2]),captions.reshape(-1))

      writer.add_scalar("Training Loss",loss.item(),global_step = step)

      optimizer.zero_grad()
      loss.backward(loss)
      optimizer.step()

  return model

In [None]:
model = train()

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "


Example 1 CORRECT: Dog on a beach by the ocean
Example 1 OUTPUT: padded teacher creating wine gold strange rowing puppy skimpy lasso that alongside small followed electronic touching runners outside snowy milkshake vine lunch self pugs for wine wires apples wrapping gravel squirted naked attention gathers creating wine gold strange rowing puppy skimpy lasso that alongside small followed electronic touching runners outside
Example 2 CORRECT: Child holding red frisbee outdoors
Example 2 OUTPUT: victory mug men pouring footballers headset offering onlookers batting lining cones unhappy mug men travels formally pet spot four color teacher tiny cannon field caps marble interested still padded crosswalk fight dragging to grey dive big instrument ten when struggles climbers refrigerator climbers refrigerator squeezing lunch dimly converse patio checkered
Example 3 CORRECT: Bus driving by parked cars
Example 3 OUTPUT: victory mug tries connected shown cups clouds rollerblader unhappy motorcycl

 14%|█▎        | 171/1265 [01:23<08:20,  2.18it/s]