In [None]:
cd ..

In [None]:
import sys
import os
import json
import random
import pdb
import logging
import torch
import torch.nn.functional as F
import torch.nn as nn
from tqdm import tqdm
from bisect import bisect
import yaml
from easydict import EasyDict as edict

In [None]:
from vilbert.task_utils import (
    LoadDatasets,
    LoadLosses,
    ForwardModelsTrain,
    ForwardModelsVal,
    clip_gradients,
    get_optim_scheduler)

import vilbert.utils as utils
import torch.distributed as dist

In [None]:
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

In [None]:
# python train_tasks_evaluate.py \
# --task_file sweeps/m4c-spatial-mask-1-2-layers-4.yml \
# --from_scratch \
# --resume_file save/TextVQA_spatial_m4c_mmt_textvqa-finetune_from_multi_task_model-local-spatial-4layers-mask-1-2/pytorch_ckpt_latest.tar \
# --config_file config/spatial_m4c_mmt_textvqa.json \
# --tasks 19 \
# --train_iter_gap 4 --save_name finetune_from_multi_task_model \
# --tag "debug"

In [None]:
args = edict({
    'bert_model': 'bert-base-uncased',
    'tasks': '19',
    'do_lower_case': True, 
    'in_memory': True,
    'gradient_accumulation_steps': 1, 
    'num_workers': 0, 
    'local_rank': -1, 
    'clean_train_sets': False, 
    'num_train_epochs': 100,
    'train_iter_multiplier': 1.0, 
    'config_file': "config/spatial_m4c_mmt_textvqa.json",
    'resume_file': "save/TextVQA_spatial_m4c_mmt_textvqa-finetune_from_multi_task_model-local-spatial-4layers-mask-1-2/pytorch_ckpt_latest.tar"
})

In [None]:
from vilbert.m4c_spatial import BertConfig, M4C

In [None]:
project_dir = '/srv/share/ykant3/common/vilbert-multi-task/'
task_file = 'sweeps/m4c-spatial-mask-1-2-layers-4.yml'

In [None]:
with open(task_file, "r") as f:
        task_cfg = edict(yaml.safe_load(f))

In [None]:
(
    task_batch_size, 
    task_num_iters, 
    task_ids, 
    task_datasets_train, 
    task_datasets_val, 
    task_dataloader_train, 
    task_dataloader_val
) = LoadDatasets(args, task_cfg, args.tasks.split("-"))

In [None]:
task_losses = LoadLosses(args, task_cfg, args.tasks.split("-"))

In [None]:
task = "TASK" + str(args.tasks)
task_id = 19

In [None]:
device = torch.device(
    "cuda" if torch.cuda.is_available() else "cpu"
)
n_gpu = torch.cuda.device_count()

In [None]:
transfer_keys = ["attention_mask_quadrants", "hidden_size", "num_implicit_relations", "spatial_type", "num_hidden_layers", "num_spatial_layers", "layer_type_list"]
transfer_keys.extend(["aux_spatial_fusion", "use_aux_heads"])

with open(args.config_file, "r") as file:
    config_dict = json.load(file)

# Adding blank keys that could be dynamically replaced later
config_dict["layer_type_list"] = None

# Replace keys
for key in transfer_keys:
    if key in task_cfg["TASK19"]:
        config_dict[key] = task_cfg["TASK19"][key]
        logger.info(f"Transferring keys:  {key}, {config_dict[key]}")
mmt_config = BertConfig.from_dict(config_dict)

text_bert_config = BertConfig.from_json_file("config/m4c_textbert_textvqa.json")
model = M4C(mmt_config, text_bert_config)

In [None]:
logger.info(f"Resuming from Checkpoint: {args.resume_file}")
checkpoint = torch.load(args.resume_file, map_location="cpu")
new_dict = {}
for attr in checkpoint["model_state_dict"]:
    if attr.startswith("module."):
        new_dict[attr.replace("module.", "", 1)] = checkpoint[
            "model_state_dict"
        ][attr]
    else:
        new_dict[attr] = checkpoint["model_state_dict"][attr]
model.load_state_dict(new_dict)
del checkpoint

model = model.to(device)


In [None]:
def evaluate(
    args,
    task_dataloader_val,
    task_stop_controller,
    task_cfg,
    device,
    task_id,
    model,
    task_losses
):

    predictions = []
    scores = 0.0
    data_size = 0
    model.eval()

    for i, batch in enumerate(task_dataloader_val[task_id]):
        # batch['spatial_adj_matrix'] = torch.zeros_like(batch['spatial_adj_matrix'])
        # batch['spatial_adj_matrix'] = torch.ones_like(batch['spatial_adj_matrix'])
        # batch['spatial_adj_matrix'] = torch.transpose(batch['spatial_adj_matrix'], 2, 1)

        loss, score, batch_size,  batch_dict = ForwardModelsVal(
            args, task_cfg, device, task_id, batch, model, task_losses
        )

        scores += score * batch_size
        data_size += batch_size
        
        save_keys = ['question_id', 'textvqa_scores', 'targets']

        batch_dict_keys = list(batch_dict.keys())
        for key in batch_dict_keys:
            if key not in save_keys:
                del batch_dict[key]
            else:
                batch_dict[key] = batch_dict[key].cpu().detach().numpy()

            predictions.append(batch_dict)

        sys.stdout.write("%d/%d\r" % (i, len(task_dataloader_val[task_id])))
        sys.stdout.flush()

    print("Val Score: ", float(scores)/data_size)
    
    model.train()
    return score

In [None]:
curr_val_score = evaluate(
        args,
        task_dataloader_val,
        None,
        task_cfg,
        device,
        task,
        model,
        task_losses
    )

In [22]:
for i, batch in enumerate(task_dataloader_val[task]):
    print(batch.keys())
    break

dict_keys(['pad_obj_features', 'pad_obj_mask', 'pad_obj_bboxes', 'pad_ocr_features', 'pad_ocr_mask', 'pad_ocr_bboxes', 'segment_ids', 'co_attention_mask', 'question', 'question_id', 'image_id', 'answers', 'image_height', 'image_width', 'question_indices', 'num_question_tokens', 'question_mask', 'ocr_fasttext', 'ocr_tokens', 'ocr_length', 'ocr_phoc', 'spatial_adj_matrix', 'targets', 'train_prev_inds', 'train_loss_mask', 'train_acc_mask', 'spatial_loss_mask', 'spatial_ocr_relations'])


In [31]:
(batch['question'][0] > 0).sum()

tensor(40)

In [40]:
batch['co_attention_mask']

tensor([[[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        ...,

        [[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1,

In [35]:
batch['question_indices'][0], batch['num_question_tokens'][0], batch['question_mask'][0]

(tensor([ 101, 2054, 2003, 1996, 4435, 1997, 2023, 4950, 1029,  102,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0]),
 tensor(10),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))