In [1]:
import os
import pandas as pd
import spacy
import torch
from  torch.nn.utils.rnn import pack_sequence
from torch.utils.data import DataLoader,Dataset
from PIL import Image
import torchvision.transforms as transforms



  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class Vocabulary:
    # freq_threshold if the word is repeted in the dataset, 
    # if the word is not repeted an number of times we re going to ignore it
    def __init__(self, freq_threshold):
        self.index_to_string = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.string_to_index = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold
        def __len__(self):
            return len(self.itos)
        

    # tokeneize the sentence
    # convet a sentence to an array without spaces
    @staticmethod
    def tokenizer_eng(text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

    def build_vocabulary(self, sentence_list):
        frequencies = {}
        # the first parit of the sentece is already in the init part
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1

                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.string_to_index[word] = idx
                    self.index_to_string[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.string_to_index[token] if token in self.string_to_index else self.string_to_index["<UNK>"]
            for token in tokenized_text
        ]



In [4]:
class FlickrDataset(Dataset):
    def __init__(self,root_dir, caption_file,transform=None,freq_threshold=5) :
        self.root_dir =root_dir
        self.caption_dataFrame = pd.read_csv(caption_file)
        self.transfroms= transform
        # print (self.caption_dataFrame)
        #getting images and their captions
        self.imgs = self.caption_dataFrame["image"]
        self.captions = self.caption_dataFrame["caption"]

        #initialize vocalary and build vocab
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())
        

    
    
    def __len__(self):
        return len(self.caption_dataFrame)
    
    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir,img_id)).convert("RGB")
        
        if self.transfroms is not None:
            img = self.transfroms(img)
        numericalized_caption = [self.vocab.string_to_index["<SOS>"]]
        numericalized_caption += self.vocab.numericalized_caption(caption)
        numericalized_caption.append(self.vocab.string_to_index["<EOS>"])

        return img, torch.tensor(numericalized_caption)

        

In [5]:
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)

        return imgs, targets

In [6]:
def get_loader(
    root_folder,
    annotation_file,
    transform,
    batch_size=32,
    num_workers=8,
    shuffle=True,
    pin_memory=True,
):
    dataset = FlickrDataset(root_folder, annotation_file, transform=transform)

    pad_idx = dataset.vocab.string_to_index["<PAD>"]

    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn=MyCollate(pad_idx=pad_idx),
    )

    return loader, dataset





In [7]:
if __name__ == "__main__":
    transform = transforms.Compose(
        [transforms.Resize((224, 224)), transforms.ToTensor(),]
    )

    loader, dataset = get_loader(
        "../Datasets/Flickr/Images/", "../Datasets/Flickr/captions.txt", transform=transform
    )

    for idx, (imgs, captions) in enumerate(loader):
        print(imgs.shape)
        print(captions.shape)