In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import os
import numpy as np
import random

import torch
from torch.utils.data import Dataset, DataLoader

In [2]:
# !pip install lmdb tensorboardX

In [3]:
class HybridLoader:
    def __init__(self, db_path, ext):
        self.db_path = db_path
        self.ext = ext
        self.loader = lambda x: torch.from_numpy(np.load(x))

    def get(self, key):
        f_input = os.path.join(self.db_path, key + self.ext)
        feat = self.loader(f_input)
        return feat

In [4]:
class YFCC_3M(Dataset):
    def __init__(self, opt):
        self.opt = opt
        self.seq_per_img = opt.seq_per_img
        self.seq_length = 16

        self.ext = ".npy"
        self.fc_folder = self.opt.input_fc_dir
        self.att_folder = self.opt.input_att_dir

        # feature related options
        self.use_fc = getattr(opt, 'use_fc', True)
        self.use_att = getattr(opt, 'use_att', True)
        self.use_box = getattr(opt, 'use_box', 0)
        self.use_ps = getattr(opt,'use_ps',True)

        self.use_ps = getattr(opt,'use_ps',True)
        self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
        self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
        self.use_dencap = True

        if self.opt.input_label_h5 != 'none':
            self.label = np.load(self.opt.input_label_h5)
            seq_size = self.label.shape
            self.seq_length = seq_size[1]
            self.label_start_ix = np.load(self.opt.input_label_start_idx)
            self.label_end_ix = np.load(self.opt.input_label_end_idx)

        if self.opt.perss_onehot_h5 != 'none':
            self.perss_onehot = np.load(self.opt.perss_onehot_h5)

        if self.use_dencap:
            self.densecap = np.load(self.opt.densecap_dir)

        # load the json file which contains additional information about the dataset
        self.info = json.load(open(self.opt.metadata_json))
        self.img = json.load(open(self.opt.input_json))

        if 'ix_to_word' in self.info:
            self.ix_to_word = self.info['ix_to_word']
            self.vocab_size = len(self.ix_to_word)
            print('vocab size is ', self.vocab_size)

        if 'pix_to_personality' in self.info:
            self.pix_to_personality = self.info['pix_to_personality']
            self.perss_size = len(self.pix_to_personality)
            print('personality size is ', self.perss_size)

        # open the hdf5 file
        print('loading json files: ', opt.input_fc_dir, opt.input_att_dir, opt.input_label_h5)

        self.num_images = len(self.img)

    def get_vocab_size(self):
        return self.vocab_size

    def get_vocab(self):
        return self.ix_to_word

    def get_personality(self):
        return self.pix_to_personality

    def get_seq_length(self):
        return self.seq_length

    def __len__(self):
        return len(self.img)

    def get_captions(self, ix, seq_per_img):
        # fetch the sequence labels
        ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
        ix2 = self.label_end_ix[ix] - 1
        ncap = ix2 - ix1 + 1 # number of captions available for this image
        assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'

        if ncap < seq_per_img:
            # we need to subsample (with replacement)
            seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
            for q in range(seq_per_img):
                ixl = random.randint(ix1,ix2)
                seq[q, :] = self.label[ixl, :self.seq_length]
        else:
            ixl = random.randint(ix1, ix2 - seq_per_img + 1)
            seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]

        return seq

    def get_labels(self, index):
        sequence = np.int32(self.get_captions(index, self.seq_per_img))
        tmp_label = np.zeros([self.seq_per_img, self.seq_length + 2], dtype = 'int')
        tmp_label[:, 1 : self.seq_length + 1] = sequence
        return tmp_label

    def get_gts(self, index):
        gts = self.label[self.label_start_ix[index] - 1: self.label_end_ix[index]]
        return np.int32(gts)

    def get_masks(self, sequence):
        nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, sequence)))
        mask_batch = np.zeros([sequence.shape[0], self.seq_length + 2], dtype = 'float32')
        for ix, row in enumerate(mask_batch):
            row[:nonzeros[ix]] = 1
        return mask_batch

    def __getitem__(self, index):
        """
        This function returns a tuple that is further passed to collate_fn
        """ 
        indices = np.array([index]).astype("int")
        ix = index
        indicator = np.zeros([self.vocab_size+1])
        img_hash = str(self.img[ix]['id'])

        att_feat = torch.from_numpy(np.load(os.path.join(self.att_folder, img_hash + self.ext)))
        # Reshape to K x C
        att_feat = att_feat.squeeze()
        if (att_feat.shape[-1] == self.opt.att_feat_size):
            print(self.opt.att_feat_size)
        else:
            att_feat = att_feat.permute(1,2,0)
        att_feat = att_feat.reshape(-1, att_feat.shape[-1])
        if self.norm_att_feat:
            att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)

        fc_feat = torch.from_numpy(np.load(os.path.join(self.fc_folder, img_hash + self.ext)))
        seq = torch.from_numpy(self.get_labels(index))
        inds = seq[seq > 0]
        indicator[inds] = 1
        ind = torch.from_numpy(indicator)
        gts = torch.from_numpy(self.get_gts(index))
        masks = torch.from_numpy(self.get_masks(seq))
        personality = torch.from_numpy(np.float32(self.perss_onehot[index]))
        densecap = torch.from_numpy(np.int32(self.densecap[index]))

        info_dict = {}
        info_dict['id'] = self.img[ix]['id']
        info_dict['personality'] = self.img[ix].get('personality', '')

        return (fc_feat, att_feat, densecap, seq.squeeze(0), gts, masks.squeeze(0), personality, ind, indices, info_dict)


