In [None]:
!pip install datasets evaluate transformers[sentencepiece] seqeval sacrebleu huggingface_hub accelerate

In [None]:
from accelerate import Accelerator
from datasets import load_dataset
import evaluate
from huggingface_hub import get_full_repo_name, notebook_login, Repository
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AdamW, AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, get_scheduler, pipeline, Seq2SeqTrainingArguments, Seq2SeqTrainer
import warnings

warnings.filterwarnings('ignore')
notebook_login()

In [None]:
lang1, lang2 = "en", "ja"
raw_datasets = load_dataset("kde4", lang1=lang1, lang2=lang2)
split_datasets = raw_datasets["train"].train_test_split(train_size=0.9, seed=8192)
split_datasets["validation"] = split_datasets.pop("test")

In [None]:
model_checkpoint = "Helsinki-NLP/opus-tatoeba-en-ja"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors="pt")

In [None]:
def preprocess_function(examples):
    inputs = [ex[lang1] for ex in examples["translation"]]
    targets = [ex[lang2] for ex in examples["translation"]]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    
    return model_inputs

tokenized_datasets = split_datasets.map(
    preprocess_function,
    batched=True,
    remove_columns=split_datasets["train"].column_names,
)

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
metric = evaluate.load("sacrebleu")

In [None]:
optimizer = AdamW(model.parameters(), lr=2e-5)

tokenized_datasets.set_format("torch")
train_dataloader = DataLoader(
    tokenized_datasets["train"],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=8,
)

eval_dataloader = DataLoader(
    tokenized_datasets["validation"], collate_fn=data_collator, batch_size=8
)

accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [None]:
num_train_epochs = 3
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

model_name = "marian-finetuned-kde4-en-to-ja-accelerate"
repo_name = get_full_repo_name(model_name)

output_dir = "marian-finetuned-kde4-en-to-ja-accelerate"
repo = Repository(output_dir, clone_from=repo_name)

def postprocess(predictions, labels):
    predictions = predictions.cpu().numpy()
    labels = labels.cpu().numpy()

    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [[label.strip()] for label in decoded_labels]

    return decoded_preds, decoded_labels

In [None]:
progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_train_epochs):
    model.train()
    for batch in train_dataloader:
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    model.eval()
    for batch in tqdm(eval_dataloader):
        with torch.no_grad():
            generated_tokens = accelerator.unwrap_model(model).generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
                max_length=128,
            )
        labels = batch["labels"]

        generated_tokens = accelerator.pad_across_processes(
            generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
        )
        labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)

        predictions_gathered = accelerator.gather(generated_tokens)
        labels_gathered = accelerator.gather(labels)

        decoded_preds, decoded_labels = postprocess(predictions_gathered, labels_gathered)
        metric.add_batch(predictions=decoded_preds, references=decoded_labels)

    results = metric.compute()
    print(f"epoch {epoch}, BLEU score: {results['score']:.2f}")

    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
    if accelerator.is_main_process:
        tokenizer.save_pretrained(output_dir)
        repo.push_to_hub(
            commit_message=f"Training in progress epoch {epoch}", blocking=False
        )

model.push_to_hub(tags="translation", commit_message="Training complete")

In [None]:
model_checkpoint = "Hoax0930/marian-finetuned-kde4-en-to-ja"
translator = pipeline("translation", model=model_checkpoint)

In [None]:
sentence = input('原文 : ')
print(f'訳 : {translator(sentence)}')