In [2]:
import json
import os
import pickle

import numpy as np
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import cv2 as cv
from utils import load_obj_tsv
import collections
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from transformers import BertTokenizer
from tqdm import tqdm
# Load part of the dataset for fast checking.
# Notice that here is the number of images instead of the number of data,
# which means all related data to the images would be used.
TINY_IMG_NUM = 512
FAST_IMG_NUM = 5000

# The path to data and image features.

MSCOCO_IMGFEAT_ROOT = '/home/u37216/project/lxmert/data/mscoco_img_feat/'
#'/home/jupyter/vcr/lxmert/data/mscoco_imgfeat/'
SPLIT2NAME = {
    'train': 'train2014',
    'valid': 'val2014',
    'minival': 'val2014',
    'nominival': 'val2014',
    'test': 'test2015',
}
torch.cuda.is_available()

I0109 21:30:20.533334 140713261415296 file_utils.py:35] PyTorch version 1.3.1+cpu available.
  from ._conv import register_converters as _register_converters


False

In [2]:
# from param import args;

## Args parsing

In [3]:
train='train'
valid='valid'
test=None
bs=128
optim=torch.optim.Adam
lr=1e-4
epochs=10
dpt=0.1
seed=9595
output_dir='test'
fast=False
tiny=False
tqdm=True
load=None
load_lxmert=None
load_lxmert_qa=None
from_scratch=False
mce_loss=False
llayers=9
xlayers = 5
rlayers=5
taskMatched = False
taskMaskLM=False
taskObjPredict=False
taskQA=False
visualLosses='obj,attr,feat'
qaSets=None
wordMaskRate=0.15 
objMaskRate=0.15
multiGPU=False
num_workers=4  

## Base class for dataset

In [4]:
class VQADataset:
    """
    A VQA data example in json file:
        {
            "answer_type": "other",
            "img_id": "COCO_train2014_000000458752",
            "label": {
                "net": 1
            },
            "question_id": 458752000,
            "question_type": "what is this",
            "sent": "What is this photo taken looking through?"
        }
    """
    def __init__(self, path,splits: str):
        self.path = path
        self.name = splits
        self.splits = splits.split(',')

        # Loading datasets
        self.data = []
        for split in self.splits:
            self.data.extend(json.load(open(self.path+"%s.json" % split)))
        print("Load %d data from split(s) %s." % (len(self.data), self.name))

        # Convert list to dict (for evaluation)
        self.id2datum = {
            datum['question_id']: datum
            for datum in self.data
        }

        # Answers
        self.ans2label = json.load(open(self.path+"trainval_ans2label.json"))
        self.label2ans = json.load(open(self.path+"trainval_label2ans.json"))
        assert len(self.ans2label) == len(self.label2ans)

    @property
    def num_answers(self):
        return len(self.ans2label)

    def __len__(self):
        return len(self.data)
    
    def plot_img(self,idx):
        im = cv.imread(self.path+self.splits[0]+'2014'+'/'+ self.data[idx]['img_id']+'.jpg')
        im = plt.imshow(im)
        return im,self.data[idx]

In [5]:
#dset_train_init.plot_img(125)

In [6]:
# what is
# -ans2label, label2ans from json

In [7]:
#dset_train_init.data[100]

