Preprocessing and formating data to be used for model 

In [None]:
import os  
import pandas as pd  
import spacy  
import torch
from torch.nn.utils.rnn import pad_sequence 
from torch.utils.data import DataLoader, Dataset
from PIL import Image 
import torchvision.transforms as transforms
import re
from transformers import BertTokenizer


tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"} # integer to string 
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3} #STRING TO INTEGER
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
            #tokenized_sentence = [tok.text.lower() for tok in spacy_eng.tokenizer(sentence)]
            tokenized_sentence = tokenizer.tokenize(sentence)
            for word in tokenized_sentence:
                frequencies[word] = frequencies.get(word, 0) + 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = tokenizer.tokenize(text)

        return [self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
                for token in tokenized_text]
    
    
class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        self.imgs = self.df["image"]
        self.captions = self.df["caption"]

        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])

        return img, torch.tensor(numericalized_caption)


transform = transforms.Compose([
    
        transforms.Resize((356, 356)),
        transforms.RandomCrop((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

dataset = FlickrDataset(root_dir='D:/pytorch_exercise/dataset/flickr8k/flickr8k/images', 
                        captions_file='D:/pytorch_exercise/dataset/flickr8k/flickr8k/captions.txt', 
                        transform=transform,
                        freq_threshold=5)

class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
        
        
    def __call__(self, batch):
        images = [item[0].unsqueeze(0) for item in batch]
        images = torch.cat(images, dim=0)
        
        # Pad the sequences with zeros to make them the same length
        lengths = [len(sample[1]) for sample in batch]
        max_length = max(lengths)
        padded_batch = torch.full((len(batch), max_length), self.pad_idx, dtype=torch.long)
        for i, sample in enumerate(batch):
            padded_batch[i, :len(sample[1])] = torch.LongTensor(sample[1])        

        targets = padded_batch
        
        return images, targets    
    
    
pad_idx = dataset.vocab.stoi["<PAD>"]

loader = DataLoader(
        dataset=dataset,
        batch_size=64,
        shuffle=True,
        collate_fn=MyCollate(pad_idx=pad_idx))


# x, y = next(iter(loader))

