In [1]:
from __future__ import print_function, division, unicode_literals
from collections import Counter
from pathlib import Path
import string
import re
import json

from typing import Callable, List

import spacy

### Evaluation

In [2]:
nlp = spacy.load('ro_core_news_lg', disable=["tagger", "attribute_ruler", "tok2vec", "ner"])

def normalize_answer(s: str) -> str:
    """Lower text and remove punctuation, articles and extra whitespace and for romanian language lemmatisation"""

    def lemma(text: str) -> str:
        my_doc = nlp(text)

        return ' '.join([token.lemma_ for token in my_doc])

    def remove_articles(text: str) -> str:
        return re.sub(r'\b(a|an|the|un|o)\b', ' ', text)

    def white_space_fix(text: str) -> str:
        return ' '.join(text.split())

    def remove_punc(text: str) -> str:
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text: str) -> str:
        return text.lower()

    return lemma(white_space_fix(remove_articles(remove_punc(lower(s)))))


def f1_score(prediction: str, ground_truth: str) -> float:
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())

    if num_same == 0:
        return 0

    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)

    return f1


def exact_match_score(prediction: str, ground_truth: str) -> bool:
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def metric_max_over_ground_truths(metric_fn: Callable, prediction: str, ground_truths: List[str]) -> float:
    scores_for_ground_truths = []

    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)

    return max(scores_for_ground_truths)


def evaluate(dir_generate: str) -> None:
    type_model = dir_generate.split('/')[-1]
    results = {}

    for i in [str(x) for x in Path(dir_generate).glob("*") if x.is_dir()]:
        for path_file in [str(x) for x in Path(i).glob('**/*.json') if x.is_file()]:
            f1, exact_match, total = 0, 0, 0
            name_eval = path_file.split('/')[-1].replace('.json', '')

            with open(path_file, 'r') as input:
                data = json.load(input)

            for example in data:
                ground_truths = [example['original']]
                prediction = example['predict']

                exact_match += metric_max_over_ground_truths(
                    exact_match_score, prediction, ground_truths)
                f1 += metric_max_over_ground_truths(
                    f1_score, prediction, ground_truths)
                total += 1

            exact_match = 100.0 * exact_match / total
            f1 = 100.0 * f1 / total

            results[name_eval] = {'exact_match': exact_match, 'f1': f1}
            print(f'Eval: {name_eval}, Exact_Match: {exact_match}, F1: {f1}')

    Path('../../../log/xquad/').mkdir(exist_ok=True, parents=True)
    with open(f'../../../log/xquad/{type_model}.txt', 'w+') as output_file:
        for k, v in results.items():
            output_file.write(f'{k}: {v}\n')

### Run evaluation

In [3]:
evaluate('../../../generate/xquad/normal')

Eval: base-beam-search-4, Exact_Match: 24.11764705882353, F1: 35.27481850314759
Eval: base-greedy, Exact_Match: 23.69747899159664, F1: 35.970264064383905
Eval: base-beam-search-8, Exact_Match: 24.11764705882353, F1: 35.27481850314759
Eval: medium-beam-search-4, Exact_Match: 31.596638655462186, F1: 45.32667118672392
Eval: medium-beam-search-8, Exact_Match: 31.596638655462186, F1: 45.32667118672392
Eval: medium-greedy, Exact_Match: 29.66386554621849, F1: 44.740962768922714
Eval: large-beam-search-4, Exact_Match: 29.66386554621849, F1: 43.05865381544039
Eval: large-beam-search-8, Exact_Match: 29.66386554621849, F1: 43.05865381544039
Eval: large-greedy, Exact_Match: 29.747899159663866, F1: 42.98861526665083


In [4]:
evaluate('../../../generate/xquad/translate')


Eval: large-v2-top-7, Exact_Match: 31.34453781512605, F1: 44.71327852976019
Eval: large-v2-greedy, Exact_Match: 31.34453781512605, F1: 44.71327852976019
Eval: large-v2-beam-search-4, Exact_Match: 31.596638655462186, F1: 43.532352329176156
Eval: large-v1-beam-search-8, Exact_Match: 28.73949579831933, F1: 39.71153009877786
Eval: large-v1-top-7, Exact_Match: 28.403361344537814, F1: 39.790719779626826
Eval: base-v1-greedy, Exact_Match: 23.865546218487395, F1: 34.27249347982264
Eval: base-v1-top-7, Exact_Match: 23.865546218487395, F1: 34.27249347982264
Eval: base-v1-beam-search-4, Exact_Match: 25.04201680672269, F1: 34.51705084886004
Eval: base-v1-beam-search-8, Exact_Match: 25.04201680672269, F1: 34.51705084886004
Eval: medium-v1-beam-search-8, Exact_Match: 27.647058823529413, F1: 39.11011390158698
Eval: medium-v1-greedy, Exact_Match: 27.058823529411764, F1: 39.75278341376798
Eval: medium-v1-top-7, Exact_Match: 27.058823529411764, F1: 39.75278341376798
Eval: medium-v1-beam-search-4, Exact_