In [None]:
import numpy as np
import json
from tqdm import tqdm
import os
import re
import random 
import nltk
import matplotlib.pyplot as plt

# dist shift

## read original questions - XAI

In [None]:
# CLEVR-XAI
question_dir = '../data/CLEVR-XAI_v1.0/'

_path = os.path.join(question_dir, 'CLEVR-XAI_simple_questions.json')
all_simple_qs = json.load(open(_path))['questions']
_path = os.path.join(question_dir, 'CLEVR-XAI_complex_questions.json')
all_complex_qs = json.load(open(_path))['questions']

# both simple/complex qid starts with 0; add complex qid by 100,000
for q in all_complex_qs: 
    q['question_index'] += 100000

In [None]:
# change ans details for CLEVR-XAI
def convert_xai_ans_details(qs):
    for q in qs:
        ans_type = type(q["answer"])
        if ans_type != str:
            if ans_type == int:
                q["answer"] = str(q["answer"])
            elif ans_type == bool:
                if q["answer"] == False:
                    q["answer"] = "no"
                elif q["answer"] == True:
                    q["answer"] = "yes"
                else:
                    raise ValueError("unknown bool")
            else: 
                raise ValueError("unknown ans type")
    return qs

all_simple_qs = convert_xai_ans_details(all_simple_qs)
all_complex_qs = convert_xai_ans_details(all_complex_qs)

In [None]:
all_qns = all_simple_qs+all_complex_qs

## read original questions - GQA

In [None]:
gqa_data_root = '../data/neg_gqa/GQA/questions/'

In [None]:
# change qns to list
all_qns = []

train_qns = json.load(open(os.path.join(gqa_data_root, "train_balanced_questions.json")))
for qid, value in tqdm(train_qns.items()):
    value['question_id'] = qid
    all_qns.append(value)
del train_qns

val_qns = json.load(open(os.path.join(gqa_data_root, "val_balanced_questions.json")))
for qid, value in tqdm(val_qns.items()):
    value['question_id'] = qid
    all_qns.append(value)
del val_qns

## read original questions - VQA-HAT

In [None]:
vqa_data_root = '../data/neg_data_vqa/'

In [None]:
ann_train = json.load(open(os.path.join(vqa_data_root, "v2_mscoco_train2014_annotations.json"), 'r'))
ann_val = json.load(open(os.path.join(vqa_data_root, "v2_mscoco_val2014_annotations.json"), 'r'))
qns_train = json.load(open(os.path.join(vqa_data_root, 
                                        "v2_OpenEnded_mscoco_train2014_questions.json"), 'r'))
qns_val = json.load(open(os.path.join(vqa_data_root, 
                                      "v2_OpenEnded_mscoco_val2014_questions.json"), 'r'))

In [None]:
_path = os.path.join(vqa_data_root, "hints/train_hat.pkl")
with open (_path, 'rb') as f:
    hints_train = pickle.load(f)
    
_path = os.path.join(vqa_data_root, "hints/val_hat.pkl")
with open (_path, 'rb') as f:
    hints_val = pickle.load(f)

hints_train.update(hints_val)
hints = hints_train
del hints_train, hints_val

In [None]:
len(hints)

In [None]:
len(ann_train['annotations']), len(ann_val['annotations'])

In [None]:
for i in range(len(ann_train['annotations'])):
    assert ann_train['annotations'][i]['question_id'] == qns_train['questions'][i]['question_id']
    qns_train['questions'][i].update(ann_train['annotations'][i])
del ann_train

for i in range(len(ann_val['annotations'])):
    assert ann_val['annotations'][i]['question_id'] == qns_val['questions'][i]['question_id']
    qns_val['questions'][i].update(ann_val['annotations'][i])
del ann_val

In [None]:
all_qns = []
for qn in tqdm(qns_train['questions']):
    if qn['question_id'] in hints:
        all_qns.append(qn)
del qns_train

for qn in tqdm(qns_val['questions']):
    if qn['question_id'] in hints:
        all_qns.append(qn)
