In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time
import os
import random
import cv2
from google.colab.patches import cv2_imshow
import json
import string
from collections import Counter
import nltk
from nltk.tokenize import word_tokenize
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# download the COCO dataset
!wget http://images.cocodataset.org/zips/train2017.zip
!wget http://images.cocodataset.org/zips/val2017.zip
!wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip

# unzip the files
!unzip train2017.zip
!unzip val2017.zip
!unzip annotations_trainval2017.zip


# Data Exploration

In [None]:
def plot_random_images(path, num_cols=3, num_rows=3):
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 12))
    for i in range(num_rows):
        for j in range(num_cols):
            img = cv2.imread(path + random.choice(os.listdir(path)))
            axes[i, j].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            axes[i, j].axis('off')
    plt.show()

# plot some random images from the training set
plot_random_images('train2017/')


In [None]:
# what is the size of the train2017 folder?
print('Number of images in train2017 folder: ', len(os.listdir('train2017/')))
# what is the size of the val2017 folder?
print('Number of images in val2017 folder: ', len(os.listdir('val2017/')))
# list the files in the annotations folder
print('Files in annotations folder: ', os.listdir('annotations/'))
# load the annotations file
with open('annotations/instances_train2017.json') as f:
    instances = json.load(f)
# print the keys of the instances dictionary
print('Keys of instances dictionary: ', instances.keys())
# print the keys of the instances['categories'] dictionary
print('Keys of instances[\'categories\'] dictionary: ', instances['categories'][0].keys())
# print the keys of the instances['annotations'] dictionary
print('Keys of instances[\'annotations\'] dictionary: ', instances['annotations'][0].keys())
# print the keys of the instances['images'] dictionary
print('Keys of instances[\'images\'] dictionary: ', instances['images'][0].keys())
# print the number of categories
print('Number of categories: ', len(instances['categories']))
# captions_train2017.json 
with open('annotations/captions_train2017.json') as f:
    captions = json.load(f)
# print the keys of the captions dictionary
print('Keys of captions dictionary: ', captions.keys())
# print the keys of the captions['annotations'] dictionary
print('Keys of captions[\'annotations\'] dictionary: ', captions['annotations'][0].keys())
# print the keys of the captions['images'] dictionary
print('Keys of captions[\'images\'] dictionary: ', captions['images'][0].keys())
# print the number of images
print('Number of images: ', len(captions['images']))

In [None]:
def find_max_length(captions):
    # find the maximum length of the captions
    max_len = max(len(caption['caption'].split()) for caption in captions['annotations'])
    print('Maximum length of the captions: ', max_len)
    return max_len
max_len = find_max_length(captions)

In [None]:
# preprocess the captions
nltk.download('punkt')
nltk.download('stopwords')
# define the punctuation and the stopwords
punctuation = string.punctuation
stopwords = nltk.corpus.stopwords.words('english')

In [None]:
def preprocess_captions(captions, max_len):
    # create a dictionary to store the captions
    captions_dict = {cap['image_id']: [] for cap in captions['annotations']}
    for caption in captions['annotations']:
        # get the image id
        image_id = caption['image_id']
        # remove stopwords
        caption = ' '.join([word for word in caption['caption'].split() if word.lower() not in stopwords])
        # remove the punctuation
        caption = caption.translate(str.maketrans('', '', string.punctuation))
        # tokenize the caption and remove the numbers
        tokens = [token.lower() for token in caption.split() if token.isalpha()]
        # add the tokens to the dictionary
        captions_dict[image_id].append(tokens)
        # Add start and end tokens to each caption
        captions_dict[image_id][-1].insert(0, '<start>')
        # Calculate the padding
        padding = max_len - len(captions_dict[image_id][-1])
        # pad the captions so that they are all of the same length
        captions_dict[image_id][-1].extend(['<pad>'] * padding)
    return captions_dict

# preprocess the captions
captions_dict = preprocess_captions(captions, max_len)
print(captions_dict[62443])

