In [None]:
!pip install datasets evaluate transformers[sentencepiece] seqeval sacrebleu huggingface_hub accelerate

In [None]:
from datasets import load_dataset
import evaluate
from huggingface_hub import notebook_login
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
import warnings

warnings.filterwarnings('ignore')
notebook_login()

In [None]:
#データセットの用意
lang1, lang2 = "en", "ja" #要定義
raw_datasets = load_dataset("kde4", lang1=lang1, lang2=lang2)
split_datasets = raw_datasets["train"].train_test_split(train_size=0.9, seed=8192)
split_datasets["validation"] = split_datasets.pop("test")

#トークナイザーの準備
model_checkpoint = "Helsinki-NLP/opus-tatoeba-en-ja" #要定義
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors="pt")

#前処理の関数定義
def preprocess_function(examples):
    inputs = [ex[lang1] for ex in examples["translation"]]
    targets = [ex[lang2] for ex in examples["translation"]]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    
    return model_inputs

#前処理を一気に行う
tokenized_datasets = split_datasets.map(
    preprocess_function,
    batched=True,
    remove_columns=split_datasets["train"].column_names,
)

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

metric = evaluate.load("sacrebleu")

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [[label.strip()] for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)

    return {"bleu": result["score"]}

args = Seq2SeqTrainingArguments(
    f"marian-finetuned-kde4-en-to-ja", #要定義
    evaluation_strategy="no",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=True,
)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)
trainer.train()
trainer.evaluate(max_length=128)
trainer.push_to_hub(tags="translation", commit_message="Training complete")