In [8]:
class VQATorchDataset(Dataset):
    def __init__(self, dataset: VQADataset):
        super().__init__()
        self.raw_dataset = dataset

        if tiny:
            topk = TINY_IMG_NUM
        elif fast:
            topk = FAST_IMG_NUM
        else:
            topk = None

        # Loading detection features to img_data
        img_data = []
        for split in dataset.splits:
            # Minival is 5K images in MS COCO, which is used in evaluating VQA/LXMERT-pre-training.
            # It is saved as the top 5K features in val2014_***.tsv
            load_topk = 5000 if (split == 'minival' and topk is None) else topk
            img_data.extend(load_obj_tsv(
                os.path.join(MSCOCO_IMGFEAT_ROOT, '%s_obj36.tsv' % (SPLIT2NAME[split])),
                topk=load_topk))

        # Convert img list to dict
        self.imgid2img = {}
        for img_datum in img_data:
            self.imgid2img[img_datum['img_id']] = img_datum

        # Only kept the data with loaded image features
        self.data = []
        for datum in self.raw_dataset.data:
            if datum['img_id'] in self.imgid2img:
                self.data.append(datum)
        print("Use %d data in torch dataset" % (len(self.data)))
        print()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item: int):
        datum = self.data[item]

        img_id = datum['img_id']
        ques_id = datum['question_id']
        ques = datum['sent']

        # Get image info
        img_info = self.imgid2img[img_id]
        obj_num = img_info['num_boxes']
        feats = img_info['features'].copy()
        boxes = img_info['boxes'].copy()
        assert obj_num == len(boxes) == len(feats)

        # Normalize the boxes (to 0 ~ 1)
        img_h, img_w = img_info['img_h'], img_info['img_w']
        boxes = boxes.copy()
        boxes[:, (0, 2)] /= img_w
        boxes[:, (1, 3)] /= img_h
        np.testing.assert_array_less(boxes, 1+1e-5)
        np.testing.assert_array_less(-boxes, 0+1e-5)

        # Provide label (target)
        if 'label' in datum:
            label = datum['label']
            target = torch.zeros(self.raw_dataset.num_answers)
            for ans, score in label.items():
                target[self.raw_dataset.ans2label[ans]] = score
            return ques_id, feats, boxes, ques, target
        else:
            return ques_id, feats, boxes, ques

In [9]:
class VQAEvaluator:
    def __init__(self, dataset: VQADataset):
        self.dataset = dataset

    def evaluate(self, quesid2ans: dict):
        score = 0.
        for quesid, ans in quesid2ans.items():
            datum = self.dataset.id2datum[quesid]
            label = datum['label']
            if ans in label:
                score += label[ans]
        return score / len(quesid2ans)

    def dump_result(self, quesid2ans: dict, path):
        """
        Dump results to a json file, which could be submitted to the VQA online evaluation.
        VQA json file submission requirement:
            results = [result]
            result = {
                "question_id": int,
                "answer": str
            }
        :param quesid2ans: dict of quesid --> ans
        :param path: The desired path of saved file.
        """
        with open(path, 'w') as f:
            result = []
            for ques_id, ans in quesid2ans.items():
                result.append({
                    'question_id': ques_id,
                    'answer': ans
                })
            json.dump(result, f, indent=4, sort_keys=True)

In [10]:
VQA_DATA_ROOT = '/home/jupyter/vcr/lxmert/data/vqa/'
dset_train_init = VQADataset(VQA_DATA_ROOT,'minival') #'train,nominival'
dset_valid_init = VQADataset(VQA_DATA_ROOT,'minival')

Load 25994 data from split(s) minival.
Load 25994 data from split(s) minival.


In [11]:
train_data = VQATorchDataset(dset_train_init) # ques_id, feats, boxes, ques, target

0it [00:00, ?it/s]

Start to load Faster-RCNN detected objects from /home/jupyter/vcr/lxmert/data/mscoco_imgfeat/val2014_obj36.tsv


4993it [00:25, 198.58it/s]

Loaded 5000 images in file /home/jupyter/vcr/lxmert/data/mscoco_imgfeat/val2014_obj36.tsv in 25 seconds.
Use 25994 data in torch dataset



In [12]:
eval_dset = VQAEvaluator(dset_valid_init)

## Load LXMERT QA

In [13]:
from pretrain.qa_answer_table import load_lxmert_qa

In [14]:
from lxrt.entry import LXRTEncoder
from torch.nn.functional import gelu
from transformers.modeling_bert import BertLayerNorm

In [15]:
# Max length including <bos> and <eos>
MAX_VQA_LENGTH = 20

In [16]:
class GeLU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,x):
        return torch.nn.functional.gelu(x)

In [17]:
class Args():
    def __init__(self,l_layers,x_layers,r_layers):
        self.llayers = l_layers
        self.xlayers = x_layers
        self.rlayers = r_layers
        self.from_scratch=False

In [18]:
args = Args(9,5,5)

