In [9]:
import pandas as pd
import json
import collections
import math
from collections import defaultdict

In [10]:
answers_csv_df = pd.read_csv('../output/test_answers_df_v1.csv')

In [11]:
ScoreSummary = collections.namedtuple("ScoreSummary", ["short_span_score", "cls_token_score",
                                                       "answer_type_logits", "answer_type",
                                                       "start_logits", "end_logits", "unique_id",
                                                      "start_idx_in_chunk", "end_idx_in_chunk"])
def empty_score_summary():
    return ScoreSummary(None, None, None, None, None, None, None, None, None)
    
Span = collections.namedtuple("Span", ["start_token_idx", "end_token_idx", "score", "summary"])

In [12]:
def compute_score(gold_nq_jsonl_path, answers_csv_path, tqdm=None):
    
    answers_csv_df = pd.read_csv(answers_csv_path)
    gold_e_by_id = {}
    with open(gold_nq_jsonl_path) as gold_nq_jsonl:
        if tqdm is not None:
            gold_nq_jsonl = tqdm(gold_nq_jsonl)
        for index, line in enumerate(gold_nq_jsonl):
            gold_e = json.loads(line, object_pairs_hook=collections.OrderedDict)
            gold_e_by_id[gold_e['example_id']] = gold_e
            
    tp_short = 0
    fp_short = 0
    fn_short = 0
    tn_short = 0
    tp_long = 0
    fp_long = 0
    fn_long = 0
    tn_long = 0
    is_short_fp_column = []
    is_long_fp_column = []
    wrong_answers = []
    
    for index, row in answers_csv_df.iterrows():
        example_id = row['example_id']
        if not example_id in gold_e_by_id:
            raise ValueError('example id not found in gold_nq_jsonl file: ' + str(example_id))
        annotations = gold_e_by_id[example_id]['annotations']
        assert(len(annotations) == 1)
        gold_answer = annotations[0]
        # short answer
        short_answer = row['short_answer']
        short_answers_gold = gold_answer['short_answers']
        yes_no_answer_gold = gold_answer['yes_no_answer']
        is_short_fp = False
        is_long_fp = False
        
        if short_answer == 'YES' or short_answer == 'NO':
            # Check for empty gold short answer, otherwise it should be SHORT
            ok = short_answer == yes_no_answer_gold and len(short_answers_gold) == 0
            tp_short += ok
            fp_short += not ok
            is_short_fp = not ok
        elif isinstance(short_answer, str):
            assert(':' in short_answer)
            start_token, end_token = [int(x) for x in short_answer.split(':')]
            
            ok = len(short_answers_gold) != 0 and \
            start_token == short_answers_gold[0]['start_token'] and \
            end_token == short_answers_gold[-1]['end_token']
            
            tp_short += ok
            fp_short += not ok
            is_short_fp = not ok
        elif isinstance(short_answer, float) and math.isnan(short_answer):
            # blank short answer
            ok = yes_no_answer_gold == 'NONE' and len(short_answers_gold) == 0
            tn_short += ok
            fn_short += not ok
        else:
            raise Exception('wrong short_answer type, short_answer:', short_answer)
            
        # long answer
        long_answer = row['long_answer']
        long_answer_gold = gold_answer['long_answer']
        long_answer_gold_str = str(long_answer_gold['start_token']) + ':' + str(long_answer_gold['end_token'])
        if isinstance(long_answer, str):
            ok = long_answer == long_answer_gold_str
            tp_long += ok
            fp_long += not ok
            is_long_fp = not ok
        elif isinstance(long_answer, float) and math.isnan(long_answer):
            # blank long answer
            ok = long_answer_gold_str == "-1:-1"
            tn_long += ok
            fn_long += not ok
        
        is_short_fp_column.append(is_short_fp)
        is_long_fp_column.append(is_long_fp)
        
        if is_short_fp:
            predictions_summary_str = row['predictions']
            predictions_summary = eval(predictions_summary_str)
            short_spans = predictions_summary['short_spans']
            long_spans = predictions_summary['long_spans']
            info_by_unique_id = defaultdict(list)
#             info_by_unique_id= {}
            
            for ss in short_spans:
                summary = ss.summary
                
                gold_start_token = -1
                gold_end_token = -1
                if len(short_answers_gold) != 0:
                    gold_start_token = short_answers_gold[0]['start_token']
                    gold_end_token = short_answers_gold[-1]['end_token']
                    
                answer_start_token = -1
                answer_end_token = -1
                if isinstance(short_answer, str):
                    assert(':' in short_answer)
                    answer_start_token, answer_end_token = [int(x) for x in short_answer.split(':')]
                
                start_logits_shift = ss.start_token_idx - summary.start_idx_in_chunk
                end_logits_shift = ss.end_token_idx - summary.end_idx_in_chunk
                start_logits = [0.0] * start_logits_shift + summary.start_logits
                end_logits = [0.0] * end_logits_shift + summary.end_logits
                info_by_unique_id[summary.unique_id].append( (start_logits, end_logits, gold_start_token, gold_end_token,
                                                        answer_start_token, answer_end_token,
                                                        ss.start_token_idx, ss.end_token_idx,
                                                        summary.short_span_score, summary.cls_token_score) )                                                            
            
            wrong_answers.append(info_by_unique_id)
            
