In [64]:
import pandas as pd
import json
import collections
import math
import numpy as np
import tqdm

In [71]:
def safe_devide(a, b):
    if b == 0:
        return 0
    return a / b

def compute_f1(tp, fp, fn):
    precision = safe_devide(tp, (tp + fp))
    recall = safe_devide(tp, (tp + fn))
    f1 = 2 * safe_devide((precision * recall), (precision + recall))
    return f1

gold_e_by_id_static = {}
gold_nq_jsonl_path = '../output/1k/nq-train-part.jsonl'
with open(gold_nq_jsonl_path) as gold_nq_jsonl:
    for index, line in enumerate(gold_nq_jsonl):
        gold_e = json.loads(line, object_pairs_hook=collections.OrderedDict)
        gold_e_by_id_static[gold_e['example_id']] = gold_e

def compute_score(gold_e_by_id, answers_csv_df, tqdm=None):
    
    #answers_csv_df = pd.read_csv(answers_csv_path)
            
    tp_short = 0
    fp_short = 0
    fn_short = 0
    tn_short = 0
    tp_long = 0
    fp_long = 0
    fn_long = 0
    tn_long = 0
    is_short_fp_column = []
    is_long_fp_column = []
    
    for index, row in answers_csv_df.iterrows():
        example_id = row['example_id']
        if not example_id in gold_e_by_id:
            raise ValueError('example id not found in gold_nq_jsonl file: ' + str(example_id))
        annotations = gold_e_by_id[example_id]['annotations']
        assert(len(annotations) == 1)
        gold_answer = annotations[0]
        # short answer
        short_answer = row['short_answer']
        short_answers_gold = gold_answer['short_answers']
        yes_no_answer_gold = gold_answer['yes_no_answer']
        short_answer_score = row['short_answer_score']
        is_short_fp = False
        is_long_fp = False
#         print('short_answer:', short_answer)
#         print('short_answers_gold:', short_answers_gold)
#         print('yes_no_answer_gold:', yes_no_answer_gold)
        
        if short_answer == 'YES' or short_answer == 'NO':
            ok = short_answer == yes_no_answer_gold
            tp_short += ok
            fp_short += not ok
            is_short_fp = not ok
#             print(ok)
        elif isinstance(short_answer, str):
            assert(':' in short_answer)
            start_token, end_token = [int(x) for x in short_answer.split(':')]
            short_answer_dict = collections.OrderedDict([('start_token', start_token), ('end_token', end_token)])
            ok = short_answer_dict in short_answers_gold
            tp_short += ok
            fp_short += not ok
            is_short_fp = not ok
#             print(ok)
        elif isinstance(short_answer, float) and math.isnan(short_answer):
            # blank short answer
            ok = yes_no_answer_gold == 'NONE' and len(short_answers_gold) == 0
            tn_short += ok
            fn_short += not ok
#             print(ok)
        else:
            raise Exception('wrong short_answer type, short_answer:', short_answer)
            
        # long answer
        long_answer = row['long_answer']
        long_answer_gold = gold_answer['long_answer']
#         print('long_answer', long_answer)
#         print('long_answer_gold', long_answer_gold)
        long_answer_gold_str = str(long_answer_gold['start_token']) + ':' + str(long_answer_gold['end_token'])
        if isinstance(long_answer, str):
            ok = long_answer == long_answer_gold_str
            tp_long += ok
            fp_long += not ok
            is_long_fp = not ok
#             print(ok)
        elif isinstance(long_answer, float) and math.isnan(long_answer):
            # blank long answer
            ok = long_answer_gold_str == "-1:-1"
            tn_long += ok
            fn_long += not ok
#             print(ok)
            
