In [2]:
import json
import numpy as np
import os
import random
from copy import deepcopy
from pprint import pprint
from tqdm import tqdm

DATA_DIR = '{}/research/data'.format(os.getenv('HOME'))
subqs_dir = '{}/research/DecompRC/DecompRC/out'.format(os.getenv('HOME'))
splits = ['train', 'dev']

In [3]:
### Without bridge entity plugged in
# subqs = {}
# subas = {}
# for qtype in ['b', 'i']:
#     for split in splits:
#         with open(f'{subqs_dir}/subqs_{qtype}/{split}_nbest_predictions.json') as f:
#             raw_subqs = json.load(f)
#         for qid, nbest_prediction in raw_subqs.items():
#             subqs[qid] = subqs.get(qid, {})
#             subqs[qid][qtype] = nbest_prediction[0][:2]
#         raw_subas = []
#         for subq_no in [1, 2]:
#             with open(f'{subqs_dir}/subqs_{qtype}/{split}_{qtype}_{subq_no}_nbest_predictions.json') as f:
#                 raw_subas.append(json.load(f))
#         for qid in raw_subas[0].keys():
#             subas[qid] = subas.get(qid, {})
#             subas[qid][qtype] = [raw_subas[0][qid][0]['text'], raw_subas[1][qid + ('-0' if qtype == 'b' else '')][0]['text']]

In [4]:
subqs = {}
subas = {}
for qtype in ['b', 'i']:
    for split in splits:
        raw_subqs = []
        raw_subas = []
        for subq_no in [1, 2]:
            with open(f'{DATA_DIR}/decomposed/{split}_{qtype}.{subq_no}.json') as f:
                raw_subqs.append(json.load(f))
            with open(f'{subqs_dir}/subqs_{qtype}/{split}_{qtype}_{subq_no}_nbest_predictions.json') as f:
                raw_subas.append(json.load(f))
        qid2subqs = []
        for subq_no in range(len(raw_subqs)):
            qid2subqs.append({})
            for example in raw_subqs[subq_no]['data']:
                for paragraph in example['paragraphs']:
                    for qa in paragraph['qas']:
                        qid2subqs[subq_no][qa['id']] = qa['question']
        for qid in raw_subas[0].keys():
            qid2 = qid + ('-0' if qtype == 'b' else '')
            subqs[qid] = subqs.get(qid, {})
            subqs[qid][qtype] = [qid2subqs[0][qid], qid2subqs[1][qid2]]
            subas[qid] = subas.get(qid, {})
            subas[qid][qtype] = [raw_subas[0][qid][0]['text'], raw_subas[1][qid2][0]['text']]

In [5]:
qtype = 'o'
for split in splits:
    with open(f'{subqs_dir}/subqs_{qtype}/{split}_nbest_predictions.json') as f:
        raw_subas = json.load(f)
    for qid in raw_subas.keys():
        subas[qid] = subas.get(qid, {})
        subas[qid][qtype] = [raw_subas[qid][0]['text']]

In [6]:
qtype = 'c'
total_raw_subqs = 0
raw_subas = {}
for split in splits:
    num_shards = 100 if split == 'train' else 10
    for shard_no in range(num_shards):
        with open(f'{subqs_dir}/subqs_{qtype}/comparison_decomposed_{split}_generations.num_shards={num_shards}.shard_no={shard_no}.json') as f:
            raw_subqs = json.load(f)
        with open(f'{subqs_dir}/subqs_{qtype}/comparison_decomposed_{split}_generations.num_shards={num_shards}.shard_no={shard_no}.nbest_predictions.json') as f:
            raw_subas.update(json.load(f))
        for example in raw_subqs['data']:
            for paragraph in example['paragraphs']:
                qid = paragraph['qas_orig'][0]['id']
                subqs[qid]['o'] = [paragraph['qas_orig'][0]['question']]
                subqs[qid] = subqs.get(qid, {})
                for qa in paragraph['qas']:
                    total_raw_subqs += 1
                    subqs[qid][qtype] = subqs[qid].get(qtype, [])
                    subqs[qid][qtype].append(qa['question'])

num_subas = 0
num_missing_subas = 0
subas_c_qids = {k.split('-')[0] for k in raw_subas.keys()}
for qid in subas_c_qids:
    subas[qid] = subas.get(qid, {})
    subas[qid][qtype] = []
    for i in [0, 1]:
        subqid = qid + '-' + str(i)
        if subqid in raw_subas:
            subas[qid][qtype].append(raw_subas[subqid][0]['text'])
            num_subas += 1
        else:
            num_missing_subas += 1

print(total_raw_subqs, num_subas, num_missing_subas)

37851 37851 7


In [7]:
# Load SQUAD formatted hotpot data
hotpot_data_folder = 'hotpot-squad'
data = {}
for split in splits:
    with open(f'{DATA_DIR}/{hotpot_data_folder}/{split}.json') as f:
        data[split] = json.load(f)

In [8]:
### Save Comparison SubQs UNMT
# Read in UNMT data
unmt_qs = {}
unmt_subqs = {}
unmt_qids = {}
for split in data.keys():
    unmt_qs[split] = []
    unmt_subqs[split] = []
    unmt_qids[split] = []
    for article_no in tqdm(range(len(data[split]['data']))):
        for paragraph_no in range(len(data[split]['data'][article_no]['paragraphs'])):
            if '_id' not in data[split]['data'][article_no]['paragraphs'][paragraph_no]:
                continue  # SQuAD question: skip
            qid = data[split]['data'][article_no]['paragraphs'][paragraph_no]['_id']
            if ('c' in subqs[qid].keys()) and (len(subqs[qid]['c']) == 2):
                unmt_subqs[split].append(subqs[qid]['c'][0].strip() + ' ' + subqs[qid]['c'][1].strip())
                unmt_qs[split].append(data[split]['data'][article_no]['paragraphs'][paragraph_no]['qas'][0]['question'].lower())
                unmt_qids[split].append(qid)
                break
    assert(len(unmt_qs[split]) == len(unmt_subqs[split]))

