## Downloading the Flickr Dataset from Kaggle

In [1]:
from google.colab import files

In [2]:
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"vatsalmpatel","key":"c9d984a8eae524b0a634bc03e0f663a2"}'}

In [3]:
! mkdir ~/.kaggle

In [4]:
!cp kaggle.json ~/.kaggle/

In [5]:
! chmod 600 ~/.kaggle/kaggle.json

In [6]:
!kaggle datasets download -d aladdinpersson/flickr8kimagescaptions

Downloading flickr8kimagescaptions.zip to /content
100% 1.03G/1.04G [00:38<00:00, 35.7MB/s]
100% 1.04G/1.04G [00:38<00:00, 28.6MB/s]


In [None]:
!unzip flickr8kimagescaptions.zip

## Creating a Dataloader for loading the FLICKR 8k dataset and preparing it for training and testing

In [23]:
import os
import pandas as pd
import spacy
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision.transforms as transforms

In [24]:
spacy_eng = spacy.load("en_core_web_sm")

In [25]:
class Vocabulary:
  def __init__(self,freq_thresh):
    self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
    self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
    self.freq_thresh = freq_thresh
  
  def __len__(self):
    return len(self.itos)
  
  @staticmethod
  def tokenizer_eng(text):
    return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
  
  def build_vocabulary(self,sentence_list):
    frequencies = {}
    idx = 4

    for sentence in sentence_list:
      for word in self.tokenizer_eng(sentence):
        if word not in frequencies:
          frequencies[word] = 1
        else:
          frequencies[word] += 1
        
        if frequencies[word] == self.freq_thresh:
          self.stoi[word] = idx
          self.itos[idx] = word
          idx = idx + 1
  
  def numericalize(self,text):
    tokenized_text = self.tokenizer_eng(text)

    return [
        self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
        for token in tokenized_text
    ]

In [26]:
class FlickrDataset(Dataset):
  def __init__(self,root_dir,captions_file,transform = None, freq_thresh = 5):
    self.root_dir = root_dir
    self.df = pd.read_csv(captions_file)
    self.transform = transform

    self.imgs = self.df["image"]
    self.captions = self.df['caption']

    self.vocab = Vocabulary(freq_thresh)
    self.vocab.build_vocabulary(self.captions.tolist())
  
  def __len__(self):
    return len(self.df)
  
  def __getitem__(self,index):
    caption = self.captions[index]
    img_id = self.imgs[index]
    img = Image.open(os.path.join(self.root_dir,img_id)).convert("RGB")

    if self.transform is not None:
      img = self.transform(img)
    
    numericalized_caption = [self.vocab.stoi["<SOS>"]]
    numericalized_caption += self.vocab.numericalize(caption)
    numericalized_caption.append(self.vocab.stoi["<EOS>"])

    return img, torch.tensor(numericalized_caption)

In [27]:
class MyCollate:
  def __init__(self,pad_idx):
    self.pad_idx = pad_idx
  
  def __call__(self,batch):
    imgs = [item[0].unsqueeze(0) for item in batch]
    imgs = torch.cat(imgs,dim = 0)
    targets = [item[1] for item in batch]
    targets = pad_sequence(targets,batch_first = False, padding_value = self.pad_idx)

    return imgs, targets

In [28]:
def get_loader(root,annotation_file,transform,batch_size = 32,num_workers = 8,shuffle = True, pin_memory = True):
  dataset = FlickrDataset(root,annotation_file,transform = transform)

  pad_idx = dataset.vocab.stoi["<PAD>"]
  loader = DataLoader(
      dataset = dataset,
      batch_size = batch_size,
      num_workers = num_workers,
      pin_memory = pin_memory,
      shuffle = shuffle,
      collate_fn = MyCollate(pad_idx = pad_idx)
  )

  return loader, dataset

In [29]:
transform = transforms.Compose(
    [transforms.Resize((224,224)), transforms.ToTensor()]
)

loader, dataset = get_loader(
    "./flickr8k/images/",
    "./flickr8k/captions.txt",
    transform = transform
)

In [32]:
for idx, (imgs, captions) in enumerate(loader):
  print("Image Shape",imgs.shape)
  print("Captions Shape",captions.shape)
  break

Image Shape torch.Size([32, 3, 224, 224])
Captions Shape torch.Size([22, 32])


## Helper Functions

In [33]:
import torch
import torchvision.transforms as transforms
from PIL import Image

In [34]:
def print_examples(model,device,dataset):
  transform = transforms.Compose(
        [
            transforms.Resize((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )
  model.eval()
  test_img1 = transform(Image.open("test_examples/dog.jpg").convert("RGB")).unsqueeze(
      0
  )
  print("Example 1 CORRECT: Dog on a beach by the ocean")
  print(
      "Example 1 OUTPUT: "
      + " ".join(model.caption_image(test_img1.to(device), dataset.vocab))
  )
  test_img2 = transform(
      Image.open("test_examples/child.jpg").convert("RGB")
  ).unsqueeze(0)
  print("Example 2 CORRECT: Child holding red frisbee outdoors")
  print(
      "Example 2 OUTPUT: "
      + " ".join(model.caption_image(test_img2.to(device), dataset.vocab))
  )
  test_img3 = transform(Image.open("test_examples/bus.png").convert("RGB")).unsqueeze(
      0
  )
  print("Example 3 CORRECT: Bus driving by parked cars")
  print(
      "Example 3 OUTPUT: "
      + " ".join(model.caption_image(test_img3.to(device), dataset.vocab))
  )
  test_img4 = transform(
      Image.open("test_examples/boat.png").convert("RGB")
  ).unsqueeze(0)
  print("Example 4 CORRECT: A small boat in the ocean")
  print(
      "Example 4 OUTPUT: "
      + " ".join(model.caption_image(test_img4.to(device), dataset.vocab))
  )
  test_img5 = transform(
      Image.open("test_examples/horse.png").convert("RGB")
  ).unsqueeze(0)
  print("Example 5 CORRECT: A cowboy riding a horse in the desert")
  print(
      "Example 5 OUTPUT: "
      + " ".join(model.caption_image(test_img5.to(device), dataset.vocab))
  )
  model.train()

In [36]:
def save_checkpoint(state,file_name = 'model_checkpoint.pth.tar'):
  print("************* Saving Model Checkpoint ***************************")
  torch.save(state,file_name)

In [37]:
def load_checkpoint(checkpoint, model, optimizer):
  print("************ Loading Checkpoint ************")
  model.load_state_dict(checkpoint['state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer'])
  step = checkpoint['step']
  return step