To run this script, download the [MOCHA.tar.gz](https://github.com/anthonywchen/MOCHA/blob/main/data/mocha.tar.gz) and extract it to `data/mocha`. This script will use the following frameworks:
- `SPACY` for tokenization;
- HuggingFace `datasets` for evaluation.

In [None]:
# Data loading
import json

# Tokenization
import spacy

# Evaluation
import datasets

In [None]:
# Data directory: where to look for a model
DATA_DIR = "../data/mocha/"

SPLIT = "dev"
# Full filepath to load
FILEPATH = f"{DATA_DIR}/dev.json"

# ouput directory
OUTPUT_DIR = "../outputs/proxy_metrics"

# Evaluation tokenizer
EVALUATION_TOKENIZER = spacy.load("en_core_web_sm", disable=["parser", "ner"])

In [None]:
data = json.load(open(FILEPATH))
print("Number of examples:", len(data))
next(iter(data.items()))

### Token overlap metrics

Current QA evaluation relies on string matching or token overlap metrics.

In [None]:
BLEU = datasets.load_metric("bleu", keep_in_memory=True)
BLEURT = datasets.load_metric("bleurt", keep_in_memory=True)
#^Note: requires installing bleurt !pip install git+https://github.com/google-research/bleurt.git
ROUGE = datasets.load_metric("rouge", keep_in_memory=True) 
#^Note: requires installing rouge-score (!pip install rouge-score)

METEOR = datasets.load_metric("meteor", keep_in_memory=True)
EXACT_MATCH = datasets.load_metric("exact_match", keep_in_memory=True)

BERT_SCORE = datasets.load_metric("bertscore", keep_in_memory=True)
#^Note: requires installing bert-score: https://pypi.org/project/bert-score/

EDIT_RATIO = datasets.load_metric("ter", keep_in_memory=True)
#^Note: requires installing sacrebleu
# pip install sacrebleu sacrebleu

from collections import Counter
from typing import Dict, List, Union
import logging


Tokens = List[str]
Text = Union[str, Tokens]


def exact_match(y_true: Text, y_pred: Text) -> int:
    """Determine whether two texts (or sequences of tokens) are equal."""
    if isinstance(y_true, str) and isinstance(y_pred, str):
        return int(y_true == y_pred)

    elif isinstance(y_true, (list, tuple)) and isinstance(y_pred, (list, tuple)):
        if len(y_true) != len(y_pred):
            logging.debug(
                f"Dimension mismatch (default value is 0): {y_true} vs {y_pred}"
            )
            return 0
        return int(all(map(lambda t1, t2: t1 == t2, y_true, y_pred)))
    else:
        error_msg = f"y_true ({type(y_true)}) and y_pred ({type(y_pred)})"
        raise ValueError(
            f"Cannot compare `exact_match` for argument types: {error_msg}"
        )


def first_error_position(y_true: Tokens, y_pred: Tokens, no_err_val: int = None) -> int:
    """Determine the position in the predicted sequence of the first error.
    Notes
    -----
    If both text sequences are equivalent we return ``no_err_val`` as the position.
    Otherwise, we iterate for each token in ``y_pred`` and look for the first
    mismatch between ``y_pred`` and ``y_true`` tokens returning that position.
    Examples
    --------
    >>> y_true = ["The", "sky", "is", "blue"]
    >>> y_pred = ["A", "sky", "is", "blue"]
    >>> first_error_position(y_true, y_pred)
    1
    >>> y_pred = ["The", "sky", "IS", "blue"]
    >>> first_error_position(y_true, y_pred)
    3
    >>> first_error_position(y_true, y_true, no_err_val=-1)
    -1
    """
    assert isinstance(y_true, (list, tuple)) and len(y_true) != 0
    assert isinstance(y_pred, (list, tuple)) and len(y_pred) != 0

    # When no error occurs return the `no_err_val`
    if exact_match(y_true, y_pred):
        return no_err_val

    # If there are differences then they are one of two types:
    # 1. Token mismatch: which will occur in the common length of
    # the two sequences. Values can vary between 0 and min(lengths)
    # 2. Misnumber of tokens: one of the sequences is longer than the
    # other, causing them to be wrong.
    max_mismatch_ix = min(len(y_true), len(y_pred))

    for i in range(max_mismatch_ix):
        if y_true[i] != y_pred[i]:
            return i
    return max_mismatch_ix


def _precision(tp, fp, tn, fn) -> float:
    return 0 if tp == 0 else tp / (tp + fp)


def _recall(tp, fp, tn, fn) -> float:
    return 0 if tp == 0 else tp / (tp + fn)


def _critical_success_index(tp, fp, tn, fn):
    "Ratio of positives w.r.t. number of errors (also dubbed threat score)."
    return 0 if tp == 0 else tp / (tp + fn + fp)


def _f1_score(precision=None, recall=None, **kwargs) -> float:
    if precision is not None and recall is not None:
        p = precision
        r = recall
        # return if precision or recall are 0
        if p == 0 or r == 0:
            return 0
    else:
        p = _precision(**kwargs)
        r = _recall(**kwargs)

    return (2 * p * r) / (p + r)


def f_metrics(references: Tokens, predictions: Tokens) -> Dict[str, float]:
    true_tokens, pred_tokens = Counter(references), Counter(predictions)
    tp = sum((true_tokens & pred_tokens).values())
    fp = len(predictions) - tp
    fn = len(references) - tp
    tn = 0
    assert tp + fp + fn == sum((true_tokens | pred_tokens).values())

    prec = _precision(tp=tp, fp=fp, tn=tn, fn=fn)
    rec = _recall(tp=tp, fp=fp, tn=tn, fn=fn)
    return {
        "precision": prec,
        "recall": rec,
        "f1_score": _f1_score(precision=prec, recall=rec),
        "csi": _critical_success_index(tp=tp, fp=fp, tn=tn, fn=fn),
    }

In [None]:
def tokenize(texts, tokens=True, sep: str = " ", tokenizer=EVALUATION_TOKENIZER) -> list:
    arg_is_str = isinstance(texts, str)
    if arg_is_str:
        texts = [texts]

    results = []
    for text in texts:
        text_tokens = tokenizer(text)
        text_tokens = [str(t) for t in text_tokens]
        results.append(text_tokens if tokens else [sep.join(text_tokens)])

    return results[0] if arg_is_str else results


def remove_punc(s):
    return s.replace('?', '').replace('.', '').replace('!', '')


def normalize_answer(s: str) -> str:
    """Lower text and remove punctuation, articles and extra whitespace."""
    import re, string

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


# Sanity check
assert tokenize("Hello, world!") == ['Hello', ',', 'world', '!']
assert tokenize(["Hello, world!"]) == [['Hello', ',', 'world', '!']]
assert tokenize(["Hello, world!", "Ola, mundo!"]) == [['Hello', ',', 'world', '!'], ["Ola", ",", "mundo", "!"]]

In [None]:
def get_examples(data):
    references = {example_id: (example["reference"]) for example_id, example in data.items()}
    candidates = {example_id: (example["candidate"]) for example_id, example in data.items()}
    human_judgement = {example_id: (example["score"]) for example_id, example in data.items()}
    
    return references, candidates, human_judgement


def apply(d: dict, fn):
    return {k: fn(elem) for k, elem in d.items()}

In [None]:
def compute_rouge(**kwargs):
    results = ROUGE.compute(**kwargs)
    return {rouge_type: rouge.mid.fmeasure for rouge_type, rouge in results.items()}

def compute_bleu(max_order=4, **kwargs):
    bleu = {}
    for i in range(1, max_order+1):
        try: 
            results = BLEU.compute(**kwargs, max_order=i)["bleu"]
            bleu[f"bleu{i}"] = results
        except: 
            bleu[f"bleu{i}"] = 0
        
    return bleu

def compute_edit_score(**kwargs):
    results = EDIT_RATIO.compute(**kwargs)
    # Note: {'score': 133.33333333333331, 'num_edits': 4, 'ref_length': 3.0}
    results["edit_score"] = results.pop("score") / 100
    results.pop("ref_length")
    return results

In [None]:
def compute_metrics(data):
    # Separate data in references and candidates
    references, candidates, human_judgements = get_examples(data)
    assert len(references) == len(candidates)

    ## Compute metrics that rely on strings ()
    # references_norm = apply(references, fn=normalize_answer)
    # candidates_norm = apply(candidates, fn=normalize_answer)
    ## ^NOTE: This was causing some inconsistency in the results since
    ## there were a few examples where the candidate answer is simply
    ## "the" or "an" --> therefore becoming empty str
    references_norm = apply(references, fn=remove_punc)
    candidates_norm = apply(candidates, fn=remove_punc)
    
    # Tokenize data
    references_tokens = apply(references_norm, fn=tokenize)
    candidates_tokens = apply(candidates_norm, fn=tokenize)

    metric_results = []
    for example_id, correctness in human_judgements.items():
        reference = references_norm[example_id]
        candidate = candidates_norm[example_id]

        reference_tks = references_tokens[example_id]
        candidate_tks = candidates_tokens[example_id]    

        metrics = {
            "example_id": example_id,

            "reference": reference,
            "candidate": candidate,

            "reference_tokens": reference_tks,
            "candidate_tokens": candidate_tks,

            "human_correctness_original": correctness,
            "human_correctness": (correctness - 1) / (5-1),
        }

        text_args = {"predictions": [candidate], "references": [reference]}
        metrics.update(text_args)
        metrics.update(EXACT_MATCH.compute(**text_args))
        metrics.update(METEOR.compute(**text_args))
        metrics.update(compute_rouge(**text_args))
        metrics.update({"bleurt": BLEURT.compute(**text_args)["scores"][0]})

        # keyword arguments for BLEU-like metrics
        if not candidate:
            print(data[example_id], metrics)
            
        token_args_w_mult_refs_args = {"predictions": [candidate_tks], "references": [[reference_tks]]}
        metrics.update(compute_bleu(**token_args_w_mult_refs_args))

        token_args = {"predictions": candidate_tks, "references": reference_tks}
        metrics.update(f_metrics(**token_args))


        text_w_mult_refs_args = {"predictions": [candidate], "references": [[reference]]}
        # TER score (num_edits / sum_ref_lengths * 100)
        metrics.update(compute_edit_score(**text_w_mult_refs_args))

        metric_results.append(metrics)

    return metric_results

# sanity check (:
d = compute_metrics(data["drop"])

In [None]:
import pandas as pd 

In [None]:
def write_predictions(metrics, split, dataset=None, output_dir=OUTPUT_DIR):
    import os
    # Write predictions into `<output_dir>/<dataset>_<split>/`
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)

    output_file = f"{output_dir}/{split}_{dataset}_metrics.csv.gz" 
    print("Writing metrics at", output_file)
    metrics.to_csv(output_file, compression="gzip")

In [None]:
def persist_metrics(split, dataset=None, input_dir=DATA_DIR, output_dir=OUTPUT_DIR):
    _filepath = f"{input_dir}/{split}.json"

    data = json.load(open(_filepath))
    
    if isinstance(dataset, str):
        dataset = [dataset]

    data = {k: v for k, v in data.items() if (dataset is None) or (k in dataset)}
    print("Number of datasets:", len(data))

    all_metrics = []

    for dataset_name, dataset in data.items():
        print("Computing metrics", len(dataset),"examples of dataset", dataset_name)
        metrics = compute_metrics(dataset)
        metrics = pd.DataFrame(metrics)

        metrics["dataset"] = dataset_name
        metrics["split"] = split
        
        write_predictions(metrics, split=split, dataset=dataset_name, output_dir=output_dir)
        all_metrics.append(metrics)
    
    all_metrics = pd.concat(all_metrics).reset_index(drop=True)
    
    if len(data) > 1:
        write_predictions(metrics, split=split, dataset="all_datasets", output_dir=output_dir)
        
    return all_metrics

In [None]:
%%time
train_metrics = persist_metrics("train")
train_metrics.describe()

In [None]:
%%time
dev_metrics = persist_metrics("dev")
dev_metrics.describe()

In [None]:
pd.read_csv(f"{OUTPUT_DIR}/dev_all_datasets_metrics.csv.gz", index_col=0).describe()

In [None]:
pd.read_csv(f"{OUTPUT_DIR}/dev_all_datasets_metrics.csv.gz", index_col=0).describe()

In [None]:
from bert_score import score as BERT_SCORE
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.bleu.bleu import Bleu

In [None]:
print(Bleu(1).compute_score({0: [reference]}, {0: [candidate]})[1][0])
print(Bleu(2).compute_score({0: [reference]}, {0: [candidate]})[1][0])
print(Bleu(3).compute_score({0: [reference]}, {0: [candidate]})[1][0])
print(Bleu(4).compute_score({0: [reference]}, {0: [candidate]})[1][0])
print(Meteor().compute_score({0: [reference]}, {0: [candidate]})[1])
print(Rouge().compute_score({0: [reference]}, {0: [candidate]})[1])