<a href="https://colab.research.google.com/github/vjhawar12/Image-Captioning/blob/main/Image_Captioning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
from torchvision.transforms import v2
import torch.nn as nn
from torchvision.datasets import CocoDetection
from torch.utils.data import DataLoader, Dataset
from pycocotools.coco import COCO
from pprint import pprint
import pandas as pd
from skimage import io
from os import path
from random import randint
from torchtext import vocab
from collections import Counter

In [None]:
!pip -q install torchtext

In [None]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True) # feature map: [1, 1280]

model.classifier = nn.Identity() # removing the final classification layer to retrieve the feature map

In [None]:
class GRU_Decoder(nn.Module):

  def __init__(self, feature_map_size=1280, embed_size=256, hidden_size=512, num_layers=2, vocab_size=10000):
    super().__init__()

    self.feature_map_size = feature_map_size
    self.embed_size = embed_size
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.vocab_size = vocab_size

    self.embed = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.embed_size)
    self.proj = nn.Linear(in_features=self.feature_map_size, out_features=self.hidden_size)
    self.gru = nn.GRU(input_size=self.embed_size, hidden_size=self.hidden_size, num_layers=self.num_layers)
    self.fc = nn.Linear(in_features=self.hidden_size, out_features=self.vocab_size)

  def forward(self, x, words, feature_map):
    batch_size = feature_map.size(0)
    words = self.embed(words)
    h0 = self.proj(feature_map).unsqueeze(0)
    h0 = h0.reshape(self.num_layers, batch_size, self.hidden_size)
    output, _ = self.gru(words, h0)
    logits = self.fc(output)

    return logits



In [None]:
class MiniCoco(Dataset):

  def __init__(self, json_file, root_dir, split="train", transform=None):
    super().__init__()

    self.full_data = pd.read_json(json_file)
    self.data = self.full_data["root"]["images"]
    self.split = split
    self.counter = Counter()
    self.captions = []

    if self.split == "train":
      self.data = [obj for obj in self.data if obj["split"] == "restval"]
    elif self.split == "val":
      self.data = [obj for obj in self.data if obj["split"] == "val"]
    elif self.split == "test":
      self.data = [obj for obj in self.data if obj["split"] == "test"]
    else:
      raise Exception("Invalid split")

    self.root_dir = root_dir
    self.transform = transform

    for i in range(len(self.data)):
      cap = []

      for j in range(len(self.data[i]["sentences"])):
        caption = self.data[i]["sentences"][j]["tokens"]
        self.counter.update(caption)
        cap.append(caption)

      self.captions.append(cap)

    self.vocab = vocab.vocab(self.counter, min_freq=1)
    self.vocab.set_default_index(self.vocab["<unk>"])

    for i in range(len(self.captions)):
      for j in range(len(self.captions[i])):
        self.captions[i][j] = self.encode(self.captions[i][j])

  def __len__(self):
    return len(self.data)

  def encode(self, text):
    return [self.vocab.get_stoi()[s] for s in text]

  def __getitem__(self, index):
    captions = self.captions[index][randint(0, len(self.captions[index]) - 1)] if self.split == "train" else self.captions[index]

    image_name = path.join(self.root_dir, self.data[index]["filename"])
    image = io.imread(image_name)

    if self.transform:
      sample = self.transform(sample)

    return image, captions

In [None]:
transform = v2.Compose(
    [
        v2.Resize((224, 224)),
        v2.SanitizeBoundingBoxes(),
        v2.ToTensor(),
        v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