#         if len(wrong_answers) > 100:
#             break

    print('short tp:', tp_short, 'fp:', fp_short, 'fn:', fn_short, 'tn:', tn_short,
          'all:', tp_short + fp_short + fn_short + tn_short)
    print('long tp:', tp_long, 'fp:', fp_long, 'fn:', fn_long, 'tn:', tn_long,
         'all:', tp_long + fp_long + fn_long + tn_long)
    tp = tp_short + tp_long
    fp = fp_short + fp_long
    fn = fn_short + fn_long
    print('overall tp:', tp, 'fp:', fp, 'fn:', fn)
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1 = 2 * (precision * recall) / (precision + recall)
    print('f1:', '{0:.2f}'.format(f1), 'precision:', '{0:.2f}'.format(precision), 'recall:', '{0:.2f}'.format(recall))
    
    #answers_csv_df['is_short_fp'] = is_short_fp_column
    #answers_csv_df['is_long_fp'] = is_long_fp_column
    
    return wrong_answers
    
wrong_answers = compute_score('../output/1k/nq-train-part.jsonl', '../output/test_answers_df_v1.csv')

short tp: 86 fp: 101 fn: 59 tn: 304 all: 550
long tp: 168 fp: 209 fn: 25 tn: 148 all: 550
overall tp: 254 fp: 310 fn: 84
f1: 0.56 precision: 0.45 recall: 0.75


In [13]:
len(wrong_answers)

101

In [27]:
import plotly.graph_objects as go

docs_n = 0
docs_without_gold = 0
docs_with_another_chunk_is_gold = 0
docs_with_second_chunk_is_gold = 0
docs_without_good_spans_but_with_gold = 0

def plot_document(info_by_unique_id, plot=True):
    global docs_without_gold
    global docs_with_another_chunk_is_gold
    global docs_with_second_chunk_is_gold
    global docs_without_good_spans_but_with_gold
    
    fig = go.Figure()
    
    some_span_is_good = False
    gold_is_presented = False
    for chunk_index, (unique_id, chunk_answers) in enumerate(info_by_unique_id.items()):
        the_best = chunk_index == 0
        best_chunk_answer = chunk_answers[0]
        (start_logits, end_logits, gold_start_token, gold_end_token,
            answer_start_token, answer_end_token, span_start, span_end,
            short_span_score, cls_token_score) = best_chunk_answer
        score = short_span_score - cls_token_score
        x = [i for i in range(len(start_logits))]
        
        gold_pair = (gold_start_token, gold_end_token)
        answer = (answer_start_token, answer_end_token)
        span = (span_start, span_end)
#         if the_best:
#             print('gold:', gold_pair)
#             print('answer:', answer)
#         print('span:', span)
#         print('score:', score)
#         print('cls_token_score:', cls_token_score)
#         print('chunk_answers:', len(chunk_answers))
        if gold_pair == span:
            some_span_is_good = True
            if chunk_index == 1:
                print('Best span in second chunk is the answer!')
                docs_with_second_chunk_is_gold += 1
            else:
                print('Best span in another chunk is the answer!')
                docs_with_another_chunk_is_gold += 1
        print('')
        
        if the_best:
            if gold_pair == (-1, -1):
                docs_without_gold += 1
            else:
                gold_is_presented = True
        
        mode = 'lines+markers' if the_best else 'lines'
        
        if plot:        
            fig.add_trace(go.Scatter(x=x, y=start_logits,
                                mode=mode,
                                name='start_logits'))
            fig.add_trace(go.Scatter(x=x, y=end_logits,
                                mode=mode,
                                name='end_logits'))
            fig.add_trace(go.Scatter(x=[answer_start_token, answer_end_token], y=[0.5, 0.5],
                                mode='markers',
                                name='start/end',
                                marker=dict(size=[15, 15])))
            fig.add_trace(go.Scatter(x=[gold_start_token, gold_end_token], y=[1.0, 1.0],
                                mode='markers',
                                name='gold start/end',
                                marker=dict(size=[15, 15])))
        
    if plot:
        fig.show()
        
    if not some_span_is_good and gold_is_presented:
        docs_without_good_spans_but_with_gold += 1
        

for index, info_by_unique_id in enumerate(wrong_answers):
    print('doc_i:', index)
    plot_document(info_by_unique_id, plot=False)
    docs_n += 1
#     if index > 10:
#         break

print('docs_n:', docs_n)
print('docs_without_gold:', docs_without_gold)
print('docs_with_second_chunk_is_gold:', docs_with_second_chunk_is_gold)
print('docs_with_another_chunk_is_gold:', docs_with_another_chunk_is_gold)
print('docs_without_good_spans_but_with_gold:', docs_without_good_spans_but_with_gold)

doc_i: 0

Best span in second chunk is the answer!




doc_i: 1

doc_i: 2







doc_i: 3










doc_i: 4


doc_i: 5

Best span in second chunk is the answer!

doc_i: 6



doc_i: 7


doc_i: 8




doc_i: 9

Best span in second chunk is the answer!






doc_i: 10






doc_i: 11




doc_i: 12






doc_i: 13






doc_i: 14


doc_i: 15



doc_i: 16







doc_i: 17

Best span in second chunk is the answer!



doc_i: 18


doc_i: 19


doc_i: 20







doc_i: 21

doc_i: 22



doc_i: 23







doc_i: 24

doc_i: 25



doc_i: 26




doc_i: 27




doc_i: 28


doc_i: 29


doc_i: 30


doc_i: 31

doc_i: 32








doc_i: 33









doc_i: 34



doc_i: 35





doc_i: 36








doc_i: 37




doc_i: 38



doc_i: 39




doc_i: 40







doc_i: 41

doc_i: 42






doc_i: 43


doc_i: 44




doc_i: 45





doc_i: 46



doc_i: 47



doc_i: 48









doc_i: 49



doc_i: 50


doc_i: 51






doc_i: 52




doc_i: 53

Best span in second chunk is the answer!

doc_i: 54

Best span in second chunk is th