In [5]:
opt = {"vocab_size": 10451, "seq_length": 16, "seq_per_img": 1, "caption_model": "densepembed", "beam_size": 4, 
       "att_feat_size": 2048, "batch_size": 32, "input_encoding_size" : 1024, "fc_feat_size": 2048, 
       "rnn_size": 2048, "att_hid_size": 512, "num_layers": 2, "drop_prob_lm": 0.5, 
       "start_from": "log_added_new1/log_densepembed2_added",
       "metadata_json": "data/personalised_captions/i2w_personality_mapping.json",
       "input_fc_dir": "data/yfcc_images/resnext101_32x48d_wsl", 
       "input_att_dir": "data/yfcc_images/resnext101_32x48d_wsl_spatial_att",
       "input_json": "data/personalised_captions/testing_ids.json", 
       "input_label_h5": "data/personalised_cap_labels/training_labels.npy",
       "input_label_start_idx": "data/personalised_cap_labels/training_start_ix.npy", 
       "input_label_end_idx": "data/personalised_cap_labels/training_end_ix.npy",
       "perss_onehot_h5": "data/personalities_onehot/training.npy", 
       "densecap_dir": "data/dense_captions/training.npy"}

In [6]:
class AttrDict(dict):
    def __getattr__(self, attr):
        if attr in self:
            return self[attr]
        else:
            raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attr}'")


vals = AttrDict(opt)

validation_dataset = YFCC_3M(vals)

vocab size is  10451
personality size is  216
loading json files:  data/yfcc_images/resnext101_32x48d_wsl data/yfcc_images/resnext101_32x48d_wsl_spatial_att data/personalised_cap_labels/training_labels.npy


In [7]:
validation_dataset[0][0].shape, validation_dataset[0][1].shape, validation_dataset[0][2].shape, \
validation_dataset[0][3].shape, validation_dataset[0][4].shape, validation_dataset[0][5].shape, \
validation_dataset[0][6].shape

(torch.Size([2048, 1, 1]),
 torch.Size([49, 2048]),
 torch.Size([5, 16]),
 torch.Size([18]),
 torch.Size([1, 16]),
 torch.Size([18]),
 torch.Size([217]))

In [8]:
train_loader = DataLoader(validation_dataset, batch_size=8, shuffle=False, 
                              num_workers=16, pin_memory=True, drop_last=True)

In [10]:
for a, b, c, d, e, f, g, h, i, j in train_loader:
    print(a.shape)

torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048, 1, 1])
torch.Size([8, 2048,

KeyboardInterrupt: 

Exception in thread Thread-6:
Traceback (most recent call last):
  File "/home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/threading.py", line 980, in _bootstrap_inner
    self.run()
  File "/home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/threading.py", line 917, in run
    self._target(*self._args, **self._kwargs)
  File "/home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 49, in _pin_memory_loop
    do_one_step()
  File "/home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 26, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  File "/home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/torch/multiprocessing/reductions.py", line 305, in rebuild_storage_fd
    fd = df.