del qns_val

In [None]:
len(all_qns)

In [None]:
dataset = 'vqa-hat'

## grouping QAs

In [None]:
# groups: (q_type, ans) -> group
# group: questions (list), concepts (list)
groups = {}
prefix_length = 3
# group simple questions
for q in tqdm(all_qns):
    
    q_type = ' '.join(q['question'].split()[:prefix_length])
    
    if dataset=='vqa-hat':
        ans = q['multiple_choice_answer']
    else:
        ans = q['answer']
    if (q_type, ans) not in groups:
        groups[(q_type, ans)] = {'questions': [], 'concepts': []}
    
    # add question
    groups[(q_type, ans)]['questions'].append(q)
    # add concepts
    set_question = re.split("[\W]+", q['question'].lower())
    set_ans = re.split("[\W]+", ans.lower())
    groups[(q_type, ans)]['concepts'] += set_question + set_ans
    groups[(q_type, ans)]['concepts'] = list(set(groups[(q_type, ans)]['concepts']))
print(f"Grouped questions into {len(groups)} groups")

In [None]:
# remove all stop words from concepts
nltk.download('stopwords')
stop_words = set(nltk.corpus.stopwords.words('english'))

for key in tqdm(groups):    
    new_concepts = []
    for w in groups[key]['concepts']:
        if w.lower() not in stop_words and w !='':
            new_concepts.append(w)
    groups[key]['concepts'] = set(new_concepts)

In [None]:
# visualize groups
all_length = []
for key, group in groups.items():
    all_length.append(len(group['questions']))
all_length = np.array(all_length)

In [None]:
cutoff = 10
np.sum(all_length>=cutoff), np.sum(all_length<cutoff), np.sum(all_length==1)

In [None]:
plt.hist(all_length)
plt.show()

In [None]:
groups[('There is a', 'cylinder')]['questions'][0].keys()

## re-split

In [None]:
total_size = len(all_qns)
total_size, len(groups.keys())

In [None]:
total_size = 0
for key, value in groups.items():
    total_size += len(value['questions'])
total_size

In [None]:
import ipdb

In [None]:
def get_concepts_from_question_list(qs):
    
    concepts = set()
    for q in qs:
        if dataset=='vqa-hat':
            ans = q['multiple_choice_answer']
        else:
            ans = q['answer']
        
        set_question = re.split("[\W]+", q['question'].lower()) # concepts in question
        set_ans = re.split("[\W]+", ans.lower()) # concepts in answer
        concepts = concepts.union(set(set_question + set_ans))
    # print(f"num of concepts: {len(concepts)}")
    # remove stopwords
    new_concepts = set()
    for w in concepts:
        if w.lower() not in stop_words and w !='':
            new_concepts = new_concepts.union({w})
    # print(f"after removing stop words: {len(new_concepts)}")
    return new_concepts

In [None]:
from tqdm import tqdm
from IPython.display import clear_output

In [None]:
USE_GREEDY = True
# reduce_simple_ood = False
USE_SOFT_SHIFT = False
soft_shift_ratio = 0.2
ood_ratio = 0.15
progress_list = []
speed_up_val_concepts_size = 20

# questions
all_val_questions = []
all_train_questions = []
# concepts
cur_train_concepts = set()
cur_val_concepts = set()
remaining_val_concepts = set()
# key
visited_keys = []

