In [None]:
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, Literal
import random
import os
import numpy as np
import torch

from datasets import load_dataset, DatasetDict, Dataset
import evaluate

from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    set_seed,
)

In [None]:
@dataclass
class Config:
    model_name: str = "facebook/nllb-200-distilled-600M"
    dataset_path: str = "alexantonov/chuvash_russian_parallel"
    load_kwargs: Dict[str, Any] | None = None
    test_size: Optional[float] = 0.1
    source_column: str = "ru"
    target_column: str = "chv"
    src_lang: str = "rus_Cyrl"
    tgt_lang: str = "chv_Cyrl"
    max_source_length: int = 256
    max_target_length: int = 256
    output_dir: str = "./nllb_run"
    num_train_epochs: int = 3
    per_device_train_batch_size: int = 4
    per_device_eval_batch_size: int = 4
    gradient_accumulation_steps: int = 4
    learning_rate: float = 2e-5
    weight_decay: float = 0.01
    warmup_ratio: float = 0.1
    logging_steps: int = 1_000
    save_steps: int = 9_000
    eval_steps: int = 3_000
    save_total_limit: int = 2
    evaluation_strategy: str = "steps"
    logging_dir: str = "./logs"
    predict_with_generate: bool = True
    generation_max_length: int = 256
    generation_num_beams: int = 1
    seed: int = 42
    fp16: bool = True
    bf16: bool = False
    dataloader_num_workers: int = 4
    remove_unused_columns: bool = True
    push_to_hub: bool = False
    report_to: str = "none"
    load_best_model_at_end: bool = True
    metric_for_best_model: str = "bleu"
    greater_is_better: bool = True

In [None]:
def _load_dataset_from_path(path: str, test_size: float | None = None, load_kwargs = None) -> DatasetDict:  
    load_kwargs = load_kwargs or {}
    if path.endswith("jsonl"):
        dataset = load_dataset("json", data_files=path, **load_kwargs)
    else:
        dataset = load_dataset(path, **load_kwargs)
    
    if test_size is not None and 'test' not in dataset.keys() and 'train' in dataset.keys():
        dataset = dataset["train"].train_test_split(
            test_size, seed=42, load_from_cache_file=True
        )
    
    return dataset


def _sample_dataset(
    dataset: Dataset,
    mode: Literal['random', 'sequential'],
    num_samples: int | None = None,
    ds_ratio: float | None = None
) -> Dataset:
    total_samples = len(dataset)

    if num_samples is None and ds_ratio is None:
        raise ValueError("Either num_samples or ds_ratio must be specified")
    if num_samples is not None and ds_ratio is not None:
        raise ValueError("Only one of num_samples or ds_ratio should be specified, not both")

    total_to_select = num_samples if num_samples is not None else int(total_samples * ds_ratio)
    idx = range(total_to_select) if mode =='sequential' else random.sample(range(total_samples), total_to_select)

    dataset = dataset.select(idx)
    return dataset


def preprocess_function(batch):
    src_texts = batch[config.source_column]
    tgt_texts = batch[config.target_column]

    model_inputs = tokenizer(
        src_texts,
        max_length=config.max_source_length,
        truncation=True,
    )

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            tgt_texts,
            max_length=config.max_target_length,
            truncation=True,
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [label.strip() for label in decoded_labels]

    bleu_res = sacrebleu.compute(
        predictions=decoded_preds,
        references=[[l] for l in decoded_labels]
    )
    chrf_res = chrf.compute(
        predictions=decoded_preds,
        references=[[l] for l in decoded_labels],
        word_order=2  # chrF++
    )

    gen_lens = [np.count_nonzero(p != tokenizer.pad_token_id) for p in preds]

    return {
        "bleu": bleu_res["score"],
        "chrf++": chrf_res["score"],
        "gen_len": float(np.mean(gen_lens)),
    }

In [None]:
config = Config()

os.environ["TOKENIZERS_PARALLELISM"] = "false"
set_seed(config.seed)

sacrebleu = evaluate.load("sacrebleu")
chrf = evaluate.load("chrf")

In [None]:
raw_datasets = _load_dataset_from_path(
    config.dataset_path,
    test_size=config.test_size,
    load_kwargs=config.load_kwargs
)

tokenizer = AutoTokenizer.from_pretrained(
    config.model_name,
    src_lang=config.src_lang,
    tgt_lang=config.tgt_lang,
)

model = AutoModelForSeq2SeqLM.from_pretrained(config.model_name)

In [None]:
column_names_train = raw_datasets["train"].column_names

sample_train = _sample_dataset(raw_datasets["train"], mode='random', num_samples = 120_000)
sample_eval = _sample_dataset(raw_datasets["test"], mode='random', num_samples = 3_000)

tokenized_train = sample_train.map(
    preprocess_function,
    batched=True,
    remove_columns=column_names_train if config.remove_unused_columns else None,
    desc="Tokenizing train split"
)

tokenized_eval = sample_eval.map(
    preprocess_function,
    batched=True,
    remove_columns=column_names_train if config.remove_unused_columns else None,
    desc="Tokenizing val split"
)

In [None]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding="longest",
    label_pad_token_id=-100,
)

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir=config.output_dir,
    overwrite_output_dir=True,
    num_train_epochs=config.num_train_epochs,
    learning_rate=config.learning_rate,
    per_device_train_batch_size=config.per_device_train_batch_size,
    per_device_eval_batch_size=config.per_device_eval_batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    weight_decay=config.weight_decay,
    warmup_ratio=config.warmup_ratio,

    eval_strategy=config.evaluation_strategy,
    logging_steps=config.logging_steps,
    save_steps=config.save_steps,
    eval_steps=config.eval_steps,
    save_total_limit=config.save_total_limit,

    predict_with_generate=config.predict_with_generate,
    generation_max_length=config.generation_max_length,
    generation_num_beams=config.generation_num_beams,

    fp16=True,
    bf16=False,
    dataloader_num_workers=config.dataloader_num_workers,

    seed=config.seed,
    report_to=config.report_to,

    remove_unused_columns=False,  
    load_best_model_at_end=config.load_best_model_at_end,
    metric_for_best_model=config.metric_for_best_model,
    greater_is_better=config.greater_is_better,
    logging_dir=config.logging_dir,
)

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_eval,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()