To run this script, download the [MOCHA.tar.gz](https://github.com/anthonywchen/MOCHA/blob/main/data/mocha.tar.gz) and extract it to `data/mocha`. This script will use the following frameworks:
- `SPACY` for tokenization;
- HuggingFace `datasets` for evaluation.

In [1]:
# Data loading
import json

# Tokenization
import spacy

# Evaluation
import datasets

In [2]:
# the (lowercase) name of the dataset to run the analysis for
DATASET = "narrativeqa"

# Data directory: where to look for a model
DATA_DIR = "../data/mocha/"

# Full filepath to load
FILEPATH = f"{DATA_DIR}/dev.json"


# ouput directory
OUTPUT_DIR = "../output/predictions"

# Evaluation tokenizer
EVALUATION_TOKENIZER = spacy.load("en_core_web_sm", disable=["parser", "ner"])

In [3]:
data = json.load(open(FILEPATH))
data = data[DATASET] if DATASET is not None else data
print("Number of examples:", len(data))
next(iter(data.items()))

Number of examples: 890


('0005c7718ff653683df879622efb02d1',
 {'candidate': 'his distant relative pascal rougon',
  'context': "The plot centres on the neurotic young priest Serge Mouret, first seen in La ConquĂŞte de Plassans, as he takes his orders and becomes the parish priest for the uninterested village of Artauds. The inbred villagers have no interest in religion and Serge is portrayed giving several wildly enthusiastic Masses to his completely empty, near-derelict church. Serge not only seems unperturbed by this state of affairs but actually appears to have positively sought it out especially, for it gives him time to contemplate religious affairs and to fully experience the fervour of his faith. Eventually he has a complete nervous breakdown and collapses into a near-comatose state, whereupon his distant relative, the unconventional doctor Pascal Rougon (the central character of the last novel in the series, 1893's Le Docteur Pascal), places him in the care of the inhabitants of a nearby derelict stat

### Token overlap metrics

Current QA evaluation relies on string matching or token overlap metrics.

In [34]:
BLEU = datasets.load_metric("bleu")
ROUGE = datasets.load_metric("rouge") 
#^Note: requires installing rouge-score (!pip install rouge-score)

METEOR = datasets.load_metric("meteor")
EXACT_MATCH = datasets.load_metric("exact_match")

BERT_SCORE = datasets.load_metric("bertscore")
#^Note: requires installing bert-score: https://pypi.org/project/bert-score/



from collections import Counter
from typing import Dict, List, Union
import logging


Tokens = List[str]
Text = Union[str, Tokens]


def exact_match(y_true: Text, y_pred: Text) -> int:
    """Determine whether two texts (or sequences of tokens) are equal."""
    if isinstance(y_true, str) and isinstance(y_pred, str):
        return int(y_true == y_pred)

    elif isinstance(y_true, (list, tuple)) and isinstance(y_pred, (list, tuple)):
        if len(y_true) != len(y_pred):
            logging.debug(
                f"Dimension mismatch (default value is 0): {y_true} vs {y_pred}"
            )
            return 0
        return int(all(map(lambda t1, t2: t1 == t2, y_true, y_pred)))
    else:
        error_msg = f"y_true ({type(y_true)}) and y_pred ({type(y_pred)})"
        raise ValueError(
            f"Cannot compare `exact_match` for argument types: {error_msg}"
        )


def first_error_position(y_true: Tokens, y_pred: Tokens, no_err_val: int = None) -> int:
    """Determine the position in the predicted sequence of the first error.
    Notes
    -----
    If both text sequences are equivalent we return ``no_err_val`` as the position.
    Otherwise, we iterate for each token in ``y_pred`` and look for the first
    mismatch between ``y_pred`` and ``y_true`` tokens returning that position.
    Examples
    --------
    >>> y_true = ["The", "sky", "is", "blue"]
    >>> y_pred = ["A", "sky", "is", "blue"]
    >>> first_error_position(y_true, y_pred)
    1
    >>> y_pred = ["The", "sky", "IS", "blue"]
    >>> first_error_position(y_true, y_pred)
    3
    >>> first_error_position(y_true, y_true, no_err_val=-1)
    -1
    """
    assert isinstance(y_true, (list, tuple)) and len(y_true) != 0
    assert isinstance(y_pred, (list, tuple)) and len(y_pred) != 0

    # When no error occurs return the `no_err_val`
    if exact_match(y_true, y_pred):
        return no_err_val

    # If there are differences then they are one of two types:
    # 1. Token mismatch: which will occur in the common length of
    # the two sequences. Values can vary between 0 and min(lengths)
    # 2. Misnumber of tokens: one of the sequences is longer than the
    # other, causing them to be wrong.
    max_mismatch_ix = min(len(y_true), len(y_pred))

    for i in range(max_mismatch_ix):
        if y_true[i] != y_pred[i]:
            return i
    return max_mismatch_ix


def _precision(tp, fp, tn, fn) -> float:
    return 0 if tp == 0 else tp / (tp + fp)


def _recall(tp, fp, tn, fn) -> float:
    return 0 if tp == 0 else tp / (tp + fn)


def _critical_success_index(tp, fp, tn, fn):
    "Ratio of positives w.r.t. number of errors (also dubbed threat score)."
    return 0 if tp == 0 else tp / (tp + fn + fp)


def _f1_score(precision=None, recall=None, **kwargs) -> float:
    if precision is not None and recall is not None:
        p = precision
        r = recall
        # return if precision or recall are 0
        if p == 0 or r == 0:
            return 0
    else:
        p = _precision(**kwargs)
        r = _recall(**kwargs)

    return (2 * p * r) / (p + r)


def f_metrics(references: Tokens, predictions: Tokens) -> Dict[str, float]:
    true_tokens, pred_tokens = Counter(references), Counter(predictions)
    tp = sum((true_tokens & pred_tokens).values())
    fp = len(predictions) - tp
    fn = len(references) - tp
    tn = 0
    assert tp + fp + fn == sum((true_tokens | pred_tokens).values())

    prec = _precision(tp=tp, fp=fp, tn=tn, fn=fn)
    rec = _recall(tp=tp, fp=fp, tn=tn, fn=fn)
    return {
        "precision": prec,
        "recall": rec,
        "f1_score": _f1_score(precision=prec, recall=rec),
        "csi": _critical_success_index(tp=tp, fp=fp, tn=tn, fn=fn),
    }

[nltk_data] Downloading package wordnet to /home/kat/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /home/kat/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/kat/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
Couldn't find a directory or a metric named 'exact_match' in this version. It was picked from the master branch on github instead.


In [29]:
def tokenize(texts, tokens=True, sep: str = " ", tokenizer=EVALUATION_TOKENIZER) -> list:
    arg_is_str = isinstance(texts, str)
    if arg_is_str:
        texts = [texts]

    results = []
    for text in texts:
        text_tokens = tokenizer(text)
        text_tokens = [str(t) for t in text_tokens]
        results.append(text_tokens if tokens else [sep.join(text_tokens)])

    return results[0] if arg_is_str else results


def normalize_answer(s: str) -> str:
    """Lower text and remove punctuation, articles and extra whitespace."""
    import re, string

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


# Sanity check
assert tokenize("Hello, world!") == ['Hello', ',', 'world', '!']
assert tokenize(["Hello, world!"]) == [['Hello', ',', 'world', '!']]
assert tokenize(["Hello, world!", "Ola, mundo!"]) == [['Hello', ',', 'world', '!'], ["Ola", ",", "mundo", "!"]]

In [30]:
def get_examples(data):
    references = {example_id: (example["reference"]) for example_id, example in data.items()}
    candidates = {example_id: (example["candidate"]) for example_id, example in data.items()}
    
    return references, candidates


def apply(d: dict, fn):
    return {k: fn(elem) for k, elem in d.items()}


# Separate data in references and candidates
references, candidates = get_examples(data)
assert len(references) == len(candidates)

# Compute metrics that rely on strings ()
references_norm = apply(references, fn=normalize_answer)
candidates_norm = apply(candidates, fn=normalize_answer)

# Tokenize data
references_tokens = apply(references_norm, fn=tokenize)
candidates_tokens = apply(candidates_norm, fn=tokenize)

In [31]:
def compute_rouge(**kwargs):
    results = ROUGE.compute(**kwargs)
    return {rouge_type: rouge.mid.fmeasure for rouge_type, rouge in results.items()}

def compute_bleu(max_order=4, **kwargs):
    bleu = {}
    for i in range(1, max_order+1):
        results = BLEU.compute(**kwargs, max_order=i)["bleu"]
        bleu[f"bleu{i}"] = results
        
    return bleu

In [35]:
metric_results = []

for k in references.keys():
    reference = references_norm[k]
    candidate = candidates_norm[k]

    reference_tks = references_tokens[k]
    candidate_tks = candidates_tokens[k]    

    metrics = {
        "example_id": k, 
        "reference": reference,
        "candidate": candidate,
        
        "reference_tokens": reference_tks,
        "candidate_tokens": candidate_tks,
    }
    
    text_args = {"predictions": [candidate], "references": [reference]}
    
    metrics.update(text_args)
    metrics.update(EXACT_MATCH.compute(**text_args))
    metrics.update(METEOR.compute(**text_args))
    metrics.update(compute_rouge(**text_args))
    
    # keyword arguments for BLEU-like metrics
    token_args = {"predictions": [candidate_tks], "references": [[reference_tks]]}
    metrics.update(compute_bleu(**token_args))
    
    token_args = {"predictions": candidate_tks, "references": reference_tks}
    metrics.update(f_metrics(**token_args))
    
    
    break
    metric_results.append(metrics)

In [36]:
metrics

{'example_id': '0005c7718ff653683df879622efb02d1',
 'reference': 'le docteur pascal',
 'candidate': 'his distant relative pascal rougon',
 'reference_tokens': ['le', 'docteur', 'pascal'],
 'candidate_tokens': ['his', 'distant', 'relative', 'pascal', 'rougon'],
 'predictions': ['his distant relative pascal rougon'],
 'references': ['le docteur pascal'],
 'exact_match': 0.0,
 'meteor': 0.15625,
 'rouge1': 0.25,
 'rouge2': 0.0,
 'rougeL': 0.25,
 'rougeLsum': 0.25,
 'bleu1': 0.2,
 'bleu2': 0.0,
 'bleu3': 0.0,
 'bleu4': 0.0,
 'precision': 0.2,
 'recall': 0.3333333333333333,
 'f1_score': 0.25,
 'csi': 0.14285714285714285}