100%|██████████| 90889/90889 [00:00<00:00, 139696.60it/s]
100%|██████████| 7440/7440 [00:00<00:00, 135279.60it/s]


In [18]:
unmt_qs['dev'][5]

'which writer was from england, henry roth or robert erskine childers?'

In [None]:
unmt_qs

In [19]:
seed = 42
shuffler = random.Random(seed)

# Save UNMT data
for split in ['train_para']:
    save_split = split.replace('dev', 'valid')
    if split == 'train_para':
        split = 'train'
    save_folder = f'{DATA_DIR}/umt/comparison.paired'
    os.makedirs(save_folder, exist_ok=True)
    save_data = unmt_qs[split] if save_split != 'train' else shuffler.sample(unmt_qs[split], len(unmt_qs[split]))
    with open(f'{save_folder}/{save_split}.mh', 'w') as f:
        f.writelines('\n'.join(save_data))

    save_data = unmt_subqs[split] if save_split != 'train' else shuffler.sample(unmt_subqs[split], len(unmt_subqs[split]))
    with open(f'{save_folder}/{save_split}.sh', 'w') as f:
        f.writelines('\n'.join(save_data))
        
    if save_split in {'train', 'valid'}:
        with open(f'{save_folder}/{save_split}.qids.txt', 'w') as f:
            f.writelines('\n'.join(unmt_qids[split]))
    # TODO: Also save test file and symlink to parallel data file

In [64]:
q_lens = np.array(list(len(q.split()) for q in save_data))

In [78]:
unmt_qs['test'] = unmt_qs['train'][-1500:]
unmt_subqs['test'] = unmt_subqs['train'][-1500:]

In [44]:
use_subq = True
use_suba = True
skip_qtypes = ['c', 'o']
limitsubqs = False

q_start = '//'
a_start = '/'
skip_qtypes.sort()

assert use_subq or use_suba, 'Adding no sub-Qs or sub-As. Are you sure?'
question_augmenteds = []
print('Copying data...')
data_copy = deepcopy(data)
for split in data_copy.keys():
    for article_no in tqdm(range(len(data_copy[split]['data']))):
        for paragraph_no in range(len(data_copy[split]['data'][article_no]['paragraphs'])):
            if '_id' not in data_copy[split]['data'][article_no]['paragraphs'][paragraph_no]:
                continue  # SQuAD question: skip
            qid = data_copy[split]['data'][article_no]['paragraphs'][paragraph_no]['_id']
            for qa_no in range(len(data_copy[split]['data'][article_no]['paragraphs'][paragraph_no]['qas'])):
                question_augmented = data_copy[split]['data'][article_no]['paragraphs'][paragraph_no]['qas'][qa_no]['question'].strip()
                if use_suba and ('o' not in skip_qtypes):
                    question_augmented += ' ' + a_start + ' ' + subas[qid]['o'][0].strip()
                for qtype in {'b', 'i', 'c'}:
                    if (qtype not in skip_qtypes) and (qtype in subqs[qid]):
                        if limitsubqs and ('c' in subqs[qid]) and (qtype in {'b', 'i'}):
                            continue  # don't add b/i subqs if there's a c decomposition
                        for subq, suba in zip(subqs[qid][qtype], subas[qid][qtype]):
                            if use_subq:
                                question_augmented += ' ' + q_start + ' ' + subq.strip()
                            if use_suba:
                                question_augmented += ' ' + a_start + ' ' + suba.strip()
                data_copy[split]['data'][article_no]['paragraphs'][paragraph_no]['qas'][qa_no]['question'] = question_augmented
                question_augmenteds.append(question_augmented)

Copying data...


100%|██████████| 90889/90889 [00:06<00:00, 14991.03it/s]
100%|██████████| 7440/7440 [00:00<00:00, 13885.36it/s]


In [45]:
qlens = []
for question_augmented in question_augmenteds:
    qlens.append(len(question_augmented.split()))
qlens = np.array(qlens)
print('Mean # Tokens:', int(round(1.2 * qlens.mean())))

qlens.sort()
print('90 %-ile # Tokens:', int(round(1.2 * qlens[int(.9 * qlens.size)])))

Mean # Tokens: 92
90 %-ile # Tokens: 131


In [46]:
input_name = 'q'
if use_subq:
    input_name += '-predsubqs'
if use_suba:
    input_name += '-predsubas'
if len(skip_qtypes) > 0:
    input_name += '-no-' + ''.join(skip_qtypes)
if limitsubqs:
    input_name += '-limitsubqs'

save_dir = f'{DATA_DIR}/{hotpot_data_folder}-{input_name}'
print(f'Saving to {save_dir}...')
os.makedirs(save_dir, exist_ok=False)
for split in splits:
    with open(f'{save_dir}/{split}.json', 'w') as f:
        json.dump(data_copy[split], f, indent=2)

Saving to /Users/ethanperez/research/data/hotpot-squad-q-predsubqs-predsubas-no-co...


In [9]:
pprint(subqs['5add68665542992200553afe'])
pprint(subas['5add68665542992200553afe'])
pprint(data_copy['train']['data'][9889]['paragraphs'][0])