# shuffle
items = list(groups.items())
random.shuffle(items)
for index,(key, value) in enumerate(items):
    # randomly select one group
    if key in visited_keys:
        continue
    qs = value['questions']
    concepts = value['concepts']
    
    if USE_SOFT_SHIFT:
        # split qs 8:2
        random.shuffle(qs)
        cutoff = int(len(qs)*soft_shift_ratio)
        qs_20 = qs[:cutoff]
        qs_80 = qs[cutoff:]
        concepts_20 = get_concepts_from_question_list(qs_20)
        concepts_80 = get_concepts_from_question_list(qs_80)
        # randomly assign to train/val
        if np.random.rand()>0.5:
            # val
            all_val_questions += qs_20
            cur_val_concepts = cur_val_concepts.union(concepts_20)
            # train
            all_train_questions += qs_80
            cur_train_concepts = cur_train_concepts.union(concepts_80)
            # remaining val
            remaining_val_concepts = remaining_val_concepts.union(concepts_20)
            remaining_val_concepts = remaining_val_concepts - concepts_80
        else:
            # val
            all_val_questions += qs_80
            cur_val_concepts = cur_val_concepts.union(concepts_80)
            # train
            all_train_questions += qs_20
            cur_train_concepts = cur_train_concepts.union(concepts_20)
            # remaining val
            remaining_val_concepts = remaining_val_concepts.union(concepts_80)
            remaining_val_concepts = remaining_val_concepts - concepts_20
    else: # hard shift
        all_val_questions += qs
        cur_val_concepts = cur_val_concepts.union(concepts)
        # remaining 
        l = len(remaining_val_concepts)
        remaining_val_concepts = remaining_val_concepts.union(concepts)
        print(f"increase concepts by {len(remaining_val_concepts) - l}")
    
    # update
    visited_keys.append(key)
    progress_list.append(len(all_val_questions) / (total_size*ood_ratio))
    # print(f"visited key: f{key}")
    
    # end when reaching desired ood size
    if len(all_val_questions) > total_size*ood_ratio:
        break
    
    if USE_GREEDY: # find group with max coverage
        # get uncovered val concepts
        if len(remaining_val_concepts) <= speed_up_val_concepts_size: # if train concepts cover all val concepts
            # print(index, len(all_val_questions) / total_size*ood_ratio)
            continue
        print(f"start greedy search with {len(remaining_val_concepts)} remaining concepts...")
        max_intersect = -1
        max_key = None

        next_items = list(groups.items())
        random.shuffle(next_items)
        for next_key, next_value in next_items:
            if next_key == key: # ignore current
                continue
            elif next_key in visited_keys: # ignore visied keys
                continue 
            else: 
                if USE_SOFT_SHIFT:
                    # split qs 8:2
                    qs = next_value['questions']
                    random.shuffle(qs)
                    cutoff = int(len(qs)*soft_shift_ratio)
                    qs_20 = qs[:cutoff]
                    qs_80 = qs[cutoff:]
                    concepts_20 = get_concepts_from_question_list(qs_20)
                    concepts_80 = get_concepts_from_question_list(qs_80)
                    # random assign
                    if np.random.rand()>0.5:
                        train_qs = qs_20
                        train_concepts = concepts_20
                        val_qs = qs_80
                        val_concepts = concepts_80
                    else:
                        train_qs = qs_80
                        train_concepts = concepts_80
                        val_qs = qs_20
                        val_concepts = concepts_20
                else:
                    train_qs = next_value['questions']
                    train_concepts = next_value['concepts']

                len_intersect = len(remaining_val_concepts.intersection(train_concepts))
                # update max
                if len_intersect > max_intersect:
                    max_intersect = len_intersect
                    max_key = next_key
                    max_train_qs = train_qs
                    max_train_concepts = train_concepts
                    if USE_SOFT_SHIFT:
                        max_val_qs = val_qs
                        max_val_concepts = val_concepts

        if max_key == None: # no group left
            break
        # add to train/val
        assert(max_key not in visited_keys)
        if USE_SOFT_SHIFT:
            # add val
            all_val_questions += max_val_qs
            cur_val_concepts = cur_val_concepts.union(max_val_concepts)
            # remaining
            remaining_val_concepts = remaining_val_concepts.union(max_val_concepts)
        # add train
        all_train_questions += max_train_qs
        cur_train_concepts = cur_train_concepts.union(max_train_concepts)
        # remaining
        remaining_val_concepts = remaining_val_concepts - max_train_concepts
        # update
        visited_keys.append(max_key)
        progress_list.append(len(all_val_questions) / (total_size*ood_ratio))

    clear_output(wait=True)
    print(index, len(all_val_questions) / (total_size*ood_ratio))
    if USE_GREEDY:
        print(f"max intersection {max_intersect}")
    

