<a href="https://colab.research.google.com/github/vjhawar12/Image-Captioning/blob/main/Image_Captioning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchtext==0.17.0 && pip install torch==2.2.0 && pip install torchvision==0.17.0 && pip install evaluate

In [None]:
from torchtext.vocab import vocab
import torch
import torchvision
from torchvision.transforms import v2
from torchvision.io import decode_image
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import CocoDetection
from torch.utils.data import DataLoader, Dataset
from pycocotools.coco import COCO
from pprint import pprint
import pandas as pd
from skimage import io
from os import path
from random import randint
from collections import Counter
from google.cloud import storage
from tqdm import tqdm
from evaluate import load
from torch.func import vmap
from torch.nn.utils.rnn import pad_sequence

# CUDA Optimizations

In [None]:
if torch.cuda.is_available():
  torch.backends.cuda.matmul.allow_tf32 = True # more efficient highly-accurate data format
  torch.backends.cudnn.allow_tf32 = True
  torch.backends.cuda.enable_flash_sdp(True) # efficient version of scaled dot product attention comptuation
  torch.backends.cuda.enable_mem_efficient_sdp(True)
  torch.backends.cuda.enable_math_sdp(True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
encoder = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True) # loading MobileNetV2.

encoder.classifier = nn.Identity() # removing the final classification layer to retrieve the feature map. Feature map: [1, 1280]
encoder.to(device) # moving to CUDA if possible

In [None]:
for param in encoder.parameters(): # freezing the encoder since we're not training it
  param.requires_grad = False # avoid computing gradients for brevity

In [None]:
class GRU_Decoder(nn.Module):

  def __init__(self, feature_map_size=1280, embed_size=256, hidden_size=512, num_layers=2, vocab_size=10000):
    super().__init__()

    self.feature_map_size = feature_map_size
    self.embed_size = embed_size
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.vocab_size = vocab_size

    self.embed = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.embed_size) # word --> embedding (vector representation)
    self.proj = nn.Linear(in_features=self.feature_map_size, out_features=self.hidden_size) # dim(feature space) --> dim(hidden state) to draw caption-related information from the raw images
    self.gru = nn.GRU(input_size=self.embed_size, hidden_size=self.hidden_size, num_layers=self.num_layers, batch_first=True) # embed_size --> hidden_size
    self.fc = nn.Linear(in_features=self.hidden_size, out_features=self.vocab_size) # hidden state vector -> vocabulary vector (in hidden state vector space, the vector is not interpretable hence it needs to go to vocabulary vector space)

  def generate(self, feature_map, bos_token, eos_token, max_len=10):
    batch_size = feature_map.size(0)
    bos_token = self.embed(bos_token).unsqueeze(1)
    h = self.proj(feature_map).unsqueeze(0) # initial hidden state

    last_word = bos_token
    caption = []

    for i in range(max_len): # don't have the entire caption yet, so need to loop until its generated
      output, h = self.gru(last_word, h) # passing the last word generated through the GRU layer to get the next word
      logits = self.fc(output) # now in vocabulary vector space
      word = torch.argmax(word, dim=1) # argmaxxing to get the most probable predicted word
      caption.append(word) # adding this word to the caption generated so far

      if torch.all(word == eos_token): # comparing word and eos token across the various dimensions
        break # exit if reached end of caption

      last_word = self.embed(word).unsqueeze(1) # shifting the last_word pointer to the right

    return torch.stack(caption, dim=1) # formatting the caption correctly before returning it

  def forward(self, feature_map, words):
    batch_size = feature_map.size(0)
    embedding = self.embed(words) # returns a vector representation of a word
    h0 = self.proj(feature_map).unsqueeze(0) # initializes the hidden state by projecting the feature map onto the hidden state dimensional space
    h0 = h0.reshape(self.num_layers, batch_size, self.hidden_size) # gru expects hidden state in a certain format
    output, _ = self.gru(embedding, h0) # teacher-forcing with the correct captions
    logits = self.fc(output) # going from hidden state vector space --> vocabulary vector space

    return logits



