In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time
import os

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# download the COCO dataset
!wget http://images.cocodataset.org/zips/train2017.zip
!wget http://images.cocodataset.org/zips/val2017.zip
!wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip

# unzip the files
!unzip train2017.zip
!unzip val2017.zip
!unzip annotations_trainval2017.zip


In [None]:
# define the transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# define the dataset
trainset = torchvision.datasets.CocoDetection(root='./train2017', annFile='./annotations/instances_train2017.json', transform=transform)
valset = torchvision.datasets.CocoDetection(root='./val2017', annFile='./annotations/instances_val2017.json', transform=transform)

In [None]:
# plot 9 random images from the dataset
def plot_random_images(dataset, num_cols=3, num_rows=3):
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 12))
    for i in range(num_rows):
        for j in range(num_cols):
            index = np.random.randint(len(dataset))
            image = dataset[index][0].numpy().transpose(1, 2, 0)
            image = (image + 1) / 2
            axes[i, j].imshow(image)
            axes[i, j].axis('off')
    plt.show()

plot_random_images(trainset)


In [None]:
# what is the size of the dataset?
print('Size of the training dataset:', len(trainset))
print('Size of the validation dataset:', len(valset))

# what is the number of classes in the dataset? 
num_of_classes = len(trainset.coco.getCatIds())
print('Number of classes in the dataset:', num_of_classes)


In [None]:
# what does the captions_train2017.json file contain?
import json
with open('./annotations/captions_train2017.json') as f:
    captions = json.load(f)
print(captions.keys())
print(captions['info'])
print(captions['licenses'])
print(captions['images'][0])
print(captions['annotations'][0])


In [None]:
# preprocess the captions
import string
from collections import Counter
import nltk
nltk.download('punkt')
nltk.download('stopwords')
from nltk.tokenize import word_tokenize

# define the punctuation and the stopwords
punctuation = string.punctuation
stopwords = nltk.corpus.stopwords.words('english')

# define the function to preprocess the captions
def preprocess_captions(captions):
    # create a dictionary to store the captions
    captions_dict = {}
    # loop through the captions
    for caption in captions['annotations']:
        # get the image id
        image_id = caption['image_id']
        # get the caption
        caption = caption['caption']
        # remove the punctuation
        caption = caption.translate(str.maketrans('', '', punctuation))
        # tokenize the caption
        tokens = word_tokenize(caption.lower())
        # remove the stopwords
        tokens = [token for token in tokens if token not in stopwords]
        # remove the tokens with length less than 2
        tokens = [token for token in tokens if len(token) > 1]
        # remove the numbers
        tokens = [token for token in tokens if token.isalpha()]
        # add the tokens to the dictionary
        if image_id not in captions_dict:
            captions_dict[image_id] = []
        captions_dict[image_id].append(tokens)
    return captions_dict

# preprocess the captions
captions_dict = preprocess_captions(captions)
print(captions_dict[62443])


In [None]:
# Add <start> and <end> tokens to the captions
for image_id in captions_dict:
    for caption in captions_dict[image_id]:
        caption.insert(0, '<start>')
        caption.append('<end>')
print(captions_dict[62443])

# What is the maximum length of the captions?
max_length = 0
for image_id in captions_dict:
    for caption in captions_dict[image_id]:
        max_length = max(max_length, len(caption))

In [None]:
# Add <pad> tokens to the captions
for image_id in captions_dict:
    for caption in captions_dict[image_id]:
        while len(caption) < max_length:
            caption.append('<pad>')
print(captions_dict[62443])


In [None]:
# create a vocabulary
def create_vocabulary(captions_dict):
    # create a list to store the words
    words = []
    # loop through the captions
    for image_id in captions_dict:
        for caption in captions_dict[image_id]:
            words.extend(caption)
    # create a counter object
    counter = Counter(words)
    # get the most common words
    most_common_words = counter.most_common()
    # create a vocabulary
    vocabulary = {}
    for i, word in enumerate(most_common_words):
        vocabulary[word[0]] = i + 1
    return vocabulary

# create the vocabulary
vocabulary = create_vocabulary(captions_dict)
print("Size of the vocabulary:", len(vocabulary))