In [19]:
class VQAModel(nn.Module):
    def __init__(self, num_answers):
        super().__init__()
        
        # Build LXRT encoder
        self.lxrt_encoder = LXRTEncoder(
            args,
            max_seq_length=MAX_VQA_LENGTH
        )
        hid_dim = self.lxrt_encoder.dim
        
        # VQA Answer heads
        self.logit_fc = nn.Sequential(
            nn.Linear(hid_dim, hid_dim * 2),
            GeLU(),
            BertLayerNorm(hid_dim * 2, eps=1e-12),
            nn.Linear(hid_dim * 2, num_answers)
        )
        self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights)

    def forward(self, feat, pos, sent):
        """
        b -- batch_size, o -- object_number, f -- visual_feature_size
        :param feat: (b, o, f)
        :param pos:  (b, o, 4)
        :param sent: (b,) Type -- list of string
        :param leng: (b,) Type -- int numpy array
        :return: (b, num_answer) The logit of each answers.
        """
        x = self.lxrt_encoder(sent, (feat, pos))
        logit = self.logit_fc(x)

        return logit

In [20]:
# def get_data_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
#     dset = VQADataset(splits)
#     tset = VQATorchDataset(dset)
#     evaluator = VQAEvaluator(dset)
#     data_loader = DataLoader(
#         tset, batch_size=bs,
#         shuffle=shuffle, num_workers=args.num_workers,
#         drop_last=drop_last, pin_memory=True
#     )

#     return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)

In [21]:
data_loader = DataLoader(
        train_data, batch_size=bs,
        shuffle=True, num_workers=num_workers,
        drop_last=False, pin_memory=True
    )

In [22]:
model = VQAModel(dset_train_init.num_answers)

4993it [00:40, 198.58it/s]

LXRT encoder with 9 l_layers, 5 x_layers, and 5 r_layers.


In [23]:
load_lxmert_qa_path = '/home/jupyter/vcr/lxmert/snap/pretrained/model'

In [24]:
class AnswerTable:
    ANS_CONVERT = {
        "a man": "man",
        "the man": "man",
        "a woman": "woman",
        "the woman": "woman",
        'one': '1',
        'two': '2',
        'three': '3',
        'four': '4',
        'five': '5',
        'six': '6',
        'seven': '7',
        'eight': '8',
        'nine': '9',
        'ten': '10',
        'grey': 'gray',
    }

    def __init__(self, dsets=None):
        self.all_ans = json.load(open("/home/jupyter/vcr/lxmert/data/lxmert/all_ans.json"))
        if dsets is not None:
            dsets = set(dsets)
            # If the answer is used in the dsets
            self.anss = [ans['ans'] for ans in self.all_ans if
                         len(set(ans['dsets']) & dsets) > 0]
        else:
            self.anss = [ans['ans'] for ans in self.all_ans]
        self.ans_set = set(self.anss)

        self._id2ans_map = self.anss
        self._ans2id_map = {ans: ans_id for ans_id, ans in enumerate(self.anss)}

        assert len(self._id2ans_map) == len(self._ans2id_map)
        for ans_id, ans in enumerate(self._id2ans_map):
            assert self._ans2id_map[ans] == ans_id

    def convert_ans(self, ans):
        if len(ans) == 0:
            return ""
        ans = ans.lower()
        if ans[-1] == '.':
            ans = ans[:-1].strip()
        if ans.startswith("a "):
            ans = ans[2:].strip()
        if ans.startswith("an "):
            ans = ans[3:].strip()
        if ans.startswith("the "):
            ans = ans[4:].strip()
        if ans in self.ANS_CONVERT:
            ans = self.ANS_CONVERT[ans]
        return ans

    def ans2id(self, ans):
        return self._ans2id_map[ans]

    def id2ans(self, ans_id):
        return self._id2ans_map[ans_id]

    def ans2id_map(self):
        return self._ans2id_map.copy()

    def id2ans_map(self):
        return self._id2ans_map.copy()

    def used(self, ans):
        return ans in self.ans_set

    def all_answers(self):
        return self.anss.copy()

    @property
    def num_answers(self):
        return len(self.anss)