In [None]:
class MiniCoco(Dataset):

  def __init__(self, json_file, root_dir, split, transform=None):
    super().__init__()

    self.full_data = pd.read_json(json_file)
    self.data = self.full_data["images"]
    self.split = split
    self.counter = Counter() # counting the # of occurances of a particular word in a sentence
    self.captions = [] # nested list with all the captions for each sample

    if self.split == "train":
      self.data = [obj for obj in self.data if obj["split"] == "restval"]
    elif self.split == "val":
      self.data = [obj for obj in self.data if obj["split"] == "val"]
    elif self.split == "test":
      self.data = [obj for obj in self.data if obj["split"] == "test"]
    else:
      raise Exception("Invalid split")

    self.length = len(self.data)

    self.root_dir = root_dir
    self.transform = transform

    if self.split == "train": # only want to store captions for train -- during test/val model should be generating without knowing any ground truth
      for sample in range(len(self.data)): # iterating over all samples in the train dataset
        cap = [] # captions for particular sample

        for j in range(len(self.data[sample]["sentences"])): # iterating over the various captions provided for each sample
          caption = self.data[sample]["sentences"][j]
          token = caption["tokens"]
          self.counter.update(token) # keeping track of the frequency of each token
          cap.append(token)

        self.captions.append(cap)
    else:
      self.captions = None # if self.captions is empty, indexing it can fail. Setting it to None fixes this

    special_tokens = ['<unk>', '<pad>', '<bos>', '<eos>']
    self.vocab = vocab(self.counter, specials=special_tokens, special_first=True, min_freq=2)
    self.vocab.set_default_index(self.vocab["<unk>"])

    for i in range(len(self.captions)):
      for j in range(len(self.captions[i])):
        self.captions[i][j] = self.encode(self.captions[i][j]) # mapping each caption in the nested list to an integer via encode()

  def encode(self, text):
    return [self.vocab["<bos>"]] + [self.vocab.get_stoi()[s] for s in text] + [self.vocab["<eos>"]]

  def itos(self, tens):
    return ' '.join(self.vocab.get_itos()[i] for i in tens[1:-1]) # return space-seperated string composed from sequence of integers

  def decode(self, ints):
    if ints.dim() == 1:
      return self.itos(ints)

    return [self.itos(seq) for seq in ints]

  def __len__(self):
    return self.length

  def __getitem__(self, index):
    # train images should only have 1 caption (leads to faster convergence when teacher-forcing during training)
    captions = self.captions[index][randint(0, len(self.captions[index]) - 1)] if self.split == "train" else self.captions[index]

    # storing the image into memory as a torch tensor
    image_name = path.join(self.root_dir, self.data[index]["filename"])
    image = decode_image(image_name, mode="RGB")

    return image, captions

In [None]:
transform_encoder = v2.Compose(
    [
        v2.Resize((224, 224)),
        v2.SanitizeBoundingBoxes(),
        v2.ToTensor(),
        v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)



In [None]:
!gcloud auth application-default login

In [None]:
def download_blob(bucket_name, source_blob_name, destination_file_name):

  client = storage.Client(project="Image Captioning")
  bucket = client.bucket(bucket_name)
  blob = bucket.blob(source_blob_name)
  blob.download_to_filename(destination_file_name)

In [None]:
download_blob("img-captioning", "images.cocodataset.org/zips/test2014.zip", "/content/test2014.zip")
download_blob("img-captioning", "images.cocodataset.org/zips/train2014.zip", "/content/train2014.zip")
download_blob("img-captioning", "images.cocodataset.org/zips/val2014.zip", "/content/val2014.zip")
download_blob("img-captioning", "archive.zip", "/content/archive.zip")



In [None]:
!unzip /content/test2014.zip -d /content/test2014/ && unzip /content/train2014.zip -d /content/train2014/ && unzip /content/archive.zip -d /content/archive/ && !unzip /content/val2014.zip -d /content/val2014/

In [None]:
!rm /content/test2014.zip /content/train2014.zip /content/val2014.zip /content/archive.zip

In [None]:
!cd /content/archive && ls

dataset_coco.json  dataset_flickr30k.json  dataset_flickr8k.json


In [None]:
decoder = GRU_Decoder()
decoder.to(device)

In [None]:
json_file = "/content/archive/dataset_coco.json"
root_train_dir = "/content/train2014/train2014/"
root_test_dir = "/content/test2014/test2014/"
root_val_dir = "/content/val2014/val2014/"

train_data = MiniCoco(json_file, root_train_dir, "train")
test_data = MiniCoco(json_file, root_test_dir, "test")
val_data = MiniCoco(json_file, root_val_dir, "val")

In [None]:
def pad(data):
  images, captions = zip(*data)
  images = torch.stack(images, dim=0)

  captions = pad_sequence(captions, batch_first=True, padding_value=train_data.vocab["<pad>"])

  return images, captions

In [None]:
train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=pad)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=False, collate_fn=pad)
val_dataloader = DataLoader(val_data, batch_size=32, shuffle=False, collate_fn=pad)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adagrad(decoder.parameters())