print("finished assigning val set", len(all_train_questions), len(all_val_questions))   

In [None]:
len(all_train_questions), len(all_val_questions), len(visited_keys), len(groups)

In [None]:
plt.plot(np.arange(len(progress_list)), progress_list)
print(len(progress_list))
plt.show()

In [None]:
# assign all rest groups to train
for key,value in groups.items():
    if key in visited_keys:
        continue
    else:
        # assign to train
        visited_keys.append(key)
        all_train_questions += value['questions']
print("finished assigning train set", len(all_train_questions), len(all_val_questions))   

In [None]:
len(all_train_questions) + len(all_val_questions), total_size

In [None]:
# split train into train/dev/test-id -> 6:1:1.5
random.shuffle(all_train_questions)
unit_length = int(len(all_train_questions) / 8.5)
real_train_questions = all_train_questions[:unit_length*6]
dev_questions = all_train_questions[unit_length*6 : unit_length*7]
test_id_questions = all_train_questions[unit_length*7 :]
len(real_train_questions), len(dev_questions), len(test_id_questions)

In [None]:
real_train_concepts = get_concepts_from_question_list(real_train_questions)
'train concept size', len(real_train_concepts)

In [None]:
dev_concepts = get_concepts_from_question_list(dev_questions)
'dev concept size', len(dev_concepts)

In [None]:
test_id_concepts = get_concepts_from_question_list(test_id_questions)
print(f'test-id concept size: {len(test_id_concepts)}')
print(f'covered concept size id: {len(real_train_concepts.intersection(test_id_concepts))}')

In [None]:
all_val_concepts = get_concepts_from_question_list(all_val_questions)
print(f'test-ood concept size: {len(all_val_concepts)}')
print(f'covered concept size ood: {len(real_train_concepts.intersection(all_val_concepts))}')

In [None]:
'covered concept size', len(real_train_concepts.intersection(all_val_concepts))

# save - XAI

In [None]:
_path = os.path.join(question_dir, 'CLEVR-XAI_simple_questions.json')
simple_qs = json.load(open(_path))
_path = os.path.join(question_dir, 'CLEVR-XAI_complex_questions.json')
complex_qs = json.load(open(_path))

In [None]:
complex_qs['info'] == simple_qs['info']

In [None]:
# save 
new_train_qs = {}
new_train_qs['info'] = simple_qs['info']
new_train_qs['questions'] = real_train_questions

new_dev_qs = {}
new_dev_qs['info'] = simple_qs['info']
new_dev_qs['questions'] = dev_questions

new_test_id_qs = {}
new_test_id_qs['info'] = simple_qs['info']
new_test_id_qs['questions'] = test_id_questions

new_test_ood_qs = {}
new_test_ood_qs['info'] = simple_qs['info']
new_test_ood_qs['questions'] = all_val_questions

In [None]:
len(new_train_qs['questions']), len(new_test_ood_qs['questions'])

In [None]:
save_dir = '../data/neg_data_xaicp/questions/'

In [None]:
_path = os.path.join(save_dir, 'CLEVRXAICP_train_questions.json')
with open(_path, 'w') as f:
    json.dump(new_train_qs, f)

In [None]:
_path = os.path.join(save_dir, 'CLEVRXAICP_dev_questions.json')
with open(_path, 'w') as f:
    json.dump(new_dev_qs, f)

In [None]:
_path = os.path.join(save_dir, 'CLEVRXAICP_test-id_questions.json')
with open(_path, 'w') as f:
    json.dump(new_test_id_qs, f)

In [None]:
_path = os.path.join(save_dir, 'CLEVRXAICP_test-ood_questions.json')
with open(_path, 'w') as f:
    json.dump(new_test_ood_qs, f)