In [None]:
def create_vocabulary(captions_dict):
    # create a counter object
    counter = Counter(word for image_id in captions_dict for caption in captions_dict[image_id] for word in caption)
    # get the most common words
    most_common_words = counter.most_common()
    # create a vocabulary
    vocabulary = {word[0]: i + 1 for i, word in enumerate(most_common_words)}
    return vocabulary

vocabulary = create_vocabulary(captions_dict)
# add <end> token to the vocabulary
vocabulary['<end>'] = len(vocabulary) + 1
print("Size of the vocabulary:", len(vocabulary))

In [None]:
def convert_captions_to_integers(captions_dict, vocabulary):
    captions_dict_int = {image_id: [[vocabulary[token] for token in caption] for caption in captions_dict[image_id]] for image_id in captions_dict}
    return captions_dict_int

captions_dict_int = convert_captions_to_integers(captions_dict, vocabulary)
print(captions_dict_int[62443])

In [None]:
# create dataset class for image captioning using COCO dataset and preprocessing steps like above given annotations file path 
class ImageCaptioningDataset(Dataset):
    """
    The class ImageCaptioningDataset is a custom dataset class that is used to read the images and captions from the COCO dataset.

    Args:
        root_dir (string): The path to the root directory of the COCO dataset.
        ann_file: the path to the annotations file that contains the captions for the images.
        vocabulary: a dictionary that maps words to integers.
        max_len: the maximum length of the captions.
        transform: an optional PyTorch transform that is applied to the images before returning them.

    Class Methods:
        __init__(): This method is called when an object of the class is created. It loads the captions from the annotations
        file and preprocesses them by removing stopwords, punctuations, and numbers and converts the captions to integers using the vocabulary. 
        It also saves the image ids in a list.

        __len__(): This method returns the length of the dataset.

        __getitem__(): This method returns the image and the caption for a given index.

        preprocess_captions_and_convert_to_integers(): A helper function to preprocess captions and convert them to integers in one step.

    Example:
        >>> dataset = ImageCaptioningDataset(root_dir='train2017', ann_file='annotations/captions_train2017.json', vocabulary=vocabulary, max_len=max_len)
        >>> print('Length of the dataset: ', len(dataset))
        >>> image, caption = dataset[0]
        >>> print('Image shape: ', image.shape)
    """
    def __init__(self, root_dir, ann_file, vocabulary, max_len, transform=None):
        self.root_dir = root_dir
        self.vocabulary = vocabulary
        self.transform = transform
        # Load the captions and preprocess them in one step using the provided functions
        with open(ann_file) as f:
            captions = json.load(f)
        self.captions_dict_int = self.preprocess_captions_and_convert_to_integers(captions, vocabulary, max_len)
        self.image_ids = list(self.captions_dict_int.keys())
        
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        # get the image id
        image_id = self.image_ids[idx]
        # get the image path (images are in the train2017 folder)
        image_path = os.path.join(self.root_dir,  '%012d.jpg' % (image_id))
        # read the image using PIL
        image = Image.open(image_path).convert('RGB')
        # transform the image
        if self.transform is not None:
          image = self.transform(image)
        # get the captions
        captions = self.captions_dict_int[image_id]
        # randomly select a caption
        captions = random.choice(captions)
        # convert the captions to a tensor
        captions = torch.tensor(captions)
        return image, captions
    def preprocess_captions_and_convert_to_integers(self, captions, vocabulary, max_len):
        # preprocess the captions
        captions_dict = {cap['image_id']: [] for cap in captions['annotations']}
        for caption in captions['annotations']:
            # get the image id
            image_id = caption['image_id']
            # remove stopwords
            caption = ' '.join([word for word in caption['caption'].split() if word.lower() not in stopwords])
            # remove the punctuation
            caption = caption.translate(str.maketrans('', '', string.punctuation))
            # tokenize the caption and remove the numbers
            tokens = [token.lower() for token in caption.split() if token.isalpha()]
            # add the tokens to the dictionary
            captions_dict[image_id].append(tokens)
            # Add start and end tokens to each caption
            captions_dict[image_id][-1].insert(0, '<start>')
            # Calculate the padding
            padding = max_len - len(captions_dict[image_id][-1])
            # pad the captions so that they are all of the same length
            captions_dict[image_id][-1].extend(['<pad>'] * padding)
        # convert the captions to integers
        captions_dict_int = {image_id: [[vocabulary[token] for token in caption] for caption in captions_dict[image_id]] for image_id in captions_dict}
        return captions_dict_int