#         print('')

        
        is_short_fp_column.append(is_short_fp)
        is_long_fp_column.append(is_long_fp)

    print('short tp:', tp_short, 'fp:', fp_short, 'fn:', fn_short, 'tn:', tn_short,
          'all:', tp_short + fp_short + fn_short + tn_short)
    print('long tp:', tp_long, 'fp:', fp_long, 'fn:', fn_long, 'tn:', tn_long,
         'all:', tp_long + fp_long + fn_long + tn_long)
    
    f1_short = compute_f1(tp=tp_short, fp=fp_short, fn=fn_short)
    f1_long = compute_f1(tp=tp_long, fp=fp_long, fn=fn_long)
    
    tp = tp_short + tp_long
    fp = fp_short + fp_long
    fn = fn_short + fn_long
    f1 = compute_f1(tp=tp, fp=fp, fn=fn)
    
    return f1_short, f1_long, f1

In [81]:
def create_short_answer(entry, threshold, check_type):
    answer = []    
    if entry['answer_type'] == 0:
        return ""
    if check_type and entry['answer_type'] == 4:
        return ""
    
    elif entry['answer_type'] == 1:
        return 'YES'
    
    elif entry['answer_type'] == 2:
        return 'NO'
        
    elif entry["short_answer_score"] < threshold:
        return ""
    
    else:
        for short_answer in entry["short_answers"]:
            if short_answer["start_token"] > -1:
                answer.append(str(short_answer["start_token"]) + ":" + str(short_answer["end_token"]))
    
        return " ".join(answer)

def create_long_answer(entry, threshold, check_type):
    
    answer = []
    
    if entry['answer_type'] == 0:
        return ''
    if check_type and entry['answer_type'] == 4:
        return ""
    
    elif entry["long_answer_score"] < threshold:
        return ""

    elif entry["long_answer"]["start_token"] > -1:
        answer.append(str(entry["long_answer"]["start_token"]) + ":" + str(entry["long_answer"]["end_token"]))
        return " ".join(answer)

In [82]:
predictions_json_static = pd.read_json("../output/1k/predictions.json")

def test_params(threshold_short, check_type_short, threshold0_long, check_type_long):

    create_short_answer_l = lambda q: create_short_answer(q, threshold=threshold_short,
                                                       check_type=check_type_short)
    create_long_answer_l = lambda q: create_long_answer(q, threshold=threshold0_long, check_type=check_type_long)
    
    test_answers_df = predictions_json_static
    for var_name in ['long_answer_score','short_answer_score','answer_type']:
        test_answers_df[var_name] = test_answers_df['predictions'].apply(lambda q: q[var_name])
    test_answers_df["long_answer"] = test_answers_df["predictions"].apply(create_long_answer_l)
    test_answers_df["short_answer"] = test_answers_df["predictions"].apply(create_short_answer_l)
    test_answers_df["example_id"] = test_answers_df["predictions"].apply(lambda q: str(q["example_id"]))

    long_answers = dict(zip(test_answers_df["example_id"], test_answers_df["long_answer"]))
    short_answers = dict(zip(test_answers_df["example_id"], test_answers_df["short_answer"]))

    test_answers_df.example_id = test_answers_df.example_id.astype('int64')
    test_answers_df = test_answers_df.replace(r'', np.NaN)
    return compute_score(gold_e_by_id_static, test_answers_df)

In [94]:
threshold_short = 8.35
check_type_short = True
threshold_long_params = np.arange(3, 5, 0.05).tolist()
check_type_long_params = [False]

results = []
for threshold_long in threshold_long_params:
    for check_type_long in check_type_long_params:
        res = test_params(threshold_short, check_type_short, threshold_long, check_type_long)
        res_with_params = (res, (threshold_short, check_type_short, threshold_long, check_type_long))
        print(res_with_params)
        results.append(res_with_params)
results.sort(reverse=True)

