# get hints for GQA

### get train subset

In [None]:
import json 
import pickle

In [None]:
_path = '../data/neg_gqacp/questions/train_questions.json'
train_qns = json.load(open(_path))

In [None]:
len(train_qns['questions']), train_qns['questions'][0]

In [None]:
import random
random.shuffle(train_qns['questions'])

In [None]:
train_qns['questions'][0]

In [None]:
train_qns_subset = {}
train_qns_subset['questions'] = train_qns['questions'][:int(len(train_qns['questions'])/6)]

In [None]:
len(train_qns_subset['questions'])

In [None]:
_path = '../data/neg_gqacp/questions/train_annotations.json'
train_anns = json.load(open(_path))

In [None]:
subset_qid = set()
for qn in train_qns_subset['questions']:
    subset_qid.add(qn['question_id'])

In [None]:
len(subset_qid)

In [None]:
train_anns_subset = {}
train_anns_subset['annotations'] = []
for ann in train_anns['annotations']:
    if ann['question_id'] in subset_qid:
        train_anns_subset['annotations'].append(ann)

In [None]:
len(train_anns_subset['annotations'])

In [None]:
_path = '../data/neg_gqacp/questions/train-100k_questions.json'
with open(_path, 'w') as f:
    json.dump(train_qns_subset, f)

In [None]:
_path = '../data/neg_gqacp/questions/train-100k_annotations.json'
with open(_path, 'w') as f:
    json.dump(train_anns_subset, f)

### convert hints

In [None]:
import pickle 
_path = '../data/neg_gqacp/hints/gqacp_hints_random.pkl'
with open(_path, 'rb') as f:
    hints = pickle.load(f)

In [None]:
for qid in hints:
    print(hints[qid])
    break
    # hints[qid] = hints[qid].numpy()

In [None]:
with open(_path, 'wb') as handle:
    pickle.dump(hints, handle, protocol=pickle.HIGHEST_PROTOCOL)

## generate importance map using scene graph

In [None]:
import re
import os 
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import json
from tqdm import tqdm

In [None]:
def get_impt_objs_id_from_qns(data_root, split):
    print(f"Reading {split} question file...")
    gqa_questions = json.load(open(os.path.join(data_root, 'questions', f'{split}_balanced_questions.json')))

    print(f"Finding impt objs...")
    qid2impt_objs_ids = {}
    qid2imgid = {}
    for qid, qns in tqdm(gqa_questions.items()):
        obj_ids = []
        for s in qns['semantic']:
            obj_id = re.findall('[0-9]+', s['argument'])
            obj_ids += obj_id
        qid2impt_objs_ids[qid] = obj_ids
        qid2imgid[qid] = qns['imageId']
    return qid2impt_objs_ids, qid2imgid

def get_impt_map(data_root, split):
    qid2impt_objs_ids, qid2imgid = get_impt_objs_id_from_qns(data_root, split)
    
    print(f"Reading {split} scene graph file...")
    gqa_scenegraph = json.load(open(os.path.join(data_root, 'sceneGraph', f'{split}_sceneGraphs.json')))
    
    print(f"Generating {split} masks...")
    for qid, ids_list in tqdm(qid2impt_objs_ids.items()):

        img_id = qid2imgid[qid]
        cur_scene = gqa_scenegraph[img_id]
        w = cur_scene['width']
        h = cur_scene['height']
        img = np.zeros((h, w))

        for obj_id in ids_list:
            if len(obj_id) <= 3: # accidentally includes non obj ids
                continue 
            # get obj info
            obj = cur_scene['objects'][obj_id]
            obj_h = obj['h']
            obj_w = obj['w']
            obj_x = obj['x']
            obj_y = obj['y']

            img[obj_y:obj_y+obj_h, obj_x:obj_x+obj_w] = np.ones(img[obj_y:obj_y+obj_h, obj_x:obj_x+obj_w].shape)

        # save img
        img = Image.fromarray(img*255)
        img = img.convert('RGB')
        img.save(os.path.join(data_root, "masks", split,"GQA_"+qid+".png"))

In [None]:
data_root = "../data/neg_gqa/GQA/"
get_impt_map(data_root, 'train')