# create the dataset
root_dir = 'train2017'
ann_file = 'annotations/captions_train2017.json'
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
dataset = ImageCaptioningDataset(root_dir, ann_file, vocabulary, transform=transform, max_len=max_len)

In [None]:
# show an image and its captions
index = random.randint(0, len(dataset))
image, captions = dataset[index]
print('Image shape: ', image.shape)
print('Captions shape: ', captions.shape)
# convert the captions to a list of words remove the padding and start and end tokens
captions = [list(vocabulary.keys())[list(vocabulary.values()).index(token)] for token in captions if token not in [vocabulary['<pad>'], vocabulary['<start>'], vocabulary['<end>']]]
print(' '.join(captions))
# show the image
plt.imshow(image.permute(1, 2, 0))
plt.axis('off')
plt.show()

In [None]:
# create the dataloader
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

In [None]:
# define A CNN-RNN model for image captioning in which the CNN is a pretrained ResNet-50 model(use timm library)
# Each image is passed through the CNN and the output along with <start> token is passed to the RNN.
# The RNN outputs the next word and the process is repeated until the <end> token is generated or the maximum length of the caption is reached.
class ImageCaptioningModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(ImageCaptioningModel, self).__init__()
        # encoder (CNN) - pretrained ResNet-50 model (pass output to another linear layer of size embedding_dim)
        self.encoder = timm.create_model('resnet50', pretrained=True)
        self.encoder.fc = nn.Linear(self.encoder.fc.in_features, embedding_dim)
        for param in self.encoder.parameters():
            param.requires_grad = False
        for param in self.encoder.fc.parameters():
            param.requires_grad = True
        # decoder (RNN) - LSTM layer
        self.decoder = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        # linear layer
        self.linear = nn.Linear(hidden_dim, vocab_size)
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
    def forward(self, images, captions):
        # pass the images through the encoder
        features = self.encoder(images)
        captions = self.embedding(captions)
        # concatenate the features and captions
        inputs = torch.cat((features.unsqueeze(1), captions), dim=1)
        # pass the inputs through the decoder
        outputs, _ = self.decoder(inputs)
        # pass the outputs through the linear layer
        outputs = self.linear(outputs)
        return outputs
# define the hyperparameters
vocab_size = len(vocabulary)
embedding_dim = 256
hidden_dim = 512
model = ImageCaptioningModel(vocab_size, embedding_dim, hidden_dim)
model = model.to(device)

# define the loss function
criterion = nn.CrossEntropyLoss(ignore_index=vocabulary['<pad>'])

# define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)




In [None]:
# train the model for 10 epochs
# note that the model takes an image and passes it through the encoder to get the features
# then it passes the features as the first input to the decoder, <start> token as the second input to the decoder and so on
# the decoder outputs at first the <start> token and then the next word and so on
num_epochs = 10
for epoch in range(num_epochs):
    for i, (images, captions) in enumerate(dataloader):
        images = images.to(device)
        captions = captions.to(device)
        # zero the gradients
        optimizer.zero_grad()
        # forward pass
        print(captions.shape)
        # add the <end> token to the captions
        captions = torch.cat((captions, torch.ones((captions.shape[0], 1), dtype=torch.long).to(device) * vocabulary['<end>']), dim=1)
        print(captions.shape)
        outputs = model(images, captions[:, :-1])
        print(outputs.shape)

        # calculate the loss
        loss = criterion(outputs.reshape(-1, vocab_size), captions.reshape(-1))
        print(outputs.reshape(-1, vocab_size).shape)
        print(captions.reshape(-1).shape)
        # backward pass
        loss.backward()
        # update the weights
        optimizer.step()
        # print the loss
        if i % 100 == 0:
            print('Epoch: {}/{} | Step: {}/{} | Loss: {:.4f}'.format(epoch+1, num_epochs, i, len(dataloader), loss.item()))
