<a href="https://colab.research.google.com/github/wesslen/seamless_sacrebleu_evaluation/blob/main/notebooks/seamless_sacrebleu_evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install transformers sacrebleu tqdm torch

In [11]:
# Import required libraries
import torch
from transformers import SeamlessM4Tv2Model, AutoProcessor
from sacrebleu.metrics import BLEU
from typing import List, Union, Optional
import tqdm

class TranslationEvaluator:
    def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
        """
        Initialize the translation evaluator with the Seamless model.

        Args:
            device: Device to run the model on ("cuda" or "cpu")
        """
        print(f"Using device: {device}")
        self.device = device
        print("Loading model and processor...")
        self.processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large")
        self.model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to(device)
        self.bleu = BLEU()
        print("Setup complete!")

    def translate_batch(self, texts: List[str], src_lang: str, tgt_lang: str, batch_size: int = 8) -> List[str]:
        """
        Translate a list of texts in batches.

        Args:
            texts: List of source texts to translate
            src_lang: Source language code (e.g., "eng", "fra")
            tgt_lang: Target language code
            batch_size: Batch size for translation

        Returns:
            List of translated texts
        """
        translations = []

        # Process in batches
        for i in tqdm.trange(0, len(texts), batch_size, desc="Translating"):
            batch = texts[i:i + batch_size]

            # Process input
            text_inputs = self.processor(
                text=batch,
                src_lang=src_lang,
                return_tensors="pt"
            ).to(self.device)

            # Generate translations
            with torch.no_grad():
                output_tokens = self.model.generate(
                    **text_inputs,
                    tgt_lang=tgt_lang,
                    generate_speech=False
                )

            # Decode translations
            # print(len(output_tokens))
            # print(len(output_tokens[0]))
            # print(len(output_tokens[0][0]))
            # print(output_tokens[0][0].tolist())
            # print(output_tokens[0].tolist())
            # print(output_tokens[0].tolist()[0])

            batch_translations = [
                self.processor.decode(tokens, skip_special_tokens=True)
                for tokens in output_tokens[0].tolist()
            ]
            translations.extend(batch_translations)

        return translations

    def evaluate_translations(
        self,
        hypotheses: List[str],
        references: Union[List[str], List[List[str]]],
        verbose: bool = True
    ) -> BLEU:
        """
        Evaluate translations using sacreBLEU.

        Args:
            hypotheses: List of system outputs (translations)
            references: List of reference translations. For multiple references,
                       provide a list of lists where each inner list contains
                       all references for one source sentence
            verbose: Whether to print the BLEU score

        Returns:
            BLEU score object
        """
        # Handle single reference case
        if isinstance(references[0], str):
            references = [references]

        # Calculate BLEU score
        bleu_score = self.bleu.corpus_score(hypotheses, references)

        if verbose:
            print(f"BLEU score: {bleu_score.score:.2f}")
            print(f"Signature: {self.bleu.get_signature()}")

        return bleu_score

# Example usage
if __name__ == "__main__":
    # Test data
    source_texts = [
        "Hello, my dog is cute",
        "The weather is nice today",
        "I love programming"
    ]

    # Reference translations in French
    references = [
        ["Bonjour, mon chien est mignon"],
        ["Le temps est beau aujourd'hui"],
        ["J'aime la programmation"]
    ]

    print("Initializing translator...")
    # Initialize evaluator
    evaluator = TranslationEvaluator()

    print("\nTranslating texts...")
    # Translate texts
    translations = evaluator.translate_batch(
        texts=source_texts,
        src_lang="eng",
        tgt_lang="fra"
    )

    print("\nEvaluating translations...")
    # Evaluate translations
    bleu_score = evaluator.evaluate_translations(
        hypotheses=translations,
        references=references
    )

    # Print detailed results
    print("\nDetailed Results:")
    print("-" * 50)
    for src, hyp, ref in zip(source_texts, translations, references):
        print(f"\nSource: {src}")
        print(f"System: {hyp}")
        print(f"Reference: {ref[0]}")

Initializing translator...
Using device: cuda
Loading model and processor...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Setup complete!

Translating texts...


Translating: 100%|██████████| 1/1 [00:00<00:00,  1.61it/s]


Evaluating translations...
BLEU score: 75.98
Signature: nrefs:3|case:mixed|eff:no|tok:13a|smooth:exp|version:2.4.3

Detailed Results:
--------------------------------------------------

Source: Hello, my dog is cute
System: Salut, mon chien est mignon
Reference: Bonjour, mon chien est mignon

Source: The weather is nice today
System: Le temps est beau aujourd'hui.
Reference: Le temps est beau aujourd'hui

Source: I love programming
System: J'adore la programmation
Reference: J'aime la programmation