In [None]:
# create a function to convert the captions to integers
def convert_captions_to_integers(captions_dict, vocabulary):
    # create a dictionary to store the captions
    captions_dict_int = {}
    # loop through the captions
    for image_id in captions_dict:
        captions_dict_int[image_id] = []
        for caption in captions_dict[image_id]:
            caption_int = []
            for token in caption:
                caption_int.append(vocabulary[token])
            captions_dict_int[image_id].append(caption_int)
    return captions_dict_int

# convert the captions to integers
captions_dict_int = convert_captions_to_integers(captions_dict, vocabulary)
print(captions_dict_int[62443])

In [None]:
# create a function to get the captions for a given image id
def get_captions_for_image_id(captions_dict, image_id):
    captions = []
    for caption in captions_dict[image_id]:
        caption = ' '.join(caption)
        captions.append(caption)
    return captions

# get the captions for a given image id
image_id = 62443
captions = get_captions_for_image_id(captions_dict, image_id)
print(captions)


In [None]:
# create dataset class for image captioning using COCO dataset
class ImageCaptioningDataset(Dataset):
    def __init__(self, root_dir, ann_file, vocabulary, transform=None):
        self.root_dir = root_dir
        self.ann_file = ann_file
        self.vocabulary = vocabulary
        self.transform = transform
        self.coco = COCO(ann_file)
        self.ids = list(self.coco.anns.keys())
        
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, index):
        coco = self.coco
        vocabulary = self.vocabulary
        ann_id = self.ids[index]
        ann = coco.anns[ann_id]
        img_id = ann['image_id']
        caption = ann['caption']
        
        path = coco.loadImgs(img_id)[0]['file_name']
        
        image = Image.open(os.path.join(self.root_dir, path)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        # preprocess the caption
        caption = caption.translate(str.maketrans('', '', string.punctuation))
        tokens = word_tokenize(caption.lower())
        tokens = [token for token in tokens if token not in stopwords]
        tokens = [token for token in tokens if len(token) > 1]
        tokens = [token for token in tokens if token.isalpha()]

        # add <start> and <end> tokens to the caption
        tokens.insert(0, '<start>')
        tokens.append('<end>')

        # convert the tokens to integers
        caption = []
        for token in tokens:
            caption.append(vocabulary[token])

        # pad the caption
        while len(caption) < max_length:
            caption.append(vocabulary['<pad>'])

        # convert the caption to a tensor
        caption = torch.tensor(caption)

        return image, caption

# create the dataset
root_dir = 'train2017'
ann_file = 'annotations/captions_train2017.json'
dataset = ImageCaptioningDataset(root_dir, ann_file, vocabulary, transform=transform)

In [None]:
# get the image and caption for a given index
index = 100
image, caption = dataset[index]
print("Caption:", caption)

idx2word = {v: k for k, v in vocabulary.items()}
print("Caption:", ' '.join([idx2word[i.item()] for i in caption if i not in [vocabulary['<start>'], vocabulary['<end>'], vocabulary['<pad>']]]))
# create a function to display an image and its caption (remove <start> and <end> and <pad> tokens)
def display_image_and_caption(image, caption):
    # remove <start> and <end> and <pad> tokens
    caption = [i for i in caption if i not in [vocabulary['<start>'], vocabulary['<end>'], vocabulary['<pad>']]]
    # convert the caption to a string
    caption = [idx2word[i.item()] for i in caption]
    caption = ' '.join(caption)
    # display the image and the caption
    image = image.permute(1, 2, 0)
    image = image.numpy()
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)

    plt.imshow(image)
    plt.title(caption)
    plt.axis('off')
    plt.show()

# display the image and the caption
display_image_and_caption(image, caption)


In [None]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
# create a function to collate the data
def collate_fn(data):
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    # merge images (from tuple of 3D tensor to 4D tensor)
    images = torch.stack(images, 0)

    # merge captions (from tuple of 1D tensor to 2D tensor)
    lengths = [len(cap) for cap in captions]
    captions = pad_sequence(captions, batch_first=True, padding_value=vocabulary['<pad>'])

    return images, captions, lengths

# create the data loader
batch_size = 32
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
