In [1]:
import json, os, glob
import h5py
from tqdm import tqdm
from PIL import Image
from torch.utils.data import Dataset

langFile = "test.pkl"
jsonFile = "/home/ball/dataset/mscoco/visdialog/visdial_1.0_train.json"
cocoDir = "/home/ball/dataset/mscoco/"
cocoFile = "/home/ball/dataset/mscoco/visdialog/cocofile.json"
sentFeature = "visdial_train.h5"

In [2]:
class VisDialDataset(Dataset):
    def __init__(self, dialFile, cocoDir, sentFeatureFile, imgTransform=None, sentTransform=None):
        with open(dialFile, 'r') as f:
            self.data = json.load(f)
            self.data = self.data["data"]
        
        self.cocoDir = cocoDir
        self.imageFile = {}
        for image_path in tqdm(glob.iglob(os.path.join(self.cocoDir, '*', '*.jpg')), desc="Preparing image paths with image_ids"):
            self.imageFile[int(image_path[-12:-4])] = image_path
        self.sentFeature = h5py.File(sentFeatureFile, "r")
        
        self.imgTransform = imgTransform
        self.sentTransform = sentTransform if sentTransform is not None else lambda s: s

            
    def __len__(self):
        return len(self.data["dialogs"])

    def getImage(self, image_id):
        file = self.imageFile[image_id]
        if os.path.isfile(file):
            img = Image.open(file)
            if self.imgTransform:
                img = self.imgTransform(img)
                if img.size(0) != 3:
                    return [], False
            return img, True
        else:
            return [], False
        
    def __getitem__(self, idx):
        if isinstance(idx, slice) :
            #Get the start, stop, and step from the slice
            return [self[ii] for ii in range(*idx.indices(len(self)))]
        
        row = self.data["dialogs"][idx]
        item = {} 
        item["index"] = idx
        item["caption"] = self.sentFeature["caption"][idx]
        item["image"], sucess = self.getImage(row["image_id"])
        item["questions"] = []
        item["answers"] = []
        for i in range(10):
            item["questions"].append(self.sentTransform(self.sentFeature["questions"][row["dialog"][i]["question"]]))
            item["answers"].append(self.sentTransform(self.sentFeature["answers"][row["dialog"][i]["answer"]]))
        if not sucess:
            #print("Error Image: {}".format(idx))
            return self[idx-1] if idx > 0 else self[idx+1]
        return item

In [3]:
from utils.token import Lang
lang = Lang("Visdial", split = " ")

In [5]:
dataset = VisDialDataset(jsonFile, cocoDir, sentFeature)

Preparing image paths with image_ids: 133351it [00:00, 378332.22it/s]


In [7]:
dataset[0]["questions"]

[array([  1.65973604e-01,   2.49807566e-01,   2.66155843e-02,
          5.16678952e-02,  -6.34984672e-01,   7.12899983e-01,
         -1.51408777e-01,   6.40379488e-01,  -4.55886900e-01,
          5.20757847e-02,   2.10901201e-01,  -4.00793821e-01,
          7.39611030e-01,   2.84152478e-02,   1.54116035e-01,
         -9.84949246e-02,   2.12099031e-01,  -2.87105322e-01,
         -4.09477465e-02,  -3.42358857e-01,   6.26736283e-02,
         -5.08400023e-01,  -1.13314047e-01,   4.50587213e-01,
         -2.26804703e-01,  -5.92652619e-01,  -1.24831155e-01,
         -1.39598346e+00,  -1.27448857e-01,  -7.15494394e-01,
          9.51318324e-01,  -1.00327659e+00,  -4.65513021e-01,
         -1.18390048e+00,  -2.95330852e-01,   7.87597477e-01,
         -1.36334807e-01,  -2.86360562e-01,   1.77683800e-01,
         -9.38660979e-01,  -5.25065601e-01,   1.03542790e-01,
         -5.43581657e-02,  -1.86129898e-01,  -1.06903911e-03,
          1.56664580e-01,   6.48907840e-01,  -3.39020729e-01,
        

In [16]:
dataset.sentFeature["questions"][73533]

array([  1.96883023e-01,   5.39050698e-01,  -2.05519661e-01,
         1.33453652e-01,  -5.40044963e-01,  -2.34149843e-02,
        -8.68591070e-01,  -4.06872958e-01,   6.75128639e-01,
         1.17788434e+00,  -1.31362483e-01,  -2.41897196e-01,
        -2.77660251e-01,  -3.27154756e-01,  -8.09699297e-01,
        -8.39026809e-01,   2.34797195e-01,   1.23398893e-01,
        -7.08724186e-02,   1.72853291e-01,   4.12494987e-01,
        -8.02401304e-01,  -4.39597845e-01,   2.24190876e-02,
        -8.74427855e-02,  -1.50478527e-01,   7.46765077e-01,
         2.60591865e-01,   3.44970077e-01,   3.39807749e-01,
         5.31286716e-01,  -3.21395427e-01,   1.98489368e-01,
        -1.51936507e+00,  -8.10594678e-01,   5.45917571e-01,
        -3.88886601e-01,  -2.96401531e-01,  -6.75955653e-01,
        -4.81433235e-03,   3.94184217e-02,   8.06256890e-01,
        -3.21778774e-01,  -1.06613718e-01,   2.05297992e-01,
         4.74448740e-01,   8.20497811e-01,   5.84032714e-01,
         2.74912477e-01,