In [1]:
# we want to convert text ==> numerical value
# 1- we need to vocabulary mapping each word to a index
# 2- we need to setup a pytorch dataset to load the data
# 3- setup padding of every batch


In [2]:
import os
from customDataset import Cat
import pandas as pd
import numpy as np
import torch
import spacy #for tokenize
from torch.nn.utils.rnn import pad_sequence #pad batch
from torch.utils.data import Dataset, DataLoader
from PIL import Image #load image
import torchvision.transforms as transforms

In [3]:
!python -m spacy download en_core_web_sm # download the en_core_web_sm model
import spacy

Collecting en-core-web-sm==3.7.1
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m56.9 MB/s[0m eta [36m0:00:00[0m
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.


In [5]:
spacy_eng = spacy.load("en_core_web_sm")

class vocabulary:

  def __init__(self, freq_threshold):
    self.itos = {0:"<PAD>", 1:"<SOS>", 2:"<EOS>", 3:"<UNK>"}
    self.stoi = {"<PAD>":0, "<SOS>":1, "<EOS>":2, "<UNK>":3}
    self.freq_threshold = freq_threshold

  def __len__(self):
    return len(self.itos)

  @staticmethod
  def tokenizer_eng(text):
    return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

  def build_vocabulary(self, sentence_list):
    frequencies = {}
    idx = 4
    for sentence in sentence_list:
      for word in self.tokenizer_eng(sentence):
        if word not in frequencies:
          frequencies[word] = 1
        else:
          frequencies[word] += 1

        if frequencies[word] == self.freq_threshold:
          self.stoi[word] = idx
          self.itos[idx] = word
          idx += 1

  def numericalize(self, text):
    tokenized_text = self.tokenizer_eng(text)

    return [self.stoi[token] if token in self.stoi else self.stoi["<UNK>"] for token in tokenized_text]



In [None]:
class mycollate:
  def __init__(self, pad_idx):
    self.pad_idx = pad_idx

  def __call__(self, batch):
    imgs = [item[0].unsqueeze(0) for item in batch]
    imgs = torch.cat(imgs, dim=0)
    targets = [item[1] for item in batch]
    targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)

    return imgs, targets

  def get_loader(self, root_folder, annotation_file, transform, batch_size=32, num_workers=8, shuffle=True, pin_memory=True):
    dataset = flickerDataset(root_folder, annotation_file, transform=transform)
    pad_idx = dataset.vocab.stoi["<PAD>"]
    loader = DataLoader(dataset=dataset,
                        batch_size=batch_size,
                        num_workers=num_workers,
                        shuffle=shuffle,
                        pin_memory=pin_memory,
                        collate_fn=mycollate(pad_idx = pad_idx))
    return loader

def main():
  transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
  ])

  #add data and annotation
  data_loader = get_loader(root_folder, annotation_file, transrorm = transform)
  for idx, (imgs, captions) in enumerate(data_loader):
    print(imgs.shape, captions.shape)

if __name__ == "__main__":
  main()




In [None]:
class flickerDataset(Dataset):

  def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
    self.root_dir = root_dir
    self.df = pd.read_csv(captions_file)
    self.transform = transform

    #get image and caption column
    self.imgs = self.df["image"]
    self.captions = self.df["caption"]

    #initial and build vocabulary
    self.vocab = Vocabulary(freq_threshold)
    self.vocab.build_vocabulary(self.captions.tolist())

  def __len__(self):
    return len(self.df)

  def __getitem__(self, index):
    caption = self.captions[index]
    img_id = self.imgs[index]
    img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")

    if self.transform is not None:
      img = self.transform(img)

    numericalized_caption = [self.vocab.stoi["<SOS>"]]
    numericalized_caption += self.vocab.numericalize(caption)
    numericalized_caption.append(self.vocab.stoi["<EOS>"])

    return img, torch.tensor(numericalized_caption)



