In [19]:
import os
import re
import json
import h5py
import numpy as np
from collections import defaultdict
from nltk.tokenize import word_tokenize
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader

In [9]:
from time import time

In [10]:
val_file = '../../../../ivd_data/Oracle/oracle.val.json'
train_file = '../../../../ivd_data/Oracle/oracle.train.json'
test_file = '../../../../ivd_data/Oracle/oracle.test.json'
small_file = '../../../seq2seq/Preprocessing/Data/oracle.small.test.json'

In [17]:
with open(test_file) as f:
    data = json.load(f)['questions']
print(len(data))
data[0]

121938


{'answer': 'No',
 'crop_features': '2488.jpg',
 'game_id': 2488,
 'img_features': 'COCO_train2014_000000175527.jpg',
 'obj_cat': 9,
 'question': 'is it in the sky?',
 'spatial': [0.6765, 0.7561, 0.9207, 0.9261, 0.7986, 0.8411, 0.1221, 0.085]}

In [18]:
# get statistics about questions
question_lengths = list()
for datapoint in data:
    q = datapoint['question']
    question_lengths.append(len(q.split()))
        
print("Avg: ", np.mean(question_lengths))
print("Median: ", np.median(question_lengths))
print("80 Perecentile: ", np.percentile(question_lengths, 80))
print("90 Perecentile: ", np.percentile(question_lengths, 90))
print("99 Perecentile: ", np.percentile(question_lengths, 99))

Avg:  4.87682264757
Median:  4.0
80 Perecentile:  6.0
90 Perecentile:  8.0
99 Perecentile:  13.0


In [5]:
obj_cat = []
for datum in data:
    obj_cat.append(datum['obj_cat'])
    
obj_cat = list(set(obj_cat))
print(len(obj_cat))
print(obj_cat)

80
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]


In [6]:
questions = [re.findall(r'\w+', datum['question'].lower())+['?'] for datum in data]
# questions = [re.findall("[a-zA-Z]+", datum['question'].lower()) for datum in data]
# questions = [datum['question'].lower().split() for datum in data]
# for q in questions:
#     for idx in range(len(q)):
#         if q[idx][-1] == '?':
#             q[idx] = q[idx][:-1]
#         if idx == len(q)-1:
#             q += '?'
print(len(questions))
max_length = max(questions,key=len)
min_length = min(questions,key=len)
print(questions.index(max_length))
print(questions.index(min_length))
print(max_length)
print(len(max_length))
print(min_length)
print(len(min_length))

121938
38362
23969
['continuing', 'to', 'move', 'right', 'is', 'it', 'the', 'person', 'behind', 'and', 'to', 'our', 'right', 'of', 'the', 'first', 'row', 'man', 'in', 'the', 'white', 'shirt', 'and', 'a', 'black', 'hat', 'with', 'a', 'stripe', 'across', 'the', 'front', '?']
33
['?']
1


In [7]:
vocab = defaultdict(int)
vocab['-PAD-'] = 6
vocab['-UNK-'] = 6

for line in questions:
    for word in line:
        if word in vocab:
            vocab[word] += 1
        else:
            vocab[word] = 1
# vocab = list(set(vocab))
print(len(vocab))
vocab_temp = dict(vocab)
for w in vocab_temp:
    if vocab[w] < 3:
        vocab.pop(w)
del vocab_temp
print(len(vocab))
# for w in sorted(vocab, key=vocab.get, reverse=True):
#       print(w, vocab[w])

5069
1900


In [50]:
word2ind = {}
ind2word = {}

for idx, w in enumerate(sorted(vocab, key=vocab.get, reverse=True)):
    ind2word[idx] = w
    word2ind[w] = idx

# print(ind2word)
# print(word2ind)

In [51]:
# json_data = {'word2ind':word2ind, 'ind2word':ind2word}

# with open('vocabOracle.json','w') as vc:
#     json.dump(json_data, vc)

In [57]:
(data[15])

{'answer': 'No',
 'crop_features': '2426.jpg',
 'game_id': 2426,
 'img_features': 'COCO_train2014_000000460809.jpg',
 'obj_cat': 44,
 'question': 'does the container have a picture?',
 'spatial': [-0.3059,
  -0.9893,
  -0.2675,
  -0.9128,
  -0.2867,
  -0.9511,
  0.0192,
  0.0382]}

In [53]:
filename = '../../../../ivd_data/img_features/image_features.h5'

h5data = h5py.File(filename, 'r')
train_data = h5data['train_img_features']
del h5data
print(train_data[0])
print(len(train_data))

[ 0.0393373   0.49363673 -0.23234019 ..., -0.1444326   0.16763082
  0.14814836]
46794


In [44]:
class OracleDataset(Dataset):
    def __init__(self, split, json_data_file, img_features_file, img2id_file, crop_features_file, crop2id_file, vocab_json_file):
        """
        split: ['train', 'val', 'test']
        """
        with open(json_data_file) as file:
            self.questions = json.load(file)['questions']
        with open(img2id_file) as file:
            self.img2id = json.load(file)[split+'2id']
        with open(crop2id_file) as file:
            self.crop2id = json.load(file)[split+'crops2id']
        
        img_h5data = h5py.File(img_features_file, 'r')
        self.img_features = img_h5data[split+'_img_features']
        del img_h5data
        crop_h5data = h5py.File(crop_features_file, 'r')
        self.crop_features = crop_h5data[split+'_crop_features']
        del crop_h5data
        
        self.ans2id = {'no':0, 'yes':1, 'n/a':2}
        with open(vocab_json_file) as file:
            self.word2ind = json.load(file)['word2ind']
            
    def __len__(self):
        return len(self.questions)
    
    def __getitem__(self, idx):
        crop_features = self.crop_features[self.crop2id[self.questions[idx]['crop_features']]]
        img_features = self.img_features[self.img2id[self.questions[idx]['img_features']]]
        spaital = torch.FloatTensor(self.questions[idx]['spatial'])
        obj_cat = self.questions[idx]['obj_cat']
        
        raw_question = self.questions[idx]['question']
        question = (np.ones(45,'uint8')*self.word2ind['-PAD-']).tolist()
        for wid, word in enumerate(re.findall(r'\w+', raw_question.lower())+['?']):
            if word in self.word2ind:
                question[wid] = self.word2ind[word]
            else:
                question[wid] = self.word2ind['-UNK-']
            
        question = torch.LongTensor(question)
        answer = self.ans2id[self.questions[idx]['answer'].lower()]
        
        sample = {'question':question, 'answer': answer, 'crop_features':crop_features, 'img_features':img_features,\
                  'spaital':spaital, 'obj_cat':obj_cat}
        
        return sample

In [50]:
split = 'val'
json_data_file = '../../../../ivd_data/Oracle/oracle.val.json'
vocab_json_file = 'vocabOracle.json'
img_features_file = '../../../../ivd_data/img_features/image_features.h5'
img2id_file = '../../../../ivd_data/img_features/img_features2id.json'
crop_features_file = '../../../../ivd_data/img_features/crop_features.h5'
crop2id_file = '../../../../ivd_data/img_features/crop_features2id.json'

od = OracleDataset(split, json_data_file, img_features_file, img2id_file, crop_features_file, crop2id_file, vocab_json_file)
od[1]

torch.Size([45])

In [52]:
dataloader = DataLoader(od, batch_size=128, shuffle=True, num_workers=4, pin_memory=False)
count = 0
start = time()
for sample in dataloader:
    print(time()-start)
    if count == 5:
        print(sample['crop_features'].size()[0])
        break
    count += 1
    start = time()
    

0.4820237159729004
0.0004048347473144531
0.0011034011840820312
0.10975170135498047
0.005167484283447266
0.1707758903503418
128


In [None]:
0.3122749328613281
0.009618043899536133
0.01502084732055664
0.0002598762512207031
0.07806873321533203
0.011798858642578125