def load_lxmert_qa(path, model, label2ans):
    """
    Load model weights from LXMERT pre-training.
    The answers in the fine-tuned QA task (indicated by label2ans)
        would also be properly initialized with LXMERT pre-trained
        QA heads.

    :param path: Path to LXMERT snapshot.
    :param model: LXRT model instance.
    :param label2ans: The label2ans dict of fine-tuned QA datasets, like
        {0: 'cat', 1: 'dog', ...}
    :return:
    """
    print("Load QA pre-trained LXMERT from %s " % path)
    loaded_state_dict = torch.load("%s_LXRT.pth" % path, map_location=lambda storage, loc: storage)
    model_state_dict = model.state_dict()

    # Handle Multi-GPU pre-training --> Single GPU fine-tuning
    for key in list(loaded_state_dict.keys()):
        loaded_state_dict[key.replace("module.", '')] = loaded_state_dict.pop(key)

    # Isolate bert model
    bert_state_dict = {}
    for key, value in loaded_state_dict.items():
        if key.startswith('bert.'):
            bert_state_dict[key] = value

    # Isolate answer head
    answer_state_dict = {}
    for key, value in loaded_state_dict.items():
        if key.startswith("answer_head."):
            answer_state_dict[key.replace('answer_head.', '')] = value

    # Do surgery on answer state dict
    ans_weight = answer_state_dict['logit_fc.3.weight']
    ans_bias = answer_state_dict['logit_fc.3.bias']
    import copy
    new_answer_weight = copy.deepcopy(model_state_dict['logit_fc.3.weight'])
    new_answer_bias = copy.deepcopy(model_state_dict['logit_fc.3.bias'])
    answer_table = AnswerTable()
    loaded = 0
    unload = 0
    if type(label2ans) is list:
        label2ans = {label: ans for label, ans in enumerate(label2ans)}
    for label, ans in label2ans.items():
        new_ans = answer_table.convert_ans(ans)
        if answer_table.used(new_ans):
            ans_id_9500 = answer_table.ans2id(new_ans)
            new_answer_weight[label] = ans_weight[ans_id_9500]
            new_answer_bias[label] = ans_bias[ans_id_9500]
            loaded += 1
        else:
            new_answer_weight[label] = 0.
            new_answer_bias[label] = 0.
            unload += 1
    print("Loaded %d answers from LXRTQA pre-training and %d not" % (loaded, unload))
    print()
    answer_state_dict['logit_fc.3.weight'] = new_answer_weight
    answer_state_dict['logit_fc.3.bias'] = new_answer_bias

    # Load Bert Weights
    bert_model_keys = set(model.lxrt_encoder.model.state_dict().keys())
    bert_loaded_keys = set(bert_state_dict.keys())
    assert len(bert_model_keys - bert_loaded_keys) == 0
    model.lxrt_encoder.model.load_state_dict(bert_state_dict, strict=False)

    # Load Answer Logic FC Weights
    model_keys = set(model.state_dict().keys())
    ans_loaded_keys = set(answer_state_dict.keys())
    assert len(ans_loaded_keys - model_keys) == 0

    model.load_state_dict(answer_state_dict, strict=False)

In [25]:
load_lxmert_qa(load_lxmert_qa_path, model, label2ans= dset_train_init.label2ans)

Load QA pre-trained LXMERT from /home/jupyter/vcr/lxmert/snap/pretrained/model 
Loaded 3124 answers from LXRTQA pre-training and 5 not



In [26]:
# k = torch.load(load_lxmert_qa_path+'_LXRT.pth')
# k.keys()

In [27]:
model.cuda()

