In [1]:
import json, os, glob
from tqdm import tqdm
from PIL import Image
from torch.utils.data import Dataset

langFile = "test.pkl"
jsonFile = "/home/ball/dataset/mscoco/visdialog/visdial_1.0_train.json"
cocoDir = "/home/ball/dataset/mscoco/"
cocoFile = "/home/ball/dataset/mscoco/visdialog/cocofile.json"

In [10]:
class VisDialDataset(Dataset):
    def __init__(self, dialFile, cocoDir, imgTransform=None, sentTransform=None):
        with open(dialFile, 'r') as f:
            self.data = json.load(f)
            self.data = self.data["data"]
        
        self.cocoDir = cocoDir
        self.imageFile = {}
        for image_path in tqdm(glob.iglob(os.path.join(self.cocoDir, '*', '*.jpg')), desc="Preparing image paths with image_ids"):
            self.imageFile[int(image_path[-12:-4])] = image_path
        
        self.imgTransform = imgTransform
        self.sentTransform = None
        self.setSentTransform(sentTransform)
    
    def setSentTransform(self, sentTransform):
        if self.sentTransform:
            raise("sentTransform already exist")
            
        self.sentTransform = sentTransform
        if self.sentTransform: 
            for kType in ["questions", "answers"]:
                for idx in range(len(self.data[kType])):
                    self.data[kType][idx] = self.sentTransform(self.data[kType][idx])
                    
            for idx in range(len(self.data["dialogs"])):
                self.data["dialogs"][idx]["caption"] = self.sentTransform(self.data["dialogs"][idx]["caption"])
            
    def __len__(self):
        return len(self.data["dialogs"])

    def getImage(self, image_id):
        img = Image.open(self.imageFile[image_id])
        if self.imgTransform:
            img = self.imgTransform(img)
        return img
        
    def __getitem__(self, idx):
        row = self.data["dialogs"][idx]
        item = {} 
        item["caption"] = row["caption"]
        item["image"] = self.getImage(row["image_id"])
        item["questions"] = []
        item["answers"] = []
        for i in range(10):
            item["questions"].append(self.data["questions"][row["dialog"][i]["question"]])
            item["answers"].append(self.data["answers"][row["dialog"][i]["answer"]])
        
        return item
    
    def getAllSentences(self):
        sentences = []
        for dialog in self.data["dialogs"]:
            sentences.append(dialog["caption"])
        sentences += self.data["questions"]
        sentences += self.data["answers"]
        return sentences

In [3]:
from utils.token import Lang
lang = Lang("Visdial", split = " ")

In [11]:
dataset = VisDialDataset(jsonFile, cocoDir)

Preparing image paths with image_ids: 133351it [00:00, 373159.53it/s]


In [5]:
from tqdm import tqdm
for i, sent in enumerate(tqdm(dataset.getAllSentences(), desc="Create Dict")):
    lang.addSentance(sent)

Create Dict: 100%|██████████| 836896/836896 [00:02<00:00, 291913.82it/s]


In [6]:
lang.save(langFile)

In [7]:
dataset.setSentTransform(lang.sentenceToVector)

In [12]:
dataset[0]

{'answers': ['adult',
  'male',
  'inside',
  'yes, but there is a blanket in between them and the floor',
  'it is tile',
  'red and white',
  'orange red',
  'boxer',
  'yes',
  'tan'],
 'caption': 'a person that is laying next to a dog',
 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=612x612 at 0x7F2C327988D0>,
 'questions': ['is this a child or adult',
  'male or female',
  'are they inside or outside',
  'are they laying on the floor',
  'is the floor carpeted or wooden',
  'what color is the blanket',
  'what color is the tile',
  'what breed is the dog',
  'does the dog look healthy and happy',
  'what color is the dog']}