In [1]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from tqdm.auto import tqdm
import evaluate

device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load Model
Select the varients of model to load (Standard model/LoRA model)

In [None]:
# Load Original Model / Full finetuned model
model_repo = "model/nllb-600m-scb-3epochs"

model = AutoModelForSeq2SeqLM.from_pretrained(model_repo).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_repo)

In [2]:
# Load LoRA model
from peft import PeftModel

model_repo = "facebook/nllb-200-distilled-600M"
lora_repo = "model/epochs3.0"
model = AutoModelForSeq2SeqLM.from_pretrained(model_repo).to(device)
model = PeftModel.from_pretrained(model, lora_repo)
tokenizer = AutoTokenizer.from_pretrained(model_repo, src_lang="tha_Thai", tgt_lang="eng_Latn")

# Load test data

In [3]:
th_text = []
en_text = []

with open("../data/iwslt_2015/tst2010-2013_th-en.en", "r") as f:
    en_text = [line.strip() for line in f]
with open("../data/iwslt_2015/tst2010-2013_th-en.th", "r") as f:
    th_text = [line.strip() for line in f]

print(f"th_text: {len(th_text)}, en_text: {len(en_text)}")

th_text: 4242, en_text: 4242


In [4]:
predictions = []
batch_size = 128
for i in tqdm(range(0, len(th_text), batch_size)):
    batch = (th_text[i:i+batch_size])

    inputs = tokenizer(batch, return_tensors="pt", padding=True).to(device)

    translated_tokens = model.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["eng_Latn"], max_length=64)
    predictions += tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)

  0%|          | 0/67 [00:00<?, ?it/s]

In [6]:
metric = evaluate.load("sacrebleu")
en_ref = [[line] for line in en_text]
metric.compute(predictions=predictions, references=en_ref)

{'score': 24.014071578976555,
 'counts': [44220, 22616, 12885, 7546],
 'totals': [77806, 73564, 69322, 65085],
 'precisions': [56.83366321363391,
  30.743298352455007,
  18.587172903263035,
  11.594069294000153],
 'bp': 0.969425559265115,
 'sys_len': 77806,
 'ref_len': 80222}