## sanity check the masks

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import json
import os
gqa_data_path = '../data/neg_gqa/GQA/'
split = 'train'
gqa_questions = json.load(open(os.path.join(gqa_data_path,f'./questions/{split}_balanced_questions.json')))

In [None]:
import random
qid = random.choice(list(gqa_questions))
print(gqa_questions[qid]['question'])
img_id = gqa_questions[qid]['imageId']
img_ori = Image.open(os.path.join(gqa_data_path,f'images/images/{img_id}.jpg'))
img_mask = Image.open(os.path.join(gqa_data_path,f'masks/{split}/GQA_{qid}.png'))

In [None]:
dst = Image.new('RGB', (img_ori.width + img_mask.width, img_mask.height))
dst.paste(img_ori, (0, 0))
dst.paste(img_mask, (img_ori.width, 0))

plt.imshow(dst)
plt.show()

## get hints from masks

In [None]:
import h5py
import numpy as np
import torch
import os 
import sys

import json
import pickle

from tqdm import tqdm
from PIL import Image
import cv2
import matplotlib.pyplot as plt

In [None]:
def visualization(img_id, bbox_scores, spatials):
    # original image
    img = cv2.imread(os.path.join(gqa_data_path,f'images/images/{img_id}.jpg'))
    h, w, _ = img.shape
    plt.imshow(img)
    plt.show()
    
    # bbox image
    h, w, _ = img.shape
    bbox_img = img.copy()
    for obj in spatials:
        x1, y1, x2, y2, _, _ = obj
        cv2.rectangle(bbox_img, 
                      (int(x1*w), int(y1*h)), 
                      (int(x2*w), int(y2*h)), 
                      (255,0,0), 2)
    plt.imshow(bbox_img)
    plt.show() 
    
    mask = torch.zeros(img.shape[0], img.shape[1])
    # get the max score for diff bbox
    for index in range(len(bbox_scores)):
        x1, y1, x2, y2, _, _ = spatials[index]
        curr_score_tensor = mask[int(y1*h):int(y2*h), int(x1*w):int(x2*w)] # DEBUG!
        new_score_tensor = torch.ones_like(curr_score_tensor)*bbox_scores[index].item()
        mask[int(y1*h):int(y2*h), int(x1*w):int(x2*w)] = torch.max(new_score_tensor,\
                                                                   mask[int(y1*h):int(y2*h), int(x1*w):int(x2*w)])
    mask = (mask - mask.min()) / (mask.max() - mask.min())
    mask_norm = mask.cpu().data.numpy()
    # get masked img
    mask = mask.unsqueeze_(-1)
    mask = mask.expand(img.shape)
    masked_img = img * mask.cpu().data.numpy()

    print(mask_norm.shape, mask.shape)
    plt.imshow(mask)
    plt.show()
    return masked_img, mask

In [None]:
# help func: calculate importance score
def calc_att_score(bbox, att_map, SOFT=False):
    if SOFT:
        mask = att_map == 0
        att_map = att_map + np.ones(att_map.shape)*0.1 * mask
    # bbox: x1, y1, x2, y2 (scaled localtion)
    # att_map: 
    x1, y1, x2, y2 = bbox
    region_area = np.abs(x1 - x2) * np.abs(y1 - y2)
    assert(len(att_map.shape) == 2)
    h = att_map.shape[0]
    w = att_map.shape[1]
    
    score_inside = np.sum(att_map[int(y1*h):int(y2*h), int(x1*w):int(x2*w)]) # DEBUG!
    score_outside = np.sum(att_map) - score_inside
    score_inside = score_inside / region_area
    score_outside = score_outside / (1.0 - region_area)
    importance = score_inside / (score_inside + score_outside)
    return importance

In [None]:
def get_hint_scores_from_masks(split):
    # read questions
    gqa_questions = json.load(open(f"../data/neg_gqa/questions/{split}_questions.json"))['questions']
    # read spatials
    h5_path = f"../data/neg_gqa/{split}36.hdf5"
    hf = h5py.File(h5_path, 'r')
    spatials = hf.get('spatial_features')
    # read img_id2idx
    image_id2ix = pickle.load(open(f"../data/neg_gqa/{split}36_imgid2img.pkl", 'rb'))
    
    qid2hints = {}
    VISUALIZE = False
    for qn in tqdm(gqa_questions):
        # read
        img_id = qn['image_id']
        qid = qn['question_id']
        spatial = spatials[image_id2ix[img_id]]

        # read mask
        img_mask = cv2.imread(os.path.join(gqa_data_path,f'masks/{split}/GQA_{qid}.png'))
        img_mask = img_mask.sum(2)

        bbox_impt = []
        for i in range(spatial.shape[0]):
            importance = calc_att_score(spatial[i, :4], img_mask)
            bbox_impt.append(importance)

        if VISUALIZE:
            print(qn['question'])
            plt.imshow(img_mask)
            plt.show()
            visualization(img_id, bbox_impt, spatial)
    qid2hints[qid] = np.array(bbox_impt)
    
    pickle.dump(qid2hints, open(f"../data/neg_gqa/hints/{split}_hints.pkl", 'wb'))
    return qid2hints

In [None]:
gqa_data_path = '../data/neg_gqa/GQA/'
split = 'train'
get_hint_scores_from_masks(split)

## compare two methods for impt score

In [None]:
import h5py
import numpy as np
import torch
import os 
import sys
import re

import json
import pickle

from tqdm import tqdm
from PIL import Image
import cv2
import matplotlib.pyplot as plt

In [None]:
def visualization(img_id, bbox_scores, spatials, MASK_ONLY=False):
    # original image
    img = cv2.imread(os.path.join(gqa_data_root,f'images/images/{img_id}.jpg'))
    h, w, _ = img.shape
    if not MASK_ONLY:
        plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        plt.axis("off")
        plt.show()

    # bbox image
    h, w, _ = img.shape
    bbox_img = img.copy()
    for obj in spatials:
        x1, y1, x2, y2, _, _ = obj
        cv2.rectangle(bbox_img, 
                      (int(x1*w), int(y1*h)), 
                      (int(x2*w), int(y2*h)), 
                      (255,0,0), 2)
    if not MASK_ONLY:
        plt.imshow(cv2.cvtColor(bbox_img, cv2.COLOR_BGR2RGB))
        plt.axis("off")
        plt.show() 
    
    mask = torch.zeros(img.shape[0], img.shape[1])
    # get the max score for diff bbox
    for index in range(len(bbox_scores)):
        x1, y1, x2, y2, _, _ = spatials[index]
        curr_score_tensor = mask[int(y1*h):int(y2*h), int(x1*w):int(x2*w)] 
        new_score_tensor = torch.ones_like(curr_score_tensor)*bbox_scores[index].item()
        mask[int(y1*h):int(y2*h), int(x1*w):int(x2*w)] = torch.max(new_score_tensor,\
                                                                   mask[int(y1*h):int(y2*h), int(x1*w):int(x2*w)])
    # get masked img
    mask = mask.unsqueeze_(-1)
    mask = mask.expand(img.shape)
    
    plt.imshow(cv2.cvtColor(mask.numpy(), cv2.COLOR_BGR2RGB), vmin=0, vmax=1)
    plt.axis("off")
    plt.show()
    return mask

In [None]:
def get_impt_objs_id_from_qns(data_root, split):
    print(f"Reading {split} question file...")
    gqa_questions = json.load(open(os.path.join(data_root, 'questions', f'{split}_balanced_questions.json')))

    print(f"Finding impt objs...")
    qid2impt_objs_ids = {}
    qid2imgid = {}
    for qid, qns in tqdm(gqa_questions.items()):
        obj_ids = []
        for s in qns['semantic']:
            obj_id = re.findall('[0-9]+', s['argument'])
            obj_ids += obj_id
        qid2impt_objs_ids[qid] = obj_ids
        qid2imgid[qid] = qns['imageId']
    return qid2impt_objs_ids, qid2imgid

gqa_data_root = "../data/neg_gqacp/GQA/"
split = 'val'
qid2impt_objs_ids, qid2imgid = get_impt_objs_id_from_qns(gqa_data_root, split)

In [None]:
gqa_scenegraph = json.load(open(os.path.join(gqa_data_root, 'sceneGraph', f'{split}_sceneGraphs.json')))

In [None]:
# get gt bbox
qid2gt_impt_bbox = {}
for qid, ids_list in tqdm(qid2impt_objs_ids.items()):
    if len(ids_list)==0:
        continue
        
    img_id = qid2imgid[qid]
    cur_scene = gqa_scenegraph[img_id]
    w = cur_scene['width']
    h = cur_scene['height']
    img = np.zeros((h, w))
    
    gt_impt_bbox_list = []
    for obj_id in ids_list:
        if len(obj_id) <= 3: # accidentally includes non obj ids
            continue 
        # get obj info
        obj = cur_scene['objects'][obj_id]
        obj_h = obj['h']
        obj_w = obj['w']
        obj_x = obj['x']
        obj_y = obj['y']
        
        obj_h, obj_w = img[obj_y:obj_y+obj_h, obj_x:obj_x+obj_w].shape
        gt_impt_bbox_list.append([obj_x / w, 
                                  obj_y / h,
                                  (obj_x+obj_w) / w,
                                  (obj_y+obj_h) / h])
    qid2gt_impt_bbox[qid] = gt_impt_bbox_list

In [None]:
import torchvision.ops.boxes as bops

In [None]:
split = 'dev'
# read questions
gqa_questions = json.load(open(f"../data/neg_gqacp/questions/{split}_questions.json"))['questions']
# read spatials
h5_path = f"../data/neg_gqacp/{split}36.hdf5"
hf = h5py.File(h5_path, 'r')
spatials = hf.get('spatial_features')
# read img_id2idx
image_id2ix = pickle.load(open(f"../data/neg_gqacp/{split}36_imgid2img.pkl", 'rb'))

In [None]:
import random
qn = random.choice(gqa_questions)

img_id = qn['image_id']
qid = qn['question_id']
spatial = spatials[image_id2ix[img_id]]
print(qn['question'])

In [None]:
# method #2
gt_bbox_list = qid2gt_impt_bbox[qid]

impt_scores = torch.zeros((spatial.shape[0],))
for index, detected_bbox in enumerate(spatial[:, :4]):
    for gt_bbox in gt_bbox_list:
        iou = bops.box_iou(torch.tensor(detected_bbox).unsqueeze(0), 
                          torch.tensor(gt_bbox).unsqueeze(0))
        impt_scores[index] = max(iou, impt_scores[index])

In [None]:
_ = visualization(img_id, impt_scores, spatial)

In [None]:
# METHOD 1
def calc_att_score(bbox, att_map, SOFT=False):
    if SOFT:
        mask = att_map == 0
        att_map = att_map + np.ones(att_map.shape)*(0.05*255) * mask
    # bbox: x1, y1, x2, y2 (scaled localtion)
    # att_map: 
    x1, y1, x2, y2 = bbox
    region_area = np.abs(x1 - x2) * np.abs(y1 - y2)
    assert(len(att_map.shape) == 2)
    h = att_map.shape[0]
    w = att_map.shape[1]
    
    score_inside = np.sum(att_map[int(y1*h):int(y2*h), int(x1*w):int(x2*w)]) # DEBUG!
    score_outside = np.sum(att_map) - score_inside
    score_inside = score_inside / region_area
    score_outside = score_outside / (1.0 - region_area)
    importance = score_inside / (score_inside + score_outside)
    return importance

img_mask = cv2.imread(os.path.join(gqa_data_root,f'masks/{split}/GQA_{qid}.png'))
img_mask = img_mask.mean(2)
plt.imshow(img_mask,cmap='gray')
plt.axis("off")
plt.show()

bbox_impt = []
for i in range(spatial.shape[0]):
    importance = calc_att_score(spatial[i, :4], img_mask)
    bbox_impt.append(importance)
bbox_impt_soft = []
for i in range(spatial.shape[0]):
    importance = calc_att_score(spatial[i, :4], img_mask, SOFT=True)
    bbox_impt_soft.append(importance)

In [None]:
mask = visualization(img_id, bbox_impt, spatial, MASK_ONLY=True)
mask = visualization(img_id, bbox_impt_soft, spatial, MASK_ONLY=True)

## get hints using IoU

In [None]:
import h5py
import numpy as np
import torch
import os 
import sys
import re

import json
import pickle

from tqdm import tqdm
from PIL import Image
import cv2
import matplotlib.pyplot as plt

In [None]:
def get_impt_objs_id_from_qns(data_root, split):
    print(f"Reading {split} question file...")
    gqa_questions = json.load(open(os.path.join(data_root, 'questions', f'{split}_balanced_questions.json')))

    print(f"Finding impt objs...")
    qid2impt_objs_ids = {}
    qid2imgid = {}
    for qid, qns in tqdm(gqa_questions.items()):
        obj_ids = []
        for s in qns['semantic']:
            obj_id = re.findall('[0-9]+', s['argument'])
            obj_ids += obj_id
        qid2impt_objs_ids[qid] = obj_ids
        qid2imgid[qid] = qns['imageId']
    return qid2impt_objs_ids, qid2imgid

gqa_data_root = "../data/neg_gqacp/GQA/"
split = 'val'
qid2impt_objs_ids_val, qid2imgid_val = get_impt_objs_id_from_qns(gqa_data_root, split)
split = 'train'
qid2impt_objs_ids_train, qid2imgid_train = get_impt_objs_id_from_qns(gqa_data_root, split)

In [None]:
# merge qid2impt_objs_ids & qid2imgid
qid2impt_objs_ids_train.update(qid2impt_objs_ids_val)
qid2impt_objs_ids = qid2impt_objs_ids_train
qid2imgid_train.update(qid2imgid_val)
qid2imgid = qid2imgid_train

In [None]:
del qid2imgid_val, qid2imgid_train

In [None]:
gqa_scenegraph_train = json.load(open(os.path.join(gqa_data_root, 'sceneGraph', f'train_sceneGraphs.json')))
gqa_scenegraph_val = json.load(open(os.path.join(gqa_data_root, 'sceneGraph', f'val_sceneGraphs.json')))

In [None]:
# merge gqa_scenegraph
gqa_scenegraph_train.update(gqa_scenegraph_val)
gqa_scenegraph = gqa_scenegraph_train

In [None]:
del gqa_scenegraph_train, gqa_scenegraph_val

In [None]:
# get gt bbox
qid2gt_impt_bbox = {}
for qid, ids_list in tqdm(qid2impt_objs_ids.items()):
    if len(ids_list)==0:
        continue
        
    img_id = qid2imgid[qid]
    cur_scene = gqa_scenegraph[img_id]
    w = cur_scene['width']
    h = cur_scene['height']
    img = np.zeros((h, w))
    
    gt_impt_bbox_list = []
    for obj_id in ids_list:
        if len(obj_id) <= 3: # accidentally includes non obj ids
            continue 
        # get obj info
        obj = cur_scene['objects'][obj_id]
        obj_h = obj['h']
        obj_w = obj['w']
        obj_x = obj['x']
        obj_y = obj['y']
        
        obj_h, obj_w = img[obj_y:obj_y+obj_h, obj_x:obj_x+obj_w].shape
        gt_impt_bbox_list.append([obj_x / w, 
                                  obj_y / h,
                                  (obj_x+obj_w) / w,
                                  (obj_y+obj_h) / h])
    qid2gt_impt_bbox[qid] = gt_impt_bbox_list

In [None]:
for qn in gqa_questions_train:
    if qn['question_id'] == '08902400':
        print(qn)
        break

In [None]:
'08902400' in qid2gt_impt_bbox

In [None]:
# read questions
gqa_questions_train = json.load(open(f"../data/neg_gqacp/questions/train_questions.json"))['questions']
gqa_questions_dev = json.load(open(f"../data/neg_gqacp/questions/dev_questions.json"))['questions']
gqa_questions_test_id = json.load(open(f"../data/neg_gqacp/questions/test-id_questions.json"))['questions']
gqa_questions_test_ood = json.load(open(f"../data/neg_gqacp/questions/test-ood_questions.json"))['questions']

# read spatials
h5_path = f"../data/neg_gqacp/train36.hdf5"
hf1 = h5py.File(h5_path, 'r')
spatials_train = hf1.get('spatial_features')

h5_path = f"../data/neg_gqacp/dev36.hdf5"
hf2 = h5py.File(h5_path, 'r')
spatials_dev = hf2.get('spatial_features')

h5_path = f"../data/neg_gqacp/test-id36.hdf5"
hf3 = h5py.File(h5_path, 'r')
spatials_test_id = hf3.get('spatial_features')

h5_path = f"../data/neg_gqacp/test-ood36.hdf5"
hf4 = h5py.File(h5_path, 'r')
spatials_test_ood = hf4.get('spatial_features')
# read img_id2idx
image_id2ix_train = pickle.load(open(f"../data/neg_gqacp/train36_imgid2img.pkl", 'rb'))
image_id2ix_dev = pickle.load(open(f"../data/neg_gqacp/dev36_imgid2img.pkl", 'rb'))
image_id2ix_test_id = pickle.load(open(f"../data/neg_gqacp/test-id36_imgid2img.pkl", 'rb'))
image_id2ix_test_ood = pickle.load(open(f"../data/neg_gqacp/test-ood36_imgid2img.pkl", 'rb'))

In [None]:
from torchvision.ops import box_iou

In [None]:
def get_iou_score(gqa_questions, gqa_spatials, gqd_image_id2ix):
    qid2iou_score = {}
    for qn in tqdm(gqa_questions):
        img_id = qn['image_id']
        qid = qn['question_id']
        spatial = gqa_spatials[gqd_image_id2ix[img_id]]

        # method #2
        if qid not in qid2gt_impt_bbox: # if no gt bbox, ignore
            continue
        gt_bbox_list = qid2gt_impt_bbox[qid]

        impt_scores = torch.zeros((spatial.shape[0],))
        for index, detected_bbox in enumerate(spatial[:, :4]):
            for gt_bbox in gt_bbox_list:
                iou = box_iou(torch.tensor(detected_bbox).unsqueeze(0), 
                                  torch.tensor(gt_bbox).unsqueeze(0))
                impt_scores[index] = max(iou, impt_scores[index])
        qid2iou_score[qid] = impt_scores
    return qid2iou_score

In [None]:
from tqdm import tqdm
hint_train = get_iou_score(gqa_questions_train, spatials_train, image_id2ix_train)

In [None]:
hint_dev = get_iou_score(gqa_questions_dev, spatials_dev, image_id2ix_dev)

In [None]:
hint_test_id = get_iou_score(gqa_questions_test_id, spatials_test_id, image_id2ix_test_id)

In [None]:
hint_test_ood = get_iou_score(gqa_questions_test_ood, spatials_test_ood, image_id2ix_test_ood)

In [None]:
len(hint_train), len(hint_dev), len(hint_test_id), len(hint_test_ood)

In [None]:
_path = '../data/neg_gqacp/hints/train_hints.pkl'
with open(_path, 'wb') as handle:
    pickle.dump(hint_train, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
_path = '../data/neg_gqacp/hints/dev_hints.pkl'
with open(_path, 'wb') as handle:
    pickle.dump(hint_dev, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
_path = '../data/neg_gqacp/hints/test-id_hints.pkl'
with open(_path, 'wb') as handle:
    pickle.dump(hint_test_id, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
_path = '../data/neg_gqacp/hints/test-ood_hints.pkl'
with open(_path, 'wb') as handle:
    pickle.dump(hint_test_ood, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
hint_train.update(hint_dev)
hint_train.update(hint_test_id)
hint_train.update(hint_test_ood)

In [None]:
_path = '../data/neg_gqacp/hints/gqacp_hints.pkl'
with open(_path, 'wb') as handle:
    pickle.dump(hint_train, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
hints_random = {}
for qid in hint_train:
    h = np.random.rand(36)
    hints_random[qid] = h

In [None]:
_path = '../data/neg_gqacp/hints/gqacp_hints_random.pkl'
with open(_path, 'wb') as handle:
    pickle.dump(hints_random, handle, protocol=pickle.HIGHEST_PROTOCOL)