In [None]:
import argparse
import json
import numpy as np
import os
import random
from copy import deepcopy
from pprint import pprint
from tqdm import tqdm

DATA_DIR = '{}/research_backup/data'.format(os.getenv('HOME'))

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--subqs_dir",
                    required=True,
                    type=str,
                    help="Directory containing UMT-Generated Sub-Qs.")
parser.add_argument("--splits",
                    required=True,
                    nargs='+',
                    help="Dataset splits to process (e.g., dev).")
parser.add_argument("--num_subas",
                    default=1,
                    type=int,
                    help="Number of Sub-answers per Sub-Q.")
parser.add_argument("--subsample_data",
                    action='store_true',
                    default=False,
                    help="Make subsets of train with fewer Sub-Qs?")
args = parser.parse_args()

# Load SQUAD formatted hotpot data
hotpot_data_folder = 'hotpot-squad'
data = {}
for split in args.splits:
    with open(f'{DATA_DIR}/{hotpot_data_folder}/{split}.json') as f:
        data[split] = json.load(f)

In [None]:
subqs = {}
subas = {}
qtype = 'c'
total_raw_subqs = 0
raw_subas = {}
for split in args.splits:
    with open(f'{args.subqs_dir}/{split}.json') as f:
        raw_subqs = json.load(f)
    with open(f'{args.subqs_dir}/bert_predict.nbest=10/{split}.nbest_predictions.json') as f:
        raw_subas.update(json.load(f))
    for example in raw_subqs['data']:
        for paragraph in example['paragraphs']:
            qid = paragraph['qas'][0]['id'].split('-')[0]
            subqs[qid] = subqs.get(qid, {})
            for qa in paragraph['qas']:
                total_raw_subqs += 1
                subqs[qid][qtype] = subqs[qid].get(qtype, [])
                subqs[qid][qtype].append(qa['question'])

total_subas = 0
num_missing_subas = 0
subas_c_qids = {k.split('-')[0] for k in raw_subas.keys()}
for qid in subas_c_qids:
    subas[qid] = subas.get(qid, {})
    subas[qid][qtype] = []
    for i in [0, 1]:
        subqid = qid + '-' + str(i)
        if subqid in raw_subas:
            subas[qid][qtype].append(raw_subas[subqid])
            total_subas += 1
        else:
            num_missing_subas += 1
# suba_probs = np.array(suba_probs)
# suba_probs = np.array(np.array(suba_probs[i]) for i in suba_probs)
# suba_probs.sort()
print(f'total_raw_subqs={total_raw_subqs}, total_subas={total_subas}, num_missing_subas={num_missing_subas}')

In [None]:
# # Plot histogram of sub-A model confidence
# plt.hist(suba_probs, 100, density=True, range=(0,1))
# plt.xlabel('Confidence in Sub-Answer')
# plt.ylabel('% of Dev Examples')
# plt.title('Histogram of Single-hop QA Model Confidence (Best sub-A)')
# plt.grid(True)
# plt.show()

In [None]:
for si in range(args.num_subas):
    for sj in range(args.num_subas):
        suba_rank = [si, sj]
        use_subq = True
        use_suba = True
        skip_qtypes = ['b', 'i', 'o']
        limitsubqs = True

        q_start = '//'
        a_start = '/'
        skip_qtypes.sort()

        assert use_subq or use_suba, 'Adding no sub-Qs or sub-As. Are you sure?'
        question_augmenteds = []
        print('Copying data...')
        data_copy = deepcopy(data)
        for split in data_copy.keys():
            for article_no in tqdm(range(len(data_copy[split]['data']))):
                for paragraph_no in range(len(data_copy[split]['data'][article_no]['paragraphs'])):
                    if '_id' not in data_copy[split]['data'][article_no]['paragraphs'][paragraph_no]:
                        continue  # SQuAD question: skip
                    qid = data_copy[split]['data'][article_no]['paragraphs'][paragraph_no]['_id']
                    for qa_no in range(len(data_copy[split]['data'][article_no]['paragraphs'][paragraph_no]['qas'])):
                        question_augmented = data_copy[split]['data'][article_no]['paragraphs'][paragraph_no]['qas'][qa_no]['question'].strip()
                        if use_suba and ('o' not in skip_qtypes):
                            question_augmented += ' ' + a_start + ' ' + subas[qid]['o'][0].strip()
                        for qtype in {'b', 'i', 'c'}:
                            if (qtype not in skip_qtypes) and (qid in subqs) and (qtype in subqs[qid]):
                                if limitsubqs and ('c' in subqs[qid]) and (qtype in {'b', 'i'}):
                                    print('Limited!')
                                    continue  # don't add b/i subqs if there's a c decomposition
                                for subq_no, (subq, suba) in enumerate(zip(subqs[qid][qtype], subas[qid][qtype])):
                                    if use_subq:
                                        question_augmented += ' ' + q_start + ' ' + subq.strip()
                                    if use_suba and (len(suba) > suba_rank[subq_no]):
                                        question_augmented += ' ' + a_start + ' ' + suba[suba_rank[subq_no]]['text'].strip()
                                        data_copy[split]['data'][article_no]['paragraphs'][paragraph_no]['qas'][qa_no]['probability_subas'] = data_copy[split]['data'][article_no]['paragraphs'][paragraph_no]['qas'][qa_no].get('probability_subas', [])
                                        data_copy[split]['data'][article_no]['paragraphs'][paragraph_no]['qas'][qa_no]['probability_subas'].append(suba[suba_rank[subq_no]]['probability'])
                        data_copy[split]['data'][article_no]['paragraphs'][paragraph_no]['qas'][qa_no]['question'] = question_augmented
                        question_augmenteds.append(question_augmented)
                    
        qlens = []
        for question_augmented in question_augmenteds:
            qlens.append(len(question_augmented.split()))
        qlens = np.array(qlens)
        print('Mean # Tokens:', int(round(1.2 * qlens.mean())))
        print('# Examples:', len(qlens))

        qlens.sort()
        print('90 %-ile # Tokens:', int(round(1.2 * qlens[int(.9 * qlens.size)])))
        print('99 %-ile # Tokens:', int(round(1.2 * qlens[int(.99 * qlens.size)])))

        input_name = 'q'
        if use_subq:
            input_name += '-predsubqs'
        if use_suba:
            input_name += '-predsubas'

        save_dir = f'{args.subqs_dir}.{input_name}.num_subas={args.num_subas}'
        print(f'Saving to {save_dir}...')
        os.makedirs(save_dir, exist_ok=True)
        for split in args.splits:
            with open(f'{save_dir}/{split}.suba1={suba_rank[0]}.suba2={suba_rank[1]}.json', 'w') as f:
                json.dump(data_copy[split], f, indent=2)

In [None]:
### Generate (subset of) comparison-only Q's
if not args.subsample_data:
    print('Done! (Not making subsamples of data)')
    exit()

data_copy = deepcopy(data)
num_comp_qs = {}
comp_data = {}
for split in args.splits:
    num_comp_qs[split] = 0
    comp_data[split] = {'data': [], 'version': data_copy[split]['version']}
    for article in data_copy[split]['data']:
        is_comp_paragraph = False
        for paragraph in article['paragraphs']:
            if paragraph.get('_id') in subas_c_qids:
                is_comp_paragraph = True
                break
        if is_comp_paragraph:
            comp_data[split]['data'].append(article)
            num_comp_qs[split] += 1
print('num_comp_qs:', num_comp_qs)

In [None]:
save_dir = f'{DATA_DIR}/hotpot-all'
comp_save_dir = f'{save_dir}.comparison-only.num-train=' + str(num_comp_qs['train'])
print(f'Saving to {comp_save_dir}...')
os.makedirs(comp_save_dir, exist_ok=False)
for split in args.splits:
    with open(f'{comp_save_dir}/{split}.json', 'w') as f:
        json.dump((comp_data if split == 'train' else data_copy)[split], f, indent=2)

In [None]:
shuffled_qidxs_train = list(range(num_comp_qs['train']))
random.Random(42).shuffle(shuffled_qidxs_train)
train_data_frac = .015625

train_qidxs_subset = set(shuffled_qidxs_train[:round(len(shuffled_qidxs_train) * train_data_frac)])

num_comp_qs_filtered = {}
comp_data = {}
for split in ['train']:
    num_comp_qs_filtered[split] = 0
    comp_data[split] = {'data': [], 'version': data_copy[split]['version']}
    for article in data_copy[split]['data']:
        is_comp_paragraph = False
        for paragraph in article['paragraphs']:
            if paragraph.get('_id') in subas_c_qids:
                is_comp_paragraph = True
                break
        if is_comp_paragraph:
            if num_comp_qs_filtered[split] in train_qidxs_subset:
                comp_data[split]['data'].append(article)
            num_comp_qs_filtered[split] += 1
print('Train Comp Qs remaining:', len(comp_data['train']['data']))

In [None]:
comp_save_dir = f'{save_dir}.comparison-only.num-train=' + str(len(comp_data['train']['data']))
print(f'Saving to {comp_save_dir}...')
os.makedirs(comp_save_dir, exist_ok=False)
for split in args.splits:
    with open(f'{comp_save_dir}/{split}.json', 'w') as f:
        json.dump((comp_data if split == 'train' else data_copy)[split], f, indent=2)