VQAModel(
  (lxrt_encoder): LXRTEncoder(
    (model): LXRTFeatureExtraction(
      (bert): LXRTModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768, padding_idx=0)
          (token_type_embeddings): Embedding(2, 768, padding_idx=0)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): LXRTEncoder(
          (visn_fc): VisualFeatEncoder(
            (visn_fc): Linear(in_features=2048, out_features=768, bias=True)
            (visn_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (box_fc): Linear(in_features=4, out_features=768, bias=True)
            (box_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (layer): ModuleList(
            (0): BertLayer(
     

In [28]:
bce_loss = nn.BCEWithLogitsLoss()

In [29]:
optim = torch.optim.Adam(model.parameters(),1e-4)

In [30]:
iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if tqdm else (lambda x: x)

In [31]:
best_valid = 0
epochs = 10

In [32]:
# for epoch in tqdm(range(epochs)):
#     quesid2ans = {}
    
#     for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)):
#         model.train()
#         optim.zero_grad()
        
#         feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda()
        
        

In [33]:
i =0
ques_id, feats, boxes, sent, target = next(iter(data_loader))

In [34]:
ques_id.shape, feats.shape, boxes.shape, len(sent), target.shape

(torch.Size([128]),
 torch.Size([128, 36, 2048]),
 torch.Size([128, 36, 4]),
 128,
 torch.Size([128, 3129]))

In [35]:
feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda()

In [36]:
logit = model(feats,boxes,sent)

In [37]:
logit.shape

torch.Size([128, 3129])

In [38]:
assert logit.dim()==target.dim()

In [39]:
loss = bce_loss(logit,target)

In [40]:
loss = loss*logit.size(1)

In [41]:
loss.backward()

In [42]:
nn.utils.clip_grad_norm_(model.parameters(), 5.)

9287.52677944882

In [43]:
optim.step()

In [44]:
score,label = logit.max(1)

In [45]:
quesid2ans = {}

In [46]:
for qid, l in zip(ques_id, label.cpu().numpy()):
    ans = dset_train_init.label2ans[l]
    quesid2ans[qid.item()] = ans

In [47]:
ans

'no'

In [48]:
quesid2ans

{85527011: 'yes',
 539971002: 'no',
 480268000: 'unknown',
 575916019: 'no',
 127074002: 'in grass',
 54796004: 'tennis racket',
 178606003: 'kites',
 143554006: 'nothing',
 138179001: 'toothbrush',
 513765001: 'yes',
 15029000: '2',
 441264000: 'open',
 133418000: 'yes',
 488915003: 'no',
 476005000: 'buses',
 553776001: 'yellow',
 451674001: 'kite',
 69911001: 'nothing',
 205101012: '1',
 211743002: 'yes',
 504297013: "mcdonald's",
 39900004: 'no',
 140636001: 'black',
 486491004: 'no',
 12818001: 'yes',
 216636002: 'chocolate',
 443818001: '1',
 339336002: 'yes',
 489907005: 'no',
 454148002: 'no',
 168706002: 'no',
 223874008: '2',
 122166003: 'right',
 377845000: '1',
 161877007: 'brown',
 101180000: 'red',
 85527013: 'yes',
 338683002: 'black',
 204525008: 'pink',
 38828003: 'leather',
 563337001: 'brown',
 445008004: '5',
 332775002: 'suitcase',
 8498000: 'afternoon',
 514396000: 'no',
 23584001: 'orange',
 331455008: 'yes',
 163290001: 'no',
 160580004: 'balcony',
 483130008: '

In [49]:
for qid, l in zip(ques_id, label.cpu().numpy()):
    print(qid,l)
    break

tensor(85527011) 425


In [51]:
quesid2ans = {}
for qid, l in zip(ques_id, label.cpu().numpy()):
    ans = dset_train_init.label2ans[l]
    quesid2ans[qid.item()] = ans

In [54]:
qid

tensor(79969004)

In [52]:
ans

'no'

In [53]:
quesid2ans

{85527011: 'yes',
 539971002: 'no',
 480268000: 'unknown',
 575916019: 'no',
 127074002: 'in grass',
 54796004: 'tennis racket',
 178606003: 'kites',
 143554006: 'nothing',
 138179001: 'toothbrush',
 513765001: 'yes',
 15029000: '2',
 441264000: 'open',
 133418000: 'yes',
 488915003: 'no',
 476005000: 'buses',
 553776001: 'yellow',
 451674001: 'kite',
 69911001: 'nothing',
 205101012: '1',
 211743002: 'yes',
 504297013: "mcdonald's",
 39900004: 'no',
 140636001: 'black',
 486491004: 'no',
 12818001: 'yes',
 216636002: 'chocolate',
 443818001: '1',
 339336002: 'yes',
 489907005: 'no',
 454148002: 'no',
 168706002: 'no',
 223874008: '2',
 122166003: 'right',
 377845000: '1',
 161877007: 'brown',
 101180000: 'red',
 85527013: 'yes',
 338683002: 'black',
 204525008: 'pink',
 38828003: 'leather',
 563337001: 'brown',
 445008004: '5',
 332775002: 'suitcase',
 8498000: 'afternoon',
 514396000: 'no',
 23584001: 'orange',
 331455008: 'yes',
 163290001: 'no',
 160580004: 'balcony',
 483130008: '