{'b': ['chicken cha cha cha or roborally receive more awards?', 'did roborally'], 'i': ['did chicken cha cha cha or roborally ?', 'did receive more awards? ?'], 'o': ['Did Chicken Cha Cha Cha or RoboRally receive more awards?']}
{'b': ['roborally', 'no'], 'i': ['chicken cha cha cha', 'no'], 'o': ['roborally']}
{'_id': '5add68665542992200553afe', 'type': 'comparison', 'level': 'hard', 'supporting_facts': [], 'context': 'yes no "Cha Cha Twist" is a 1960 song by Brice Coefield. A song called Cha Cha Twist was also recorded by Connie Francis in 1962 on MGM, and by Paul Gallis on Heartbreak Records.', 'qas': [{'question': 'Did Chicken Cha Cha Cha or RoboRally receive more awards? / roborally // chicken cha cha cha or roborally receive more awards? / roborally // did roborally / no // did chicken cha cha cha or roborally ? / chicken cha cha cha // did receive more awards? ? / no', 'id': '5add68665542992200553afe.0', 'answers': [], 'is_impossible': True}]}


In [47]:
pprint({'squad_f1': 86.75510941296, 'squad_NoAns_total': 4098, 'gn_bridge_recall': 89.06078678438656, 'gn_recall': 88.85925636059851, 'gn_bridge_em': 83.51158645276293, 'joint_em': 0.0, 'comparison_f1': 88.53704369171999, 'gn_bridge_prec': 89.57426075607897, 'squad_best_exact_thresh': -0.9808313846588135, 'gn_comparison_em': 84.82014388489209, 'gn_joint_prec': 0.0, 'gn_multihop_recall': 88.85925636059851, 'squad_best_exact': 81.20489174017642, 'gn_joint_f1': 0.0, 'gn_comparison_recall': 88.04588536962635, 'multihop_prec': 90.11983318269031, 'bridge_recall': 89.78116327490373, 'squad_NoAns_exact': 92.48413860419717, 'recall': 89.60801523371296, 'squad_best_f1': 86.84361294698306, 'gn_prec': 89.37692965764394, 'gn_multihop_prec': 89.37692965764394, 'gn_comparison_prec': 88.58050702295309, 'squad_HasAns_f1': 82.76096827214838, 'gn_multihop_f1': 88.41911217653674, 'comparison_em': 85.39568345323741, 'gn_f1': 88.41911217653674, 'bridge_f1': 89.28423578735266, 'joint_f1': 0.0, 'multihop_recall': 89.60801523371296, 'gn_comparison_f1': 87.74429766875814, 'gn_bridge_f1': 88.58631220609308, 'gn_joint_em': 0.0, 'multihop_f1': 89.13586478550567, 'squad_NoAns_f1': 92.48413860419717, 'gn_joint_recall': 0.0, 'f1': 89.13586478550567, 'bridge_prec': 90.31614869582788, 'gn_sp_prec': 0.0, 'comparison_recall': 88.90919472214436, 'gn_sp_em': 0.0, 'multihop_em': 84.41428571428573, 'bridge_em': 84.1711229946524, 'squad_HasAns_total': 5878, 'sp_prec': 0.0, 'joint_prec': 0.0, 'squad_HasAns_exact': 73.00102075535897, 'joint_recall': 0.0, 'prec': 90.11983318269031, 'sp_em': 0.0, 'gn_sp_recall': 0.0, 'gn_em': 83.77142857142857, 'sp_f1': 0.0, 'comparison_prec': 89.32750942103462, 'gn_multihop_em': 83.77142857142857, 'squad_exact': 81.00441058540497, 'squad_best_f1_thresh': -0.9490156173706055, 'sp_recall': 0.0, 'em': 84.41428571428573, 'squad_total': 9976, 'gn_sp_f1': 0.0})