short tp: 150 fp: 200 fn: 121 tn: 529 all: 1000
long tp: 327 fp: 363 fn: 48 tn: 262 all: 1000
((0.4830917874396135, 0.6140845070422535, 0.5658362989323843), (8.35, True, 3.0, False))
short tp: 150 fp: 200 fn: 121 tn: 529 all: 1000
long tp: 327 fp: 361 fn: 48 tn: 264 all: 1000
((0.4830917874396135, 0.6152398871119474, 0.5665083135391924), (8.35, True, 3.05, False))
short tp: 150 fp: 200 fn: 121 tn: 529 all: 1000
long tp: 327 fp: 360 fn: 48 tn: 265 all: 1000
((0.4830917874396135, 0.615819209039548, 0.5668449197860963), (8.35, True, 3.0999999999999996, False))
short tp: 150 fp: 200 fn: 121 tn: 529 all: 1000
long tp: 326 fp: 360 fn: 49 tn: 265 all: 1000
((0.4830917874396135, 0.6145146088595664, 0.5659928656361475), (8.35, True, 3.1499999999999995, False))
short tp: 150 fp: 200 fn: 121 tn: 529 all: 1000
long tp: 326 fp: 355 fn: 49 tn: 270 all: 1000
((0.4830917874396135, 0.6174242424242424, 0.56768038163387), (8.35, True, 3.1999999999999993, False))
short tp: 150 fp: 200 fn: 121 tn: 529 all:

In [95]:
results

[((0.4830917874396135, 0.6190009794319294, 0.5676004872107187),
  (8.35, True, 3.9499999999999966, False)),
 ((0.4830917874396135, 0.6184586108468125, 0.5681818181818181),
  (8.35, True, 3.3499999999999988, False)),
 ((0.4830917874396135, 0.6177908113391984, 0.5669099756690997),
  (8.35, True, 3.899999999999997, False)),
 ((0.4830917874396135, 0.6175024582104228, 0.5665445665445665),
  (8.35, True, 3.9999999999999964, False)),
 ((0.4830917874396135, 0.6174242424242424, 0.56768038163387),
  (8.35, True, 3.1999999999999993, False)),
 ((0.4830917874396135, 0.6173320350535539, 0.566747572815534),
  (8.35, True, 3.799999999999997, False)),
 ((0.4830917874396135, 0.6172839506172839, 0.5675029868578255),
  (8.35, True, 3.299999999999999, False)),
 ((0.4830917874396135, 0.6172839506172839, 0.5675029868578255),
  (8.35, True, 3.249999999999999, False)),
 ((0.4830917874396135, 0.6172106824925816, 0.5661764705882353),
  (8.35, True, 4.249999999999996, False)),
 ((0.4830917874396135, 0.61721068249

In [96]:
results2 = [(x[0][2], x) for x in results]
results2.sort(reverse=True)
results2

[(0.5681818181818181,
  ((0.4830917874396135, 0.6184586108468125, 0.5681818181818181),
   (8.35, True, 3.3499999999999988, False))),
 (0.56768038163387,
  ((0.4830917874396135, 0.6174242424242424, 0.56768038163387),
   (8.35, True, 3.1999999999999993, False))),
 (0.5676004872107187,
  ((0.4830917874396135, 0.6190009794319294, 0.5676004872107187),
   (8.35, True, 3.9499999999999966, False))),
 (0.5675029868578255,
  ((0.4830917874396135, 0.6172839506172839, 0.5675029868578255),
   (8.35, True, 3.299999999999999, False))),
 (0.5675029868578255,
  ((0.4830917874396135, 0.6172839506172839, 0.5675029868578255),
   (8.35, True, 3.249999999999999, False))),
 (0.5669099756690997,
  ((0.4830917874396135, 0.6177908113391984, 0.5669099756690997),
   (8.35, True, 3.899999999999997, False))),
 (0.5668449197860963,
  ((0.4830917874396135, 0.615819209039548, 0.5668449197860963),
   (8.35, True, 3.0999999999999996, False))),
 (0.566747572815534,
  ((0.4830917874396135, 0.6173320350535539, 0.5667475728

### Found best params

In [104]:
#threshold_short = 8.35
threshold_short = 3.5
check_type_short = False
#threshold_long = 3.35
threshold_long = 3.5
check_type_long = False
test_params(threshold_short, check_type_short, threshold_long, check_type_long)

short tp: 166 fp: 512 fn: 15 tn: 307 all: 1000
long tp: 319 fp: 343 fn: 60 tn: 278 all: 1000


(0.38649592549476136, 0.6128722382324687, 0.5105263157894736)

In [100]:
if False:
    a = 2
else:
    b = 4
print(a)
print(b)

2
4
