In [1]:
import re
import json
import collections

In [2]:
RESULTS_PATH = 'results.jsonl'

In [3]:
def load_results(file_path):
    with open(file_path, 'rt', encoding='utf-8') as fobj:
        for line in fobj:
            yield json.loads(line)

In [4]:
collections.Counter(result['answer'] for result in load_results(RESULTS_PATH))

Counter({'X:': 375, 'Y:': 369})

In [5]:
RE_ANSWER = re.compile(r'[ABXY12]')
COALESCE = {'1': 'A', '2': 'B', 'A': 'A', 'B': 'B', 'Y': 'A', 'X': 'B'}

def parse_answer(answer):
    matches = RE_ANSWER.findall(answer)
    if matches:
        return COALESCE.get(matches[0])
    else:
        return None

collections.Counter(
    (parse_answer(result['answer']), result['swap'])
    for result in load_results(RESULTS_PATH)
)

Counter({('B', True): 341,
         ('A', False): 338,
         ('B', False): 34,
         ('A', True): 31})

In [6]:
collections.Counter(
    parse_answer(result['answer'])
    for result in load_results(RESULTS_PATH)
)

Counter({'B': 375, 'A': 369})

In [7]:
def is_correct(result):
    answer = parse_answer(result['answer'])
    if answer is None:
        return None
    expected = ('B' if result['swap'] else 'A')
    return answer == expected

In [8]:
def aggregate_swapped_answers(results):
    correct_answers = {}
    for result in results:
        correct_answers.setdefault(result['article_sent'], {})[result['swap']] = is_correct(result)
    return correct_answers


def stable_score(correct_answers):
    score = {
        'AB': 0,
        'AA': 0,
        'BA': 0,
        'BB': 0,
        'U': 0,
    }
    for answers in correct_answers.values():
        bucket = 'U'
        if answers[False] is None or answers[True] is None:
            bucket = 'U'
        elif answers[False] is True and answers[True] is True:
            bucket = 'AB'
        elif answers[False] is False and answers[True] is True:
            bucket = 'BB'
        elif answers[False] is True and answers[True] is False:
            bucket = 'AA'
        elif answers[False] is False and answers[True] is False:
            bucket = 'BA'
        score[bucket] += 1
    return score


scores = stable_score(aggregate_swapped_answers(load_results(RESULTS_PATH)))
acc = int(100 * scores['AB'] / sum(scores.values()))
scores_fmt = '  '.join(f'{k}={v}' for k, v in scores.items())
print(f'Accuracy: {acc}% Breakdown: {scores_fmt}')

Accuracy: 84% Breakdown: AB=179  AA=8  BA=11  BB=14  U=0


In [9]:
collections.Counter(
    is_correct(result)
    for result in load_results(RESULTS_PATH)
)

Counter({True: 679, False: 65})