In [None]:
epochs = 20
bertscore = load("bertscore")

In [None]:
def train_one_epoch():
  running_loss = 0

  decoded_predictions = []
  decoded_captions = []

  for batch_num, data in enumerate(train_dataloader):
    images, captions = data # corresponds to image, caption

    images = images.to(device)
    captions = captions.to(device)
    # images: [B, C, H, W]

    sliced_captions = captions[:, :-1] # removing eos token
    optimizer.zero_grad() # zeroing gradients because they accumulate

    input_tensor = transform_encoder(images) # applying transformation and adding batch dimension
    feature_map = encoder(input_tensor) # getting a feature map
    outputs = decoder(feature_map, sliced_captions)

    _, predicted = torch.max(outputs.data, 2) # [batchsize, caption]
    decoded_predictions.append(train_data.decode(predicted))
    decoded_captions.append(train_data.decode(captions))

    loss = loss_fn(outputs, sliced_captions)
    loss.backward()
    running_loss += loss.item()
    optimizer.step()

  avg_loss = running_loss / len(train_dataloader)
  acc_total = sum(bertscore.compute(predictions=decoded_predictions, references=decoded_captions)["f1"]) # getting the f1 score using bertscore
  avg_acc = acc_total / len(train_dataloader)

  return avg_loss, avg_acc

In [None]:
def validate():
  decoded_predictions = []
  decoded_captions = []

  for vdata in val_dataloader:
    vimages, vcaptions = vdata

    vimages = vimages.to(device)
    vcaptions = vcaptions.to(device)

    input_tensor = transform_encoder(vimages)
    feature_map = encoder(input_tensor)
    voutputs = decoder.generate(feature_map, train_data.vocab["<bos>"], train_data.vocab["<eos>"])

    _, vpredicted = torch.max(voutputs.data, 2)
    decoded_predictions.append(val_data.decode(vpredicted))
    decoded_captions.append(val_data.decode(vcaptions))

  vacc_total = sum(bertscore.compute(predictions=decoded_predictions, references=decoded_captions)["f1"]) # getting the f1 score using bertscore
  vacc = vacc_total / len(val_dataloader)

  return vacc

In [None]:
best_acc = -1
loop = tqdm(range(epochs))

for epoch in loop:
  decoder.train()
  avg_loss, train_acc = train_one_epoch()

  decoder.eval()

  with torch.no_grad():
    vacc = validate()

  loop.set_description(f"Avg Loss: {avg_loss} \t Train Acc: {train_acc} \t Val Acc: {vacc}")

  if vacc > best_acc:
    torch.save(decoder.state_dict(), "decoder.pt")
    best_acc = vacc


In [None]:
def test():
  decoded_predictions = []
  decoded_captions = []

  for tdata in test_dataloader:
    timages, tcaptions = tdata

    timages = timages.to(device)
    tcaptions = tcaptions.to(device)

    input_tensor = transform_encoder(timages)
    feature_map = encoder(input_tensor)
    toutputs = decoder.generate(feature_map, train_data.vocab["<bos>"], train_data.vocab["<eos>"])

    _, tpredicted = torch.max(toutputs.data, 2)
    decoded_predictions.append(test_data.decode(tpredicted))
    decoded_captions.append(test_data.decode(tcaptions))

  tacc_total = sum(bertscore.compute(predictions=decoded_predictions, references=decoded_captions)["f1"]) # getting the f1 score using bertscore
  tacc = tacc_total / len(test_dataloader)

  return tacc

In [None]:
with torch.no_grad():
  decoder.eval()
  print(test())