In [1]:
import pandas as pd
import json
import collections
import math

In [3]:
def compute_score(gold_nq_jsonl_path, answers_csv_path, tqdm=None):
    
    answers_csv_df = pd.read_csv(answers_csv_path)
    gold_e_by_id = {}
    with open(gold_nq_jsonl_path) as gold_nq_jsonl:
        if tqdm is not None:
            gold_nq_jsonl = tqdm(gold_nq_jsonl)
        for index, line in enumerate(gold_nq_jsonl):
            gold_e = json.loads(line, object_pairs_hook=collections.OrderedDict)
            gold_e_by_id[gold_e['example_id']] = gold_e
            
    tp_short = 0
    fp_short = 0
    fn_short = 0
    tn_short = 0
    tp_long = 0
    fp_long = 0
    fn_long = 0
    tn_long = 0
    is_short_fp_column = []
    is_long_fp_column = []
    
    for index, row in answers_csv_df.iterrows():
        example_id = row['example_id']
        if not example_id in gold_e_by_id:
            raise ValueError('example id not found in gold_nq_jsonl file: ' + str(example_id))
        annotations = gold_e_by_id[example_id]['annotations']
        assert(len(annotations) == 1)
        gold_answer = annotations[0]
        # short answer
        short_answer = row['short_answer']
        short_answers_gold = gold_answer['short_answers']
        yes_no_answer_gold = gold_answer['yes_no_answer']
        #short_answer_score = row['short_answer_score']
        is_short_fp = False
        is_long_fp = False
#         print('short_answer:', short_answer)
#         print('short_answers_gold:', short_answers_gold)
#         print('yes_no_answer_gold:', yes_no_answer_gold)
        
        if short_answer == 'YES' or short_answer == 'NO':
            # Check for empty gold short answer, otherwise it should be SHORT
            ok = short_answer == yes_no_answer_gold and len(short_answers_gold) == 0
            tp_short += ok
            fp_short += not ok
            is_short_fp = not ok
#             print(ok)
        elif isinstance(short_answer, str):
            assert(':' in short_answer)
            start_token, end_token = [int(x) for x in short_answer.split(':')]
            
            ok = len(short_answers_gold) != 0 and \
            start_token == short_answers_gold[0]['start_token'] and \
            end_token == short_answers_gold[-1]['end_token']
            
            tp_short += ok
            fp_short += not ok
            is_short_fp = not ok
#             print(ok)
        elif isinstance(short_answer, float) and math.isnan(short_answer):
            # blank short answer
            ok = yes_no_answer_gold == 'NONE' and len(short_answers_gold) == 0
            tn_short += ok
            fn_short += not ok
#             print(ok)
        else:
            raise Exception('wrong short_answer type, short_answer:', short_answer)
            
        # long answer
        long_answer = row['long_answer']
        long_answer_gold = gold_answer['long_answer']
#         print('long_answer', long_answer)
#         print('long_answer_gold', long_answer_gold)
        long_answer_gold_str = str(long_answer_gold['start_token']) + ':' + str(long_answer_gold['end_token'])
        if isinstance(long_answer, str):
            ok = long_answer == long_answer_gold_str
            tp_long += ok
            fp_long += not ok
            is_long_fp = not ok
#             print(ok)
        elif isinstance(long_answer, float) and math.isnan(long_answer):
            # blank long answer
            ok = long_answer_gold_str == "-1:-1"
            tn_long += ok
            fn_long += not ok
#             print(ok)
            
#         print('')

        
        is_short_fp_column.append(is_short_fp)
        is_long_fp_column.append(is_long_fp)

    print('short tp:', tp_short, 'fp:', fp_short, 'fn:', fn_short, 'tn:', tn_short,
          'all:', tp_short + fp_short + fn_short + tn_short)
    print('long tp:', tp_long, 'fp:', fp_long, 'fn:', fn_long, 'tn:', tn_long,
         'all:', tp_long + fp_long + fn_long + tn_long)
    tp = tp_short + tp_long
    fp = fp_short + fp_long
    fn = fn_short + fn_long
    print('overall tp:', tp, 'fp:', fp, 'fn:', fn)
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1 = 2 * (precision * recall) / (precision + recall)
    print('f1:', '{0:.2f}'.format(f1), 'precision:', '{0:.2f}'.format(precision), 'recall:', '{0:.2f}'.format(recall))
    
    answers_csv_df['is_short_fp'] = is_short_fp_column
    answers_csv_df['is_long_fp'] = is_long_fp_column
    
    return answers_csv_df
    
answers_csv_df = compute_score('../output/1k/nq-train-part.jsonl', '../output/test_answers_df_v1.csv')

short tp: 151 fp: 202 fn: 120 tn: 527 all: 1000
long tp: 326 fp: 361 fn: 49 tn: 264 all: 1000
overall tp: 477 fp: 563 fn: 169
f1: 0.57 precision: 0.46 recall: 0.74


In [113]:
answers_csv_df.shape

(1000, 9)

In [153]:
answers_csv_df = pd.read_csv('../output/test_answers_df_v1.csv')

In [154]:
answers_csv_df

Unnamed: 0,predictions,example_id,long_answer,short_answer
0,"{'example_id': -9209839852162522524, 'short_sp...",-9209839852162522524,,
1,"{'example_id': -9188885911445781635, 'short_sp...",-9188885911445781635,,
2,"{'example_id': -9111510312671706854, 'short_sp...",-9111510312671706854,189:282,262:266
3,"{'example_id': -9110190923673509457, 'short_sp...",-9110190923673509457,201:329,202:215
4,"{'example_id': -9100123296297706673, 'short_sp...",-9100123296297706673,519:620,598:601
...,...,...,...,...
995,"{'example_id': 9159823208162950721, 'short_spa...",9159823208162950721,,
996,"{'example_id': 9160648621472984761, 'short_spa...",9160648621472984761,565:701,599:601
997,"{'example_id': 9175842193790270809, 'short_spa...",9175842193790270809,713:843,
998,"{'example_id': 9176819453396564614, 'short_spa...",9176819453396564614,430:551,