{'bridge_em': 84.1711229946524,
 'bridge_f1': 89.28423578735266,
 'bridge_prec': 90.31614869582788,
 'bridge_recall': 89.78116327490373,
 'comparison_em': 85.39568345323741,
 'comparison_f1': 88.53704369171999,
 'comparison_prec': 89.32750942103462,
 'comparison_recall': 88.90919472214436,
 'em': 84.41428571428573,
 'f1': 89.13586478550567,
 'gn_bridge_em': 83.51158645276293,
 'gn_bridge_f1': 88.58631220609308,
 'gn_bridge_prec': 89.57426075607897,
 'gn_bridge_recall': 89.06078678438656,
 'gn_comparison_em': 84.82014388489209,
 'gn_comparison_f1': 87.74429766875814,
 'gn_comparison_prec': 88.58050702295309,
 'gn_comparison_recall': 88.04588536962635,
 'gn_em': 83.77142857142857,
 'gn_f1': 88.41911217653674,
 'gn_joint_em': 0.0,
 'gn_joint_f1': 0.0,
 'gn_joint_prec': 0.0,
 'gn_joint_recall': 0.0,
 'gn_multihop_em': 83.77142857142857,
 'gn_multihop_f1': 88.41911217653674,
 'gn_multihop_prec': 89.37692965764394,
 'gn_multihop_recall': 88.85925636059851,
 'gn_prec': 89.37692965764394,
 'gn

In [49]:
traineval_metrics = [
    {'gn_bridge_recall': 91.60731652975377, 'squad_HasAns_total': 5878, 'multihop_em': 87.37142857142857, 'sp_recall': 0.0, 'em': 87.37142857142857, 'comparison_recall': 90.85592824441746, 'multihop_prec': 92.8075581957724, 'squad_best_f1_thresh': -0.5998347774147987, 'sp_em': 0.0, 'gn_sp_prec': 0.0, 'multihop_recall': 92.45646185162228, 'sp_prec': 0.0, 'gn_multihop_recall': 91.11525514166554, 'bridge_recall': 92.85302900207049, 'squad_NoAns_f1': 93.87506100536848, 'bridge_prec': 93.25601438569348, 'comparison_prec': 90.99760191846524, 'gn_recall': 91.11525514166554, 'squad_total': 9976, 'gn_sp_recall': 0.0, 'joint_recall': 0.0, 'squad_best_f1': 88.46500089762212, 'gn_f1': 90.66373461486035, 'gn_comparison_em': 86.2589928057554, 'gn_multihop_em': 86.41428571428571, 'squad_HasAns_f1': 84.62789703230199, 'recall': 92.45646185162228, 'gn_joint_prec': 0.0, 'gn_sp_f1': 0.0, 'gn_em': 86.41428571428571, 'prec': 92.8075581957724, 'comparison_f1': 90.44577811336171, 'squad_NoAns_total': 4098, 'gn_sp_em': 0.0, 'squad_HasAns_exact': 74.0217761143246, 'joint_f1': 0.0, 'squad_exact': 82.17722534081797, 'gn_comparison_recall': 89.12930953938147, 'squad_NoAns_exact': 93.87506100536848, 'squad_f1': 88.42650147913729, 'gn_bridge_f1': 91.14906820520687, 'squad_best_exact_thresh': -0.6045594215393066, 'gn_prec': 91.54644815501955, 'joint_prec': 0.0, 'f1': 91.8559012659031, 'bridge_f1': 92.20529006840424, 'gn_bridge_prec': 92.12183073413019, 'gn_joint_em': 0.0, 'gn_bridge_em': 86.45276292335116, 'gn_multihop_f1': 90.66373461486035, 'comparison_em': 87.9136690647482, 'gn_joint_recall': 0.0, 'gn_comparison_prec': 89.22422062350121, 'gn_multihop_prec': 91.54644815501955, 'gn_joint_f1': 0.0, 'multihop_f1': 91.8559012659031, 'joint_em': 0.0, 'gn_comparison_f1': 88.70494221065529, 'bridge_em': 87.23707664884135, 'sp_f1': 0.0, 'squad_best_exact': 82.31756214915798},
    {'squad_HasAns_f1': 84.56928243295023, 'bridge_em': 87.55793226381462, 'gn_prec': 91.94441648386749, 'gn_em': 86.68571428571428, 'f1': 91.76399198840184, 'joint_prec': 0.0, 'gn_multihop_f1': 91.03721407566636, 'bridge_f1': 92.24199185177223, 'squad_HasAns_total': 5878, 'sp_recall': 0.0, 'gn_sp_recall': 0.0, 'gn_sp_f1': 0.0, 'em': 87.38571428571429, 'joint_f1': 0.0, 'sp_em': 0.0, 'gn_sp_em': 0.0, 'squad_best_f1_thresh': -0.14468765258789062, 'squad_best_exact_thresh': -0.14468765258789062, 'squad_total': 9976, 'squad_best_f1': 88.37744523416566, 'multihop_prec': 92.69292143130097, 'prec': 92.69292143130097, 'gn_comparison_f1': 87.93192059738824, 'gn_bridge_em': 87.13012477718361, 'joint_em': 0.0, 'comparison_f1': 89.83479829522997, 'bridge_recall': 92.92044700692175, 'sp_f1': 0.0, 'bridge_prec': 93.20703992426937, 'sp_prec': 0.0, 'gn_bridge_recall': 92.39452603247133, 'squad_f1': 88.35186869896596, 'squad_NoAns_exact': 93.77745241581259, 'squad_HasAns_exact': 74.49812861517523, 'squad_best_exact': 82.45789895749799, 'gn_joint_prec': 0.0, 'squad_NoAns_f1': 93.77745241581259, 'gn_joint_f1': 0.0, 'comparison_prec': 90.61795398845757, 'gn_multihop_em': 86.68571428571428, 'gn_comparison_recall': 88.3511320813479, 'multihop_recall': 92.38930209074829, 'gn_comparison_em': 84.89208633093526, 'gn_joint_recall': 0.0, 'recall': 92.38930209074829, 'gn_bridge_prec': 92.73920843905823, 'gn_recall': 91.59162351931968, 'gn_multihop_recall': 91.59162351931968, 'multihop_em': 87.38571428571429, 'comparison_em': 86.6906474820144, 'gn_sp_prec': 0.0, 'squad_NoAns_total': 4098, 'gn_comparison_prec': 88.7366590244288, 'gn_joint_em': 0.0, 'gn_f1': 91.03721407566636, 'multihop_f1': 91.76399198840184, 'joint_recall': 0.0, 'squad_exact': 82.4178027265437, 'gn_multihop_prec': 91.94441648386749, 'gn_bridge_f1': 91.80661834212015, 'comparison_recall': 90.24561649381792},
    {'f1': 91.30730265640072, 'squad_NoAns_f1': 93.04538799414348, 'sp_recall': 0.0, 'bridge_f1': 91.93537928541225, 'squad_total': 9976, 'multihop_f1': 91.30730265640072, 'joint_em': 0.0, 'gn_sp_em': 0.0, 'gn_comparison_recall': 88.30487817538175, 'gn_f1': 90.6864919062686, 'gn_comparison_em': 85.03597122302158, 'comparison_f1': 88.77240345585669, 'sp_prec': 0.0, 'recall': 91.72339969948874, 'squad_best_exact': 81.23496391339214, 'squad_best_f1_thresh': -0.005460262298583984, 'bridge_recall': 92.3269625056832, 'squad_HasAns_total': 5878, 'comparison_prec': 89.3999968856084, 'gn_recall': 91.06093036326935, 'gn_multihop_prec': 91.69873376390181, 'gn_bridge_em': 86.8270944741533, 'squad_best_exact_thresh': -0.005460262298583984, 'bridge_em': 87.3083778966132, 'em': 87.05714285714285, 'gn_sp_f1': 0.0, 'gn_joint_prec': 0.0, 'prec': 92.3853558439693, 'comparison_em': 86.0431654676259, 'multihop_em': 87.05714285714285, 'joint_prec': 0.0, 'comparison_recall': 89.28743758240158, 'gn_multihop_f1': 90.6864919062686, 'gn_comparison_prec': 88.4503565978386, 'bridge_prec': 93.12504371422273, 'squad_exact': 81.23496391339214, 'gn_bridge_f1': 91.3949539089353, 'joint_recall': 0.0, 'gn_prec': 91.69873376390181, 'sp_f1': 0.0, 'gn_multihop_recall': 91.06093036326935, 'squad_HasAns_f1': 83.24791722267227, 'squad_NoAns_exact': 93.04538799414348, 'squad_HasAns_exact': 73.00102075535897, 'gn_comparison_f1': 87.82715965090065, 'gn_joint_f1': 0.0, 'joint_f1': 0.0, 'squad_NoAns_total': 4098, 'sp_em': 0.0, 'squad_f1': 87.27257993533169, 'gn_em': 86.47142857142858, 'gn_bridge_recall': 91.74380247399377, 'squad_best_f1': 87.2725799353317, 'gn_bridge_prec': 92.50359013838099, 'gn_sp_prec': 0.0, 'gn_multihop_em': 86.47142857142858, 'multihop_recall': 91.72339969948874, 'gn_joint_em': 0.0, 'multihop_prec': 92.3853558439693, 'gn_sp_recall': 0.0, 'gn_joint_recall': 0.0},
    {'joint_f1': 0.0, 'gn_comparison_f1': 87.71162929796024, 'bridge_prec': 93.40768235041796, 'squad_NoAns_f1': 94.04587603709126, 'gn_joint_em': 0.0, 'gn_bridge_f1': 91.12933567265434, 'sp_prec': 0.0, 'gn_prec': 91.35997884265697, 'comparison_recall': 89.79742958879645, 'gn_multihop_prec': 91.35997884265697, 'squad_best_exact_thresh': -0.0821218490600586, 'squad_NoAns_total': 4098, 'squad_best_f1_thresh': -0.02561628818511963, 'gn_sp_f1': 0.0, 'gn_multihop_em': 86.18571428571428, 'gn_comparison_prec': 88.31883584041856, 'prec': 92.7552780672419, 'gn_joint_recall': 0.0, 'joint_prec': 0.0, 'squad_HasAns_exact': 74.25995236474992, 'gn_bridge_recall': 91.5262700064613, 'f1': 91.75783171476677, 'gn_sp_recall': 0.0, 'em': 87.41428571428571, 'gn_joint_prec': 0.0, 'gn_joint_f1': 0.0, 'squad_best_exact': 82.38773055332798, 'gn_bridge_em': 86.524064171123, 'gn_recall': 90.83881585079774, 'squad_f1': 88.28936197899624, 'gn_em': 86.18571428571428, 'sp_recall': 0.0, 'gn_multihop_recall': 90.83881585079774, 'joint_em': 0.0, 'squad_HasAns_f1': 84.27605905111693, 'squad_NoAns_exact': 94.04587603709126, 'multihop_recall': 92.20345979019169, 'recall': 92.20345979019169, 'comparison_em': 86.33093525179856, 'squad_total': 9976, 'gn_multihop_f1': 90.45067683539384, 'gn_comparison_em': 84.82014388489209, 'gn_sp_em': 0.0, 'squad_best_f1': 88.28936197899627, 'bridge_f1': 92.333976113574, 'joint_recall': 0.0, 'multihop_em': 87.41428571428571, 'comparison_prec': 90.12219315456721, 'gn_comparison_recall': 88.06427066139294, 'gn_bridge_prec': 92.11348842788183, 'bridge_recall': 92.79960631067993, 'gn_sp_prec': 0.0, 'gn_f1': 90.45067683539384, 'sp_em': 0.0, 'sp_f1': 0.0, 'multihop_f1': 91.75783171476677, 'multihop_prec': 92.7552780672419, 'squad_exact': 82.38773055332798, 'bridge_em': 87.68270944741533, 'squad_HasAns_total': 5878, 'comparison_f1': 89.43252950087484},
    {'joint_recall': 0.0, 'gn_comparison_recall': 87.48273588921069, 'sp_recall': 0.0, 'gn_f1': 90.53800592473229, 'squad_f1': 88.21156668539685, 'comparison_recall': 89.31207241998607, 'recall': 92.23665572917332, 'gn_prec': 91.45613187769277, 'squad_NoAns_f1': 93.87506100536848, 'gn_joint_em': 0.0, 'squad_HasAns_exact': 74.05580129295679, 'gn_bridge_recall': 91.93491849799747, 'em': 87.3, 'bridge_f1': 92.38683233713562, 'gn_sp_recall': 0.0, 'bridge_recall': 92.96128510524642, 'gn_joint_recall': 0.0, 'gn_multihop_prec': 91.45613187769277, 'comparison_prec': 89.77458033573143, 'gn_em': 86.22857142857143, 'squad_NoAns_exact': 93.87506100536848, 'comparison_em': 86.18705035971223, 'joint_em': 0.0, 'multihop_prec': 92.6862640074678, 'gn_sp_f1': 0.0, 'sp_em': 0.0, 'gn_comparison_em': 84.46043165467626, 'gn_bridge_em': 86.66666666666667, 'squad_best_exact': 82.2373696872494, 'multihop_recall': 92.23665572917332, 'gn_bridge_prec': 92.30473971666954, 'gn_comparison_f1': 87.2254072865584, 'sp_f1': 0.0, 'squad_best_exact_thresh': -0.4903706908226013, 'squad_total': 9976, 'bridge_em': 87.57575757575758, 'gn_recall': 91.05084223710979, 'gn_multihop_recall': 91.05084223710979, 'squad_best_f1_thresh': -0.00813150405883789, 'gn_sp_em': 0.0, 'squad_HasAns_f1': 84.2631148781077, 'gn_sp_prec': 0.0, 'squad_NoAns_total': 4098, 'gn_joint_f1': 0.0, 'gn_multihop_em': 86.22857142857143, 'gn_multihop_f1': 90.53800592473229, 'multihop_f1': 91.71518609750078, 'sp_prec': 0.0, 'gn_comparison_prec': 88.03117505995205, 'prec': 92.6862640074678, 'bridge_prec': 93.4076972166859, 'joint_f1': 0.0, 'joint_prec': 0.0, 'gn_joint_prec': 0.0, 'multihop_em': 87.3, 'squad_HasAns_total': 5878, 'f1': 91.71518609750078, 'comparison_f1': 89.00444120228293, 'squad_exact': 82.19727345629511, 'gn_bridge_f1': 91.35877457126719, 'squad_best_f1': 88.21156668539683},
    {'gn_comparison_f1': 88.09614446305093, 'squad_best_exact_thresh': -0.06386780738830566, 'sp_f1': 0.0, 'squad_NoAns_total': 4098, 'gn_comparison_prec': 88.76136782881387, 'squad_best_exact': 81.88652766639936, 'squad_NoAns_f1': 93.8506588579795, 'gn_joint_em': 0.0, 'joint_recall': 0.0, 'joint_prec': 0.0, 'gn_f1': 91.29181420057996, 'gn_prec': 92.33604894708468, 'gn_bridge_recall': 92.46648267555095, 'gn_comparison_recall': 88.51271190839533, 'squad_HasAns_f1': 83.70719661179737, 'gn_joint_f1': 0.0, 'gn_multihop_recall': 91.68137676607292, 'squad_HasAns_exact': 73.52841102415788, 'multihop_f1': 91.79883338170357, 'sp_prec': 0.0, 'squad_best_f1': 87.88401179672691, 'comparison_em': 86.2589928057554, 'bridge_em': 87.77183600713012, 'gn_sp_prec': 0.0, 'sp_recall': 0.0, 'gn_joint_recall': 0.0, 'gn_multihop_prec': 92.33604894708468, 'squad_best_f1_thresh': -0.030368804931640625, 'multihop_em': 87.47142857142856, 'comparison_prec': 90.07581627006088, 'gn_joint_prec': 0.0, 'joint_f1': 0.0, 'sp_em': 0.0, 'gn_bridge_prec': 93.22175425089864, 'gn_multihop_em': 86.94285714285715, 'f1': 91.79883338170357, 'bridge_prec': 93.52360732948968, 'squad_f1': 87.87398773898826, 'gn_multihop_f1': 91.29181420057996, 'recall': 92.23852812661714, 'gn_bridge_em': 87.43315508021391, 'multihop_recall': 92.23852812661714, 'gn_recall': 91.68137676607292, 'multihop_prec': 92.83897453340309, 'gn_em': 86.94285714285715, 'comparison_f1': 89.46013898172171, 'gn_sp_em': 0.0, 'gn_sp_recall': 0.0, 'bridge_f1': 92.3782959870464, 'squad_exact': 81.87650360866078, 'gn_comparison_em': 84.96402877697842, 'em': 87.47142857142856, 'squad_HasAns_total': 5878, 'gn_bridge_f1': 92.08361115871972, 'comparison_recall': 89.94076946235218, 'gn_sp_f1': 0.0, 'prec': 92.83897453340309, 'joint_em': 0.0, 'bridge_recall': 92.80784800956333, 'squad_NoAns_exact': 93.8506588579795, 'squad_total': 9976},
    {'multihop_recall': 92.26259553539884, 'joint_recall': 0.0, 'squad_HasAns_total': 5878, 'joint_em': 0.0, 'gn_joint_prec': 0.0, 'bridge_prec': 93.2590267476631, 'gn_joint_recall': 0.0, 'squad_NoAns_total': 4098, 'squad_HasAns_f1': 84.09833812277128, 'comparison_prec': 90.6714628297362, 'multihop_prec': 92.74521048396048, 'gn_bridge_prec': 92.40370352068214, 'gn_comparison_recall': 88.90748699741503, 'gn_multihop_recall': 91.25035347190673, 'gn_bridge_em': 86.73796791443851, 'gn_sp_prec': 0.0, 'f1': 91.72846505431819, 'gn_sp_em': 0.0, 'squad_best_exact': 82.17722534081797, 'bridge_f1': 92.15239814098479, 'squad_NoAns_f1': 93.97266959492435, 'comparison_em': 87.33812949640289, 'gn_prec': 91.7540497399426, 'recall': 92.26259553539884, 'squad_best_exact_thresh': -0.1645965576171875, 'bridge_recall': 92.7005517209836, 'sp_prec': 0.0, 'bridge_em': 87.29055258467024, 'comparison_recall': 90.4950169734342, 'sp_f1': 0.0, 'gn_bridge_f1': 91.37286071989341, 'squad_total': 9976, 'gn_em': 86.54285714285714, 'squad_f1': 88.15457412646872, 'gn_multihop_em': 86.54285714285714, 'gn_f1': 90.79830714500058, 'em': 87.3, 'gn_sp_recall': 0.0, 'gn_bridge_recall': 91.83084979981112, 'squad_exact': 82.13712910986368, 'gn_comparison_prec': 89.13206577595068, 'gn_joint_f1': 0.0, 'sp_em': 0.0, 'comparison_f1': 90.01748331604449, 'gn_sp_f1': 0.0, 'squad_HasAns_exact': 73.88567539979584, 'gn_comparison_f1': 88.47942545064848, 'gn_comparison_em': 85.75539568345324, 'gn_multihop_prec': 91.7540497399426, 'joint_prec': 0.0, 'joint_f1': 0.0, 'multihop_em': 87.3, 'multihop_f1': 91.72846505431819, 'gn_joint_em': 0.0, 'sp_recall': 0.0, 'prec': 92.74521048396048, 'squad_best_f1': 88.18711376005066, 'gn_multihop_f1': 90.79830714500058, 'gn_recall': 91.25035347190673, 'squad_NoAns_exact': 93.97266959492435, 'squad_best_f1_thresh': -0.1645965576171875},
    {'gn_comparison_recall': 89.4987230994425, 'sp_f1': 0.0, 'gn_em': 86.47142857142858, 'gn_sp_prec': 0.0, 'gn_joint_em': 0.0, 'multihop_em': 87.2, 'comparison_recall': 89.57666074932258, 'gn_f1': 90.68286439557572, 'sp_prec': 0.0, 'bridge_f1': 92.03469587001183, 'gn_bridge_f1': 91.03793784702495, 'gn_multihop_recall': 91.12716773700676, 'gn_sp_em': 0.0, 'em': 87.2, 'squad_HasAns_f1': 83.81330802442062, 'joint_prec': 0.0, 'gn_multihop_f1': 90.68286439557572, 'squad_HasAns_total': 5878, 'squad_best_f1': 87.96825293045448, 'comparison_f1': 89.34229559409417, 'gn_recall': 91.12716773700676, 'gn_bridge_recall': 91.53065045469202, 'squad_f1': 87.95655819642612, 'gn_comparison_f1': 89.24979816346725, 'multihop_prec': 92.43726573799681, 'squad_NoAns_f1': 93.89946315275745, 'squad_total': 9976, 'joint_em': 0.0, 'comparison_em': 86.76258992805755, 'bridge_recall': 92.6539174382531, 'joint_recall': 0.0, 'joint_f1': 0.0, 'f1': 91.50006210093692, 'squad_best_exact_thresh': -0.05420732498168945, 'gn_comparison_prec': 89.97290167865708, 'squad_best_exact': 81.80633520449078, 'gn_joint_f1': 0.0, 'gn_comparison_em': 86.76258992805755, 'squad_NoAns_exact': 93.89946315275745, 'gn_multihop_em': 86.47142857142858, 'gn_bridge_prec': 91.96791330080103, 'bridge_em': 87.3083778966132, 'prec': 92.43726573799681, 'gn_sp_recall': 0.0, 'squad_exact': 81.78628708901363, 'gn_prec': 91.57176099297523, 'multihop_f1': 91.50006210093692, 'comparison_prec': 90.08081534772184, 'gn_multihop_prec': 91.57176099297523, 'gn_joint_prec': 0.0, 'multihop_recall': 92.04286218145126, 'bridge_prec': 93.02112777765498, 'sp_em': 0.0, 'squad_HasAns_exact': 73.34127254168084, 'gn_sp_f1': 0.0, 'sp_recall': 0.0, 'recall': 92.04286218145126, 'gn_joint_recall': 0.0, 'squad_best_f1_thresh': -0.05420732498168945, 'gn_bridge_em': 86.39928698752229, 'squad_NoAns_total': 4098},
    {'gn_multihop_f1': 90.35842539568402, 'gn_bridge_recall': 91.42685476793827, 'comparison_recall': 89.93317553749206, 'joint_recall': 0.0, 'bridge_prec': 93.11856894343524, 'gn_comparison_f1': 87.81504985821533, 'gn_multihop_recall': 90.76620513027348, 'squad_HasAns_f1': 82.23161424393254, 'sp_prec': 0.0, 'squad_NoAns_f1': 92.2157149829185, 'comparison_f1': 89.52111357866755, 'squad_best_f1_thresh': -0.29443836212158203, 'squad_best_exact': 80.41299117882919, 'squad_NoAns_total': 4098, 'gn_sp_recall': 0.0, 'gn_multihop_em': 86.24285714285715, 'squad_NoAns_exact': 92.2157149829185, 'gn_sp_f1': 0.0, 'gn_bridge_em': 86.54188948306596, 'gn_comparison_recall': 88.09984220415873, 'comparison_em': 86.61870503597122, 'squad_total': 9976, 'bridge_recall': 92.69067642100134, 'gn_bridge_prec': 91.93353101641873, 'gn_sp_em': 0.0, 'gn_prec': 91.24207617779042, 'squad_best_f1': 86.3754878063881, 'joint_f1': 0.0, 'gn_recall': 90.76620513027348, 'sp_recall': 0.0, 'squad_HasAns_total': 5878, 'comparison_prec': 90.17680401133639, 'em': 87.32857142857144, 'gn_joint_prec': 0.0, 'gn_comparison_prec': 88.45138434706779, 'multihop_prec': 92.53441847834702, 'gn_joint_em': 0.0, 'joint_prec': 0.0, 'gn_joint_recall': 0.0, 'squad_HasAns_exact': 72.0823409322899, 'sp_em': 0.0, 'multihop_em': 87.32857142857144, 'prec': 92.53441847834702, 'gn_joint_f1': 0.0, 'recall': 92.143115531276, 'squad_f1': 86.3329419131754, 'squad_best_exact_thresh': -0.29443836212158203, 'joint_em': 0.0, 'squad_exact': 80.35284683239776, 'multihop_recall': 92.143115531276, 'gn_em': 86.24285714285715, 'gn_comparison_em': 85.03597122302158, 'gn_sp_prec': 0.0, 'f1': 91.60646812258955, 'gn_f1': 90.35842539568402, 'gn_bridge_f1': 90.98860222225804, 'sp_f1': 0.0, 'gn_multihop_prec': 91.24207617779042, 'multihop_f1': 91.60646812258955, 'bridge_f1': 92.12316024666272, 'bridge_em': 87.50445632798574},
    {'gn_comparison_f1': 88.34061182622337, 'squad_f1': 88.3522698105272, 'comparison_f1': 89.54205067514422, 'squad_NoAns_f1': 94.5095168374817, 'squad_best_exact_thresh': -1.0360791683197021, 'f1': 91.9595954516586, 'sp_prec': 0.0, 'gn_sp_em': 0.0, 'gn_em': 86.64285714285714, 'squad_best_exact': 82.54811547714515, 'gn_bridge_prec': 92.64451487254254, 'gn_joint_recall': 0.0, 'squad_total': 9976, 'squad_best_f1_thresh': -0.1478966474533081, 'multihop_recall': 92.43658543438875, 'gn_comparison_em': 85.75539568345324, 'gn_joint_f1': 0.0, 'gn_prec': 91.92777805481165, 'gn_multihop_recall': 91.44409319836392, 'bridge_em': 87.77183600713012, 'squad_NoAns_total': 4098, 'sp_em': 0.0, 'em': 87.58571428571429, 'squad_HasAns_f1': 84.05958551034648, 'gn_sp_recall': 0.0, 'gn_multihop_em': 86.64285714285714, 'gn_multihop_f1': 90.95842873616856, 'recall': 92.43658543438875, 'joint_prec': 0.0, 'comparison_prec': 90.21610403984505, 'prec': 92.96569901107293, 'bridge_recall': 93.06659180092748, 'joint_f1': 0.0, 'gn_sp_prec': 0.0, 'squad_best_f1': 88.37731159579351, 'gn_comparison_recall': 88.68286669365804, 'squad_NoAns_exact': 94.5095168374817, 'joint_recall': 0.0, 'sp_recall': 0.0, 'gn_joint_prec': 0.0, 'gn_bridge_f1': 91.60705003827604, 'sp_f1': 0.0, 'gn_recall': 91.44409319836392, 'squad_exact': 82.43785084202085, 'squad_HasAns_total': 5878, 'gn_bridge_recall': 92.12824735906642, 'joint_em': 0.0, 'squad_HasAns_exact': 74.0217761143246, 'gn_bridge_em': 86.86274509803921, 'comparison_recall': 89.89389786871799, 'gn_f1': 90.95842873616856, 'gn_comparison_prec': 89.03504888396976, 'bridge_prec': 93.64697120537005, 'gn_multihop_prec': 91.92777805481165, 'gn_sp_f1': 0.0, 'multihop_f1': 91.9595954516586, 'gn_joint_em': 0.0, 'comparison_em': 86.83453237410072, 'bridge_f1': 92.55859495956481, 'multihop_em': 87.58571428571429, 'multihop_prec': 92.96569901107293},
]

In [50]:
traineval_metrics_summary = {}
for metric in traineval_metrics[0].keys():
    traineval_metrics_summary[metric] = np.mean(np.array([traineval_metric[metric] for traineval_metric in traineval_metrics]))
pprint(traineval_metrics_summary)

{'bridge_em': 87.50089126559715,
 'bridge_f1': 92.23506148605688,
 'bridge_prec': 93.29727795949024,
 'bridge_recall': 92.77809153213306,
 'comparison_em': 86.6978417266187,
 'comparison_f1': 89.53730327132783,
 'comparison_prec': 90.21333287815301,
 'comparison_recall': 89.93380049207386,
 'em': 87.34142857142857,
 'f1': 91.69936378341801,
 'gn_bridge_em': 86.75757575757576,
 'gn_bridge_f1': 91.39288126863559,
 'gn_bridge_prec': 92.39542744174638,
 'gn_bridge_recall': 91.85899185977365,
 'gn_comparison_em': 85.37410071942445,
 'gn_comparison_f1': 88.13820888050682,
 'gn_comparison_prec': 88.81140156605986,
 'gn_comparison_recall': 88.50339572497845,
 'gn_em': 86.48285714285714,
 'gn_f1': 90.74659632299303,
 'gn_joint_em': 0.0,
 'gn_joint_f1': 0.0,
 'gn_joint_prec': 0.0,
 'gn_joint_recall': 0.0,
 'gn_multihop_em': 86.48285714285714,
 'gn_multihop_f1': 90.74659632299303,
 'gn_multihop_prec': 91.68374230357433,
 'gn_multihop_recall': 91.19266634157859,
 'gn_prec': 91.68374230357433,
 'gn

In [52]:
[traineval_metrics[i]['f1'] for i in range(10)]

[91.8559012659031,
 91.76399198840184,
 91.30730265640072,
 91.75783171476677,
 91.71518609750078,
 91.79883338170357,
 91.72846505431819,
 91.50006210093692,
 91.60646812258955,
 91.9595954516586]