In [1]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from tqdm.auto import tqdm
import evaluate

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_repo = "model/mt5-small-scb-mt-th-en-bf16/checkpoint-62500"

model = AutoModelForSeq2SeqLM.from_pretrained(model_repo).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_repo)

In [7]:
th_text = []
en_text = []

with open("../data/iwslt_2015/tst2010-2013_th-en.en", "r") as f:
    en_text = [line.strip() for line in f]
with open("../data/iwslt_2015/tst2010-2013_th-en.th", "r") as f:
    th_text = [line.strip() for line in f]

print(f"th_text: {len(th_text)}, en_text: {len(en_text)}")

th_text: 4242, en_text: 4242


In [8]:
predictions = []
batch_size = 64
for i in tqdm(range(0, len(th_text), batch_size)):
    batch = (th_text[i:i+batch_size])

    inputs = tokenizer(batch, return_tensors="pt", padding=True).to(device)

    translated_tokens = model.generate(**inputs, max_length=64)
    predictions += tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)

  0%|          | 0/67 [00:00<?, ?it/s]

In [9]:
metric = evaluate.load("sacrebleu")
en_ref = [[line] for line in en_text]
metric.compute(predictions=predictions, references=en_ref)

{'score': 12.13760170977748,
 'counts': [35746, 13070, 5538, 2444],
 'totals': [76552, 72310, 68068, 63833],
 'precisions': [46.69505695474971,
  18.074955054625917,
  8.13598166539343,
  3.828740620055457],
 'bp': 0.9531897675267356,
 'sys_len': 76552,
 'ref_len': 80222}