In [1]:
import torch
import torch.utils.data as data
import os
import nltk
from PIL import Image
import json


class ROCODataset(data.Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
    def __init__(self, data_json, vocab, dataset_type=None, transform=None):
        self.vocab = vocab
        self.transform = transform
        data_ = json.load(open(data_json))
        # build vocabulary
        vocabulary_data = {}
        self.captions = []
        self.image_paths = []
        for k, v in data_.items():
            if k not in vocabulary_data:
                vocabulary_data[k] = {}
            for item in v:
                # convert paths for linux
                if not os.name == 'nt':
                    new_path = list(item.values())[0]['image_path'].replace('\\', '/')
                else:
                    new_path = list(item.values())[0]['image_path']
                dict_ = list(item.values())[0]
                dict_['image_path'] = new_path
                vocabulary_data[k][list(item.keys())[0]] = dict_
                if dataset_type is None:
                    self.captions.append(list(item.values())[0]['caption'])
                    self.image_paths.append(new_path)
                else:
                    if k == dataset_type:
                        self.captions.append(list(item.values())[0]['caption'])
                        self.image_paths.append(new_path)

    def __getitem__(self, index):
        """Returns one data pair (image and caption)."""
        vocab = self.vocab
        image = Image.open(self.image_paths[index])
        if image.mode != 'RGB':
            image = image.convert('RGB')
        caption = self.captions[index]

        if self.transform:
            image = self.transform(image)

        # Convert caption (string) to word ids.
        tokens = nltk.tokenize.word_tokenize(str(caption).lower())
        caption = [vocab('<start>')]
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        target = torch.Tensor(caption)
        return image, target

    def __len__(self):
        return len(self.image_paths)

    @staticmethod
    def collate_fn(data):
        """Creates mini-batch tensors from the list of tuples (image, caption).

        We should build custom collate_fn rather than using default collate_fn,
        because merging caption (including padding) is not supported in default.

        Args:
            data: list of tuple (image, caption).
                - image: torch tensor of shape (3, 256, 256).
                - caption: torch tensor of shape (?); variable length.

        Returns:
            images: torch tensor of shape (batch_size, 3, 256, 256).
            targets: torch tensor of shape (batch_size, padded_length).
            lengths: list; valid length for each padded caption.
        """
        # Sort a data list by caption length (descending order).
        data.sort(key=lambda x: len(x[1]), reverse=True)
        images, captions = zip(*data)

        # Merge images (from tuple of 3D tensor to 4D tensor).
        images = torch.stack(images, 0)

        # Merge captions (from tuple of 1D tensor to 2D tensor).
        lengths = [len(cap) for cap in captions]
        targets = torch.zeros(len(captions), max(lengths)).long()
        for i, cap in enumerate(captions):
            end = lengths[i]
            targets[i, :end] = cap[:end]
        return images, targets, lengths