# save - hat

In [None]:
_path = os.path.join(vqa_data_root, 'v2_OpenEnded_mscoco_train2014_questions.json')
qns_train = json.load(open(_path))
_path = os.path.join(vqa_data_root, 'v2_OpenEnded_mscoco_val2014_questions.json')
qns_val = json.load(open(_path))

In [None]:
# split into questions and annotations
_path = os.path.join(vqa_data_root, 'v2_mscoco_train2014_annotations.json')
ann_train = json.load(open(_path))
_path = os.path.join(vqa_data_root, 'v2_mscoco_val2014_annotations.json')
ann_val = json.load(open(_path))

In [None]:
def split_into_ann_qns(full_qns):
    anns_only = []
    qns_only = []
    for full_qn in full_qns:
        ann = {}
        qn = {}
        # update qn
        qn['image_id'] = full_qn['image_id']
        qn['question'] = full_qn['question']
        qn['question_id'] = full_qn['question_id']
        # update ann
        ann['question_type'] = full_qn['question_type']
        ann['multiple_choice_answer'] = full_qn['multiple_choice_answer']
        ann['answers'] = full_qn['answers']
        ann['image_id'] = full_qn['image_id']
        ann['answer_type'] = full_qn['answer_type']
        ann['question_id'] = full_qn['question_id']
        # append
        anns_only.append(ann)
        qns_only.append(qn)
    return qns_only, anns_only

In [None]:
train_qns_only, train_anns_only = split_into_ann_qns(real_train_questions)
dev_qns_only, dev_anns_only = split_into_ann_qns(dev_questions)
test_id_qns_only, test_id_anns_only = split_into_ann_qns(test_id_questions)
test_ood_qns_only, test_ood_anns_only = split_into_ann_qns(all_val_questions)

In [None]:
# create qns
qns_train['questions'] = train_qns_only

qns_val['questions'] = None
qns_dev = qns_val.copy()
qns_test_id = qns_val.copy()
qns_test_ood = qns_val.copy()

qns_dev['questions'] = dev_qns_only
qns_test_id['questions'] = test_id_qns_only
qns_test_ood['questions'] = test_ood_qns_only

In [None]:
_path = os.path.join(vqa_data_root, 'hatcp_train_questions.json')
with open(_path, 'w') as outfile:
    json.dump(qns_train, outfile)
    
_path = os.path.join(vqa_data_root, 'hatcp_dev_questions.json')
with open(_path, 'w') as outfile:
    json.dump(qns_dev, outfile)

_path = os.path.join(vqa_data_root, 'hatcp_test-id_questions.json')
with open(_path, 'w') as outfile:
    json.dump(qns_test_id, outfile)
    
_path = os.path.join(vqa_data_root, 'hatcp_test-ood_questions.json')
with open(_path, 'w') as outfile:
    json.dump(qns_test_ood, outfile)

In [None]:
# create anns
ann_train['annotations'] = train_anns_only

ann_val['annotations'] = None
ann_dev = qns_val.copy()
ann_test_id = qns_val.copy()
ann_test_ood = qns_val.copy()

ann_dev['annotations'] = dev_anns_only
ann_test_id['annotations'] = test_id_anns_only
ann_test_ood['annotations'] = test_ood_anns_only

In [None]:
_path = os.path.join(vqa_data_root, 'hatcp_train_annotations.json')
with open(_path, 'w') as outfile:
    json.dump(ann_train, outfile)
    
_path = os.path.join(vqa_data_root, 'hatcp_dev_annotations.json')
with open(_path, 'w') as outfile:
    json.dump(ann_dev, outfile)

_path = os.path.join(vqa_data_root, 'hatcp_test-id_annotations.json')
with open(_path, 'w') as outfile:
    json.dump(ann_test_id, outfile)

_path = os.path.join(vqa_data_root, 'hatcp_test-ood_annotations.json')
with open(_path, 'w') as outfile:
    json.dump(ann_test_ood, outfile)