In [1]:
import torch
import os.path as op
import logging
import numpy as np
import base64
import json

In [2]:
class TSVFile(object):
    def __init__(self, tsv_file, generate_lineidx=False):
        self.tsv_file = tsv_file
        self.lineidx = op.splitext(tsv_file)[0] + '.lineidx'
        self._fp = None
        self._lineidx = None
        # the process always keeps the process which opens the file. 
        # If the pid is not equal to the currrent pid, we will re-open the file.
        self.pid = None
        # generate lineidx if not exist
        if not op.isfile(self.lineidx) and generate_lineidx:
            generate_lineidx_file(self.tsv_file, self.lineidx)

    def __del__(self):
        if self._fp:
            self._fp.close()

    def __str__(self):
        return "TSVFile(tsv_file='{}')".format(self.tsv_file)

    def __repr__(self):
        return str(self)

    def num_rows(self):
        self._ensure_lineidx_loaded()
        return len(self._lineidx)

    def seek(self, idx):
        self._ensure_tsv_opened()
        self._ensure_lineidx_loaded()
        try:
            pos = self._lineidx[idx]
        except:
            logging.info('{}-{}'.format(self.tsv_file, idx))
            raise
        self._fp.seek(pos)
        return [s.strip() for s in self._fp.readline().split('\t')]

    def seek_first_column(self, idx):
        self._ensure_tsv_opened()
        self._ensure_lineidx_loaded()
        pos = self._lineidx[idx]
        self._fp.seek(pos)
        return read_to_character(self._fp, '\t')

    def __getitem__(self, index):
        return self.seek(index)

    def __len__(self):
        return self.num_rows()

    def _ensure_lineidx_loaded(self):
        if self._lineidx is None:
            logging.info('loading lineidx: {}'.format(self.lineidx))
            with open(self.lineidx, 'r') as fp:
                self._lineidx = [int(i.strip()) for i in fp.readlines()]

    def _ensure_tsv_opened(self):
        if self._fp is None:
            self._fp = open(self.tsv_file, 'r')
            self.pid = os.getpid()

        if self.pid != os.getpid():
            logging.info('re-open {} because the process id changed'.format(self.tsv_file))
            self._fp = open(self.tsv_file, 'r')
            self.pid = os.getpid()
            
            
def prepare_image_keys(tsv):
    return [tsv.seek(i)[0] for i in range(tsv.num_rows())]


In [3]:
img_feat_path = "/data/home/zmykevin/vinvl_data/Flickr30K/features.tsv"

In [4]:
img_feats = torch.load(img_feat_path)

In [5]:
print(type(img_feats))

<class 'dict'>


In [6]:
# sample_img_feats = img_feats[241].cpu().numpy()

In [7]:
# print(sample_img_feats.shape)
for i, sample_img_feat in enumerate(img_feats.values()):
    print(sample_img_feat.cpu().numpy().shape)
    if i > 5:
        break
    

(55, 2054)
(75, 2054)
(44, 2054)
(36, 2054)
(27, 2054)
(32, 2054)
(27, 2054)


In [8]:
#Load the text feature
import json
vqa_text_path = "/data/home/zmykevin/vinvl_data/vqa/val2014_qla_mrcnn.json"
vqa_text = json.load(open(vqa_text_path, "r"))
print(type(vqa_text))

<class 'list'>


In [9]:
print(len(vqa_text))
print(vqa_text[0])

10631
{'q': 'What is he sitting on?', 'o': 'person person bottle cup person cup remote couch handbag couch frisbee couch person potted plant person', 'an': [487, 2969, 2898], 's': [0.9, 0.6, 1.0], 'img_id': 241}


In [10]:
import pickle
pickle_file = "/data/home/zmykevin/vinvl_data/vqa/trainval_ans2label.pkl"
ans2label = pickle.load(open(pickle_file, "rb"))

In [11]:
print(type(ans2label))

<class 'dict'>


In [12]:
print(len(ans2label.keys()))

3129


In [4]:
import os
#Load the COCO tsv features
coco_data_path = "/data/home/zmykevin/vinvl_data/Flickr30K"
coco_feature_tsv = os.path.join(coco_data_path, "features.tsv")
coco_prediction_tsv = os.path.join(coco_data_path, "predictions.tsv")

In [5]:
feat_tsv = TSVFile(coco_feature_tsv)
prediction_tsv = TSVFile(coco_prediction_tsv)

In [6]:
for i in range(feat_tsv.num_rows()):
    num_boxes = int(feat_tsv.seek(i)[1])
    print(num_boxes)
    features = np.frombuffer(base64.b64decode(feat_tsv.seek(i)[2]), np.float32
                ).reshape((num_boxes, -1))
    print(features.shape)
    #print(features[0][-6:])
    break

39
(39, 2054)


In [8]:
assert prediction_tsv.num_rows() == feat_tsv.num_rows()
for i in range(prediction_tsv.num_rows()):
    print(prediction_tsv.seek(i)[0])
    print(json.loads(prediction_tsv.seek(i)[1]).keys())
    #print(json.loads(prediction_tsv.seek(i)[1])['objects'])
    print(type(prediction_tsv.seek(i)[0]))
    if i > 10:
        break

1000092795
dict_keys(['image_h', 'image_w', 'num_boxes', 'objects', 'predicates', 'relations'])
[{'class': 'bush', 'conf': 0.8448252081871033, 'rect': [0.0, 216.36618041992188, 152.89178466796875, 387.9640808105469]}, {'class': 'shirt', 'conf': 0.837348222732544, 'rect': [175.1438446044922, 159.7999267578125, 207.78457641601562, 222.9622344970703]}, {'class': 'shirt', 'conf': 0.8345832824707031, 'rect': [206.04058837890625, 146.81971740722656, 257.3687744140625, 242.1309356689453]}, {'class': 'pant', 'conf': 0.7661622762680054, 'rect': [175.64266967773438, 224.15505981445312, 216.13381958007812, 319.3218994140625]}, {'class': 'man', 'conf': 0.7620166540145874, 'rect': [128.4736785888672, 186.67681884765625, 264.5747985839844, 358.3706359863281]}, {'class': 'sky', 'conf': 0.7266457080841064, 'rect': [115.32189178466797, 0.0, 332.44500732421875, 66.3604965209961]}, {'class': 'hair', 'conf': 0.7246615886688232, 'rect': [197.54745483398438, 111.19523620605469, 236.4049072265625, 145.136322