# M-ETA evaluation for the Entity-Aware Machine Translation (EA-MT) task
This notebook provides an evaluation script for the Entity-Aware Machine Translation (EA-MT) task. The evaluation script is based on the M-ETA metric, which itakes into account the entity information in the reference and candidate translations.

## Overview of the M-ETA metric
M-ETA stands for *"Manual Entity Translation Accuracy"* proposed in **[1]**. At a high level, given a set of gold entity translations and a set of predicted entity translations, m-ETA computes the proportion of correctly translated entities in the predicted entity translations.

```
M-ETA = (Number of correctly translated entities) / (Number of entities in the reference translations)
```

In general, we say that a predicted entity translation is correct if it is an exact match with at least one of the reference entity translations, which have been manually annotated by human evaluators.

**Note** that:
- the M-ETA metric is computed at the entity level, not at the token level.
- the M-ETA metric focuses on providing an entity-level score, not a score for the overall quality of the translation (as in BLEU or METEOR). Therefore, it is recommended to use M-ETA in conjunction with other metrics to get a comprehensive evaluation of the translation quality.

**[1]** [Towards Cross-Cultural Machine Translation with Retrieval-Augmented Generation from Multilingual Knowledge Graphs](https://aclanthology.org/2024.emnlp-main.914)

## Data
This notebook expects the data to be organized in the following way:
```shell
data
├── predictions
│   └── <your_model_name>
│       └── validation
│           ├── ar_AE.jsonl
│           ├── de_DE.jsonl
│           ├── es_ES.jsonl
│           ├── fr_FR.jsonl
│           ├── it_IT.jsonl
│           ├── ja_JP.jsonl
│           ├── ko_KR.jsonl
│           ├── th_TH.jsonl
│           ├── tr_TR.jsonl
│           └── zh_TW.jsonl
└── references
    ├── sample
    │   ├── ar_AE.jsonl
    │   ├── de_DE.jsonl
    │   ├── es_ES.jsonl
    │   ├── fr_FR.jsonl
    │   ├── it_IT.jsonl
    │   ├── ja_JP.jsonl
    │   ├── ko_KR.jsonl
    │   ├── th_TH.jsonl
    │   ├── tr_TR.jsonl
    │   └── zh_TW.jsonl
    ├── test
    │   ├── ar_AE.jsonl
    │   ├── de_DE.jsonl
    │   ├── es_ES.jsonl
    │   ├── fr_FR.jsonl
    │   ├── it_IT.jsonl
    │   ├── ja_JP.jsonl
    │   ├── ko_KR.jsonl
    │   ├── th_TH.jsonl
    │   ├── tr_TR.jsonl
    │   └── zh_TW.jsonl
    └── validation
        ├── ar_AE.jsonl
        ├── de_DE.jsonl
        ├── es_ES.jsonl
        ├── fr_FR.jsonl
        ├── it_IT.jsonl
        ├── ja_JP.jsonl
        ├── ko_KR.jsonl
        ├── th_TH.jsonl
        ├── tr_TR.jsonl
        └── zh_TW.jsonl
```

### Data format for the predictions
The data should be stored in JSONL format. Each line in the JSONL file should be a JSON object with the following keys:
- `id`: a unique identifier for the translation that corresponds to the `id` in the reference file.
- `source_language`: the source language of the translation.
- `target_language`: the target language of the translation.
- `text`: the original text in the source language.
- `prediction`: the translated text in the target language.

For example:
```json
{"id": "1", "source_language": "English", "target_language": "German", "text": "Hello, how are you?", "prediction": "Hallo, wie geht es dir?"}
{"id": "2", "source_language": "English", "target_language": "German", "text": "I am fine, thank you.", "prediction": "Mir geht es gut, danke."}
```

### Note
Remember to change the following variables to match your data:
- `PATH_TO_DATA_DIR`: the path to the directory containing the predictions and references.
- `SYSTEM_NAME`: the name of your model.
- `SPLIT`: the split of the data (e.g., `validation`, `test`).
- `TARGET_LANGUAGE`: the target language of the translations; one of `ar_AE`, `de_DE`, `es_ES`, `fr_FR`, `it_IT`, `ja_JP`, `ko_KR`, `th_TH`, `tr_TR`, `zh_TW`.

In [3]:
# Importing the required libraries
import json
import os
import re
from typing import Dict, List, Set

In [4]:
# Whether to use the verbose mode.
VERBOSE = False

# List of entity types to be evaluated.
# Used to filter the evaluation to a specific entity type.
ENTITY_TYPES = [
    "Musical work",
    "Artwork",
    "Food",
    "Animal",
    "Plant",
    "Book",
    "Book series",
    "Fictional entity",
    "Landmark",
    "Movie",
    "Place of worship",
    "Natural place",
    "TV series",
    "Person",
]

PATH_TO_DATA_DIR = "../data"

# Change the split to "validation" or "test" to evaluate the predictions on the respective split.
SPLIT = "validation"

# Change the target language to evaluate the predictions on the respective language.
TARGET_LANGUAGE = "it_IT"

# Change the system name to evaluate the predictions generated by the respective system.
SYSTEM_NAME = "DeepSeek(zh&tr&it)_validation"

# Load the references file.
PATH_TO_REFERENCES = os.path.join(
    PATH_TO_DATA_DIR,
    "references",
    SPLIT,
    f"{TARGET_LANGUAGE}.jsonl",
)

# Path to the predictions file.
PATH_TO_PREDICTIONS = os.path.join(
    PATH_TO_DATA_DIR,
    "predictions",
    SYSTEM_NAME,
    SPLIT,
    f"{TARGET_LANGUAGE}.jsonl",
)

In [5]:
def load_references(input_path: str, entity_types: List[str]) -> List[dict]:
    """
    Load data from the input file (JSONL) and return a list of dictionaries, one for each instance in the dataset.

    Args:
        input_path (str): Path to the input file.
        entity_types (List[str]): List of entity types to filter the evaluation.

    Returns:
        List[dict]: List of dictionaries, one for each instance in the dataset.
    """
    data = []

    with open(input_path,encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue

            line_data = json.loads(line)

            # Skip instances with empty target list and log a warning.
            if not line_data["targets"]:
                print(f"Empty target list for instance {line_data['id']}")
                continue

            # Filter the evaluation to the specified entity types if provided.
            if entity_types and not any(
                e in line_data["entity_types"] for e in entity_types
            ):
                continue

            data.append(line_data)

    return data

In [6]:
def load_predictions(input_path: str) -> Dict[str, str]:
    """
    Load data from the input file (JSONL) and return a dictionary with the instance ID as key and the prediction as value.

    Args:
        input_path (str): Path to the input file.

    Returns:
        Dict[str, str]: Dictionary with the instance ID as key and the prediction as value.
    """
    data = {}

    with open(input_path,encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue

            line_data = json.loads(line)
            prediction = line_data["prediction"]

            # Get the instance ID from a substring of the ID.
            pattern = re.compile(r"Q[0-9]+_[0-9]")
            match = pattern.match(line_data["id"])
            if not match:
                raise ValueError(f"Invalid instance ID: {line_data['id']}")

            instance_id = match.group(0)
            data[instance_id] = prediction

    return data

In [7]:
def compute_entity_name_translation_accuracy(
    predictions: Dict[str, str],
    mentions: Dict[str, Set[str]],
    verbose: bool = False,
) -> dict:
    """
    Compute the entity name translation accuracy.

    Args:
        predictions (Dict[str, str]): Predictions of the model.
        mentions (Dict[str, Set[str]]): Ground truth entity mentions.
        verbose (bool): Set to True to print every wrong match.

    Returns:
        dict: Dictionary with the following
            - correct: Number of correct matches.
            - total: Total number of instances.
            - accuracy: Accuracy of the model.
    """
    correct, total = 0, 0

    for instance_id, instance_mentions in mentions.items():
        # Check that there is at least one entity mention for the instance.
        assert instance_mentions, f"No mentions for instance {instance_id}"

        # Increment the total count of instances (for recall calculation).
        total += 1

        # Check that there is a prediction for the instance.
        if instance_id not in predictions:
            if verbose:
                print(
                    f"No prediction for instance {instance_id}. Check that this is expected behavior, as it may affect the evaluation."
                )
            continue

        prediction = predictions[instance_id]
        normalized_translation = prediction.casefold()
        entity_match = False

        for mention in instance_mentions:
            normalized_mention = mention.casefold()

            # Check if the normalized mention is a substring of the normalized translation.
            # If it is, consider the prediction (the entity name translation) correct.
            if normalized_mention in normalized_translation:
                correct += 1
                entity_match = True
                break

        # Log the prediction and the ground truth mentions for every wrong match if verbose is set.
        if not entity_match and verbose:
            print(f"Prediction: {prediction}")
            print(f"Ground truth mentions: {instance_mentions}")
            print("")

    return {
        "correct": correct,
        "total": total,
        "accuracy": correct / total if total > 0 else 0.0,
    }

In [8]:
def get_mentions_from_references(data: List[dict]) -> Dict[str, Set[str]]:
    """
    Load the ground truth entity mentions from the data.

    Args:
        data (List[dict]): List of dictionaries, one for each instance in the dataset.

    Returns:
        Dict[str, Set[str]]: Dictionary with the instance ID as key and the set of entity mentions as value.
    """
    mentions = {}

    for instance in data:
        instance_id = instance["id"]
        instance_mentions = set()

        for target in instance["targets"]:
            mention = target["mention"]
            instance_mentions.add(mention)

        mentions[instance_id] = instance_mentions

    return mentions

In [9]:
print(f"Loading data from {PATH_TO_REFERENCES}...")
reference_data = load_references(PATH_TO_REFERENCES, ENTITY_TYPES)
# reference_data = load_references(PATH_TO_REFERENCES, ENTITY_TYPES)[:50]
mentions = get_mentions_from_references(reference_data)
assert len(mentions) == len(reference_data)
print(f"Loaded {len(reference_data)} instances.")

Loading data from ../data\references\validation\it_IT.jsonl...
Loaded 730 instances.


In [10]:
print(f"Loading data from {PATH_TO_PREDICTIONS}...")
prediction_data = load_predictions(PATH_TO_PREDICTIONS)
print(f"Loaded {len(prediction_data)} predictions.")

Loading data from ../data\predictions\DeepSeek(zh&tr&it)_validation\validation\it_IT.jsonl...
Loaded 730 predictions.


In [11]:
print("Computing entity name translation accuracy...")
entity_name_translation_accuracy = compute_entity_name_translation_accuracy(
    prediction_data,
    mentions,
    verbose=VERBOSE,
)

Computing entity name translation accuracy...


In [12]:
print("=============================================")
print(f"Evaluation results in {TARGET_LANGUAGE}")
print(f"Correct instances   = {entity_name_translation_accuracy['correct']}")
print(f"Total instances     = {entity_name_translation_accuracy['total']}")

accuracy = entity_name_translation_accuracy["accuracy"] * 100.0
print("-----------------------------")
print(f"m-ETA               = {accuracy:.2f}")
print("=============================================")
print("")

print("Evaluation completed.")

Evaluation results in it_IT
Correct instances   = 598
Total instances     = 730
-----------------------------
m-ETA               = 81.92

Evaluation completed.


## 