In [None]:
!pip3 install datasets transformers sacrebleu unbabel-comet polars -q

In [None]:
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Data

In [None]:
# Language codes
full_src_lang = "Czech"
full_tgt_lang = "German"

In [None]:
from datasets import load_dataset

dataset_name = "ymoslem/news-commentary-cs-de"  # sentence-level data

dataset = load_dataset(dataset_name,
                       split="train",
                      )

dataset = dataset.shuffle(seed=0)

# Split dataset into train and test
dataset = dataset.train_test_split(test_size=500, seed=0)

dataset = dataset["test"]

dataset

In [None]:
source_sentences = dataset["source"]
prompt = f"Translate the following text from {full_src_lang} to {full_tgt_lang}:"
prompts = [prompt + "\n" + sent + "\n" for sent in source_sentences]
print(prompts[0])

In [None]:
references = dataset["target"]
references[0]

In [None]:
def define_max_len(sentences):
    max_len, longest_idx = max([(len(sent.split()), idx)
                                for idx, sent in enumerate(sentences)])
                                
    max_len = max_len * 2
    return max_len, longest_idx

max_len, longest_idx = define_max_len(source_sentences)

print(max_len)

# Model

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

layers = 24  # 16, 20, 24

model_id = "CohereLabs/aya-expanse-8b"
# model_id = f"ymoslem/wmt25-cs-de-{layers}layers-2e-05-100k-news-commentary-sentences"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id,
                                             dtype=torch.bfloat16,
                                            ).to(device).eval()

model = torch.compile(model, mode="reduce-overhead")

# model.config.max_position_embeddings = 4096  # just to match our vLLM eval

assert model.device.type == "cuda"

In [None]:
# Translate in batches

from tqdm.auto import tqdm
import torch

print(f"Translating {len(prompts)} prompts...")

batch_size = dataset.num_rows # 500 or try 256 for low memory
translations = []

for i in tqdm(range(0, len(prompts), batch_size)):
    batch_prompts = prompts[i:i+batch_size]
    
    # Format all messages in the batch
    batch_messages = [[{"role": "user", "content": prompt}] for prompt in batch_prompts]
    
    # Tokenize the entire batch and get attention mask
    batch_inputs = tokenizer.apply_chat_template(
        batch_messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        padding=True,
        return_dict=True  # This returns both input_ids and attention_mask
    )
    
    input_ids = batch_inputs['input_ids'].to(device)
    attention_mask = batch_inputs['attention_mask'].to(device)
    
    # Store original lengths for each sequence in the batch
    original_length = input_ids.shape[1]  # All sequences have same length due to padding
    
    # Generate for the entire batch with attention mask
    with torch.no_grad():
        gen_tokens = model.generate(
            input_ids,
            attention_mask=attention_mask,  # Pass the attention mask
            max_new_tokens=max_len,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
            use_cache=True
        )
    
    # Decode batch results
    for j, tokens in enumerate(gen_tokens):
        # Get the length of the original input for this specific sequence
        new_tokens = tokens[original_length:]
        translation = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
        translations.append(translation)

In [None]:
translations[0]

In [None]:
# # Optional: Save translations to a file
# with open("output.txt", "w") as output:
#     for sentence in translations:
#         output.write(sentence.strip() + "\n")

In [None]:
# Release memory
import gc
model = None
gc.collect()
with torch.no_grad():
    torch.cuda.empty_cache()

# Evaluation

In [None]:
from sacrebleu.metrics import CHRF

all_scores = []

chrf = CHRF(word_order=2)

chrf_score = round(chrf.corpus_score(translations, [references]).score, 2)

all_scores.append(chrf_score)

chrf_score

In [None]:
from comet import download_model, load_from_checkpoint

# Download and load a COMET model
comet_model_names = ["Unbabel/wmt20-comet-da", "Unbabel/wmt22-comet-da"]

for comet_model_name in comet_model_names:

    model_path = download_model(comet_model_name)
    comet_model = load_from_checkpoint(model_path).to("cuda")

    assert comet_model.device.type == "cuda"

    # Prepare the data
    data = []
    for src, mt, ref in zip(source_sentences, translations, references):
        data.append({
            "src": src,
            "mt": mt,
            "ref": ref
        })

    # Calculate COMET scores
    model_output = comet_model.predict(data, batch_size=8, gpus=1)
    comet_scores = model_output.scores
    comet_corpus_score = round(model_output.system_score * 100, 2)
    all_scores.append(comet_corpus_score)

    print(comet_model_name)
    print(f"Corpus COMET score: {comet_corpus_score}")

In [None]:
import polars as pl

print(model_id)

df = pl.DataFrame([all_scores],
                  schema=["chrF++", "COMET20", "COMET22"],
                  orient="row",
                 )

df