In [8]:
import os
import pandas as pd
import spacy
import torch
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from PIL import Image
import torchvision.transforms as transforms

we want to convert text -> numerical values
 1. We need a Vocabulary mapping each word to a index
 2. We need to setup a Pytorch dataset to load the data
 3. Setup padding of every batch (all examples should be of same seq_len and setup dataloader)

Note that loading the image is very easy compared to the text.

In [2]:
spacy_eng = spacy.load("en_core_web_sm")

In [36]:
class Vocabulary:
    def __init__(self, threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.threshold = threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text):
        return [ tok.text.lower() for tok in spacy_eng.tokenizer(text)]

    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1

                if frequencies[word] == self.threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

In [41]:
class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform, threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        self.imgs = self.df["image"]
        self.captions = self.df["caption"]

        self.vocab = Vocabulary(threshold)
        self.vocab.build_vocabulary(self.captions.to_list())


    def __len__(self):
        return len(self.captions)

    def __getitem__(self, index):
        image_id = self.imgs[index]
        caption_id = self.captions[index]

        image_path = os.path.join(self.root_dir, image_id)
        image = Image.open(image_path).convert("RGB")

        if self.transform is not None:
            image = self.transform(image)

        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption_id)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])

        return image, torch.tensor(numericalized_caption)

In [46]:
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim = 0)

        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value = self.pad_idx)


        return imgs, targets



In [44]:
def get_loader(
    root_folder,
    annotation_file,
    transform,
    batch_size=32,
    num_workers=4,
    shuffle=True,
    pin_memory=True,
):
    dataset = FlickrDataset(root_folder, annotation_file, transform=transform)

    pad_idx = dataset.vocab.stoi["<PAD>"]

    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn=MyCollate(pad_idx=pad_idx),
    )

    return loader, dataset

In [47]:
if __name__ == "__main__":
    transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ]
    )

    loader, dataset = get_loader(
        "/kaggle/input/flickr8k/Images", "/kaggle/input/flickr8k/captions.txt", transform=transform
    )

    for idx, (imgs, captions) in enumerate(loader):
        print(imgs.shape)
        print(captions.shape)
        
        if idx == 20:
            break

torch.Size([32, 3, 224, 224])
torch.Size([22, 32])
torch.Size([32, 3, 224, 224])
torch.Size([23, 32])
torch.Size([32, 3, 224, 224])
torch.Size([25, 32])
torch.Size([32, 3, 224, 224])
torch.Size([33, 32])
torch.Size([32, 3, 224, 224])
torch.Size([23, 32])
torch.Size([32, 3, 224, 224])
torch.Size([36, 32])
torch.Size([32, 3, 224, 224])
torch.Size([22, 32])
torch.Size([32, 3, 224, 224])
torch.Size([22, 32])
torch.Size([32, 3, 224, 224])
torch.Size([19, 32])
torch.Size([32, 3, 224, 224])
torch.Size([23, 32])
torch.Size([32, 3, 224, 224])
torch.Size([19, 32])
torch.Size([32, 3, 224, 224])
torch.Size([23, 32])
torch.Size([32, 3, 224, 224])
torch.Size([22, 32])
torch.Size([32, 3, 224, 224])
torch.Size([25, 32])
torch.Size([32, 3, 224, 224])
torch.Size([17, 32])
torch.Size([32, 3, 224, 224])
torch.Size([23, 32])
torch.Size([32, 3, 224, 224])
torch.Size([20, 32])
torch.Size([32, 3, 224, 224])
torch.Size([22, 32])
torch.Size([32, 3, 224, 224])
torch.Size([23, 32])
torch.Size([32, 3, 224, 224])
t