# Import Libraries

In [13]:
import pandas as pd
import numpy as np
import os

from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import BertTokenizer

# Dataset

In [48]:
class Flickr30kDataset(Dataset):
    def __init__(self, image_dir, caption_file, transform=None, tokenizer=None, max_length=50):
        
        self.image_dir = image_dir
        self.caption_file = caption_file
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.captions = self._load_captions()
        

    def _load_captions(self):

        df = pd.read_csv(self.caption_file, delimiter="|")
        return df

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, i):

        image_name = self.captions.loc[i, "image_name"]
        caption = self.captions.loc[i, " comment"]

        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        encoding = self.tokenizer.encode_plus(
            caption,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )

        caption_tokenized = encoding['input_ids'].squeeze(0)

        return image, caption_tokenized


In [49]:
image_dir = "data/flickr30k_images"
caption_file = "data/results.csv"
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

ds = Flickr30kDataset(image_dir, caption_file, transform=transform, tokenizer=tokenizer, max_length=50)
dl = DataLoader(ds, batch_size=32)

In [50]:
a, b = next(iter(dl))

In [47]:
a.size(), b.size()

(torch.Size([32, 3, 224, 224]), torch.Size([32, 50]))