# COMET evaluation for the Entity-Aware Machine Translation (EA-MT) task
This notebook provides an example of how to evaluate the Entity-Aware Machine Translation (EA-MT) task using the COMET evaluation metric. The EA-MT task is a variant of the machine translation task where the source text contains named entities that need to be translated correctly.

**NOTE**: The COMET evaluation metric computes the accuracy/quality of the translation at the sentence level, i.e., it may not correlate well with the quality of the translation of the named entities. For an evaluation metric that is more sensitive to the quality of the translation of named entities, please take a look at our notebook on our manual Entity-level Translation Accuracy (m-ETA) metric.

In [12]:
# General imports
import json
import os

# Import the comet module for the evaluation
from comet import download_model, load_from_checkpoint

In [13]:
COMET_MODEL_NAME = "Unbabel/wmt22-comet-da"
SYSTEM_NAME = "DeepSeek(zh&tr&it)_validation"
SOURCE_LANGUAGE = "en_US"
TARGET_LANGUAGE = "it_IT"# tr_TR zh_TW
DATA_DIR = "../data"
SPLIT = "validation"
NUM_GPUS = 1
BATCH_SIZE = 32

# The path to the references is formatted as follows:
# data/references/{split}/{target_language}.jsonl
PATH_TO_REFERENCES = os.path.join(
    DATA_DIR,
    "references",
    SPLIT,
    f"{TARGET_LANGUAGE}.jsonl",
)

# The path to the predictions is formatted as follows:
# data/predictions/{system_name}/{split}/{target_language}.jsonl
PATH_TO_PREDICTIONS = os.path.join(
    DATA_DIR,
    "predictions",
    SYSTEM_NAME,
    SPLIT,
    f"{TARGET_LANGUAGE}.jsonl",
)

# Load the data
Let's load the data that will be used for the evaluation.

## Data overview
Let's have a look at the data. The data is organized in JSONL format, where each line is a JSON object that contains the following fields (formatted for better readability):
```json
{
  "id": "Q1093267_0",
  "wikidata_id": "Q1093267",
  "entity_types": [
    "TV series"
  ],
  "source": "How many episodes are in the TV series Space Battleship Yamato II?",
  "targets": [
    {
      "translation": "Quanti episodi ci sono nella serie TV La corazzata Yamato?",
      "mention": "La corazzata Yamato"
    },
    {
      "translation": "Quanti episodi ci sono nella serie TV la corazzata Yamato?",
      "mention": "la corazzata Yamato"
    }
  ],
  "source_locale": "en",
  "target_locale": "it"
}
```

## Things to note for data loading
Since COMET does not support multi-reference translations, we will:
1. Create an entry for each reference translation while keeping a reference to the source (e.g., `Q1093267_0`).
2. Duplicate the predictions for each reference translation.
3. Compute the COMET score for each prediction-reference pair.
4. Get the maximum COMET score for each source-reference pair with the same source.
5. Compute the average of the maximum COMET scores for each source-reference pair.

In other words, we will compute the COMET score for each prediction-reference pair and then average the maximum scores for each source-reference pair with the same source.

In [15]:
# Load the references
references = {}

with open(PATH_TO_REFERENCES, "r",encoding="utf-8") as f:

    for line in f:
        data = json.loads(line)
        references[data["id"]] = data

print(f"Loaded {len(references)} references from {PATH_TO_REFERENCES}")

Loaded 730 references from ../data\references\validation\it_IT.jsonl


In [16]:
# Load the predictions
predictions = {}

with open(PATH_TO_PREDICTIONS, "r",encoding="utf-8") as f:

    for line in f:
        data = json.loads(line)
        predictions[data["id"]] = data

print(f"Loaded {len(predictions)} predictions from {PATH_TO_PREDICTIONS}")

Loaded 730 predictions from ../data\predictions\DeepSeek(zh&tr&it)_validation\validation\it_IT.jsonl


In [17]:
# Get all those references that have a corresponding prediction
ids = set(references.keys()) & set(predictions.keys())
num_missing_predictions = len(references) - len(ids)

if num_missing_predictions > 0:
    print(f"Missing predictions for {num_missing_predictions} references")
else:
    print("All references have a corresponding prediction")

All references have a corresponding prediction


In [18]:
instance_ids = {}
instances = []
current_index = 0

for id in sorted(list(ids)):
    reference = references[id]
    prediction = predictions[id]

    for target in reference["targets"]:
        instances.append(
            {
                "src": reference["source"],
                "ref": target["translation"],
                "mt": prediction["prediction"],
            }
        )

    instance_ids[id] = [current_index, current_index + len(reference["targets"])]
    current_index += len(reference["targets"])

print(f"Created {len(instances)} instances")

Created 1268 instances


# Load the COMET evaluation model
Let's download the model and load it into the memory.

In [20]:
# Download the model
model_path = download_model(COMET_MODEL_NAME)
# Load the model
model = load_from_checkpoint(model_path)

Lightning automatically upgraded your loaded checkpoint from v1.8.3.post1 to v2.5.0.post0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint C:\Users\skyfu\.cache\huggingface\hub\models--Unbabel--wmt22-comet-da\snapshots\f49d328952c3470eff6bb6f545d62bfdb6e66304\checkpoints\model.ckpt`
Encoder model frozen.
D:\anaconda\Lib\site-packages\pytorch_lightning\core\saving.py:195: Found keys that are not in the model state dict but in the checkpoint: ['encoder.model.embeddings.position_ids']


# Evaluate the predictions
We can now evaluate the predictions using the COMET evaluation model.

In [22]:
# Compute the scores
outputs = model.predict(instances, batch_size=BATCH_SIZE, gpus=NUM_GPUS)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Predicting DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████| 40/40 [01:48<00:00,  2.70s/it]


In [23]:
# Extract the scores
scores = outputs.scores
max_scores = []

for id, indices in instance_ids.items():
    # Get the max score for each reference
    max_score = max(scores[indices[0] : indices[1]])
    max_scores.append(max_score)

# Compute the average score while taking into account the missing predictions (which are considered as 0)
# system_score = sum(max_scores) / (len(max_scores) + num_missing_predictions)
system_score = sum(max_scores) / (len(max_scores))
print(f"Average COMET score: {100.*system_score:.2f}")

Average COMET score: 87.38


In [24]:
outputs.system_score

0.8607031721445476