In [None]:
import argparse
import json
import numpy as np
import os
import random
from copy import deepcopy
from pprint import pprint
from tqdm import tqdm

main_folder_name = 'research_backup'
DATA_DIR = '{}/{}/data'.format(os.getenv('HOME'), main_folder_name)
checkpoint_dir = os.getenv('HOME') + f'/{main_folder_name}/XLM/dumped'

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--model_no",
                    required=True,
                    type=int,
                    help="UMT generations model number.")
parser.add_argument("--sample_temperature",
                    required=True,
                    type=float,
                    help="UMT generations sampling temperature.")
parser.add_argument("--beam",
                    required=True,
                    type=int,
                    help="UMT generations beam size.")
parser.add_argument("--length_penalty",
                    required=True,
                    type=float,
                    help="UMT generations length penalty.")
parser.add_argument("--seed",
                    required=True,
                    type=int,
                    help="UMT generations seed.")
parser.add_argument("--split",
                    required=True,
                    type=str,
                    choices=['train_para', 'valid'],
                    help="Dataset split.")
args = parser.parse_args()

if args.model_no == 1:
    model_dir = 'umt.comparison.paired.wp=0.15.sa=0.5.ebs=256.lr=0.0001.ws=0.wd=0.0.wb=0.0/17507568'
elif args.model_no == 2:
    model_dir = 'umt.comparison.paired.ebs=256.lr=0.0001.ws=0.wd=0.0.wb=0.03/17410140'
else:
    assert False, 'Bad model_no'
subqs_filename = f'hyp.st={args.sample_temperature}.bs={args.beam}.lp={args.length_penalty}.es=False.seed={args.seed}.mh-sh.{args.split}.pred.bleu.sh.txt'

subqs_filepath = os.path.join(checkpoint_dir, model_dir, subqs_filename)
hotpot_split = args.split.replace('valid', 'dev').replace('_para', '')
assert hotpot_split in {'train', 'dev', 'test'}

with open(subqs_filepath) as f:
    raw_subqs = f.readlines()
print(f'Read {len(raw_subqs)} Sub-Q pairs')
print(f'model_no={args.model_no}, st={args.sample_temperature}, beam={args.beam}, length_penalty={args.length_penalty}, seed={args.seed}')

In [None]:
subqs = []
for raw_subq in raw_subqs:
    ex_subqs = raw_subq.strip('\n').strip().split(' ?')
    proc_ex_subqs = []
    for ex_subq in ex_subqs:
        proc_ex_subq = ex_subq.strip()
        if len(proc_ex_subq) > 0:
            proc_ex_subqs.append(proc_ex_subq + '?')
    subqs.append(proc_ex_subqs)

subq_lens = np.array([len(subq) for subq in subqs])
print('Mean # of Sub-Qs:', round(subq_lens.mean(), 3))
subqs[0]

In [None]:
with open(f'{DATA_DIR}/umt/comparison.paired/{args.split}.qids.txt') as f:
    qids = f.readlines()
qids = [qid.strip('\n') for qid in qids]
qid2subqs = {qid: subq_pair for qid, subq_pair in zip(qids, subqs)}
print(f'Read {len(qids)} QIDs')

In [None]:
with open(f'{DATA_DIR}/hotpot-all/{hotpot_split}.json') as f:
    data_hotpot = json.load(f)

num_hotpot_examples = len(data_hotpot['data'])
print(f'Read {num_hotpot_examples} HotpotQA examples')
qid2examples = {}
for example in data_hotpot['data']:
    qid = example['paragraphs'][0]['qas'][0]['id']
    qid2examples[qid] = example

In [None]:
new_data_hotpot = {'data': []}
for qid, subqpair in qid2subqs.items():
    assert len(subqpair) > 0, 'There must be at least 1 Sub-Q!'
    example = deepcopy(qid2examples[qid])
    example['paragraphs'][0]['qas'] = [
        {'question': subq,
         'answers': [[] for _ in range(len(example['paragraphs'][0]['qas'][0]['answers']))],
         'id': qid + '-' + str(i)}
        for i, subq in enumerate(subqpair)]
    new_data_hotpot['data'].append(example)

In [None]:
save_dir = f'{DATA_DIR}/hotpot-all.umt.comparison.paired.model={args.model_no}.st={args.sample_temperature}.beam={args.beam}.lp={float(args.length_penalty)}.seed={args.seed}'
print('Saving to', save_dir)
os.makedirs(save_dir, exist_ok=False)
assert not os.path.isfile(f'{save_dir}/{hotpot_split}.json'), 'File already exists!'
with open(f'{save_dir}/{hotpot_split}.json', 'w') as f:
    json.dump(new_data_hotpot, f, indent=2)