<a href="https://colab.research.google.com/github/yugonojima/NLP100Exercise-Chapter10-Transformer/blob/main/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install fugashi[unidic-lite] sentencepiece datasets transformers sacrebleu

In [None]:
%cd /content/drive/MyDrive/NLP100本ノック/Chapter10/

In [3]:
from datasets import Dataset, DatasetDict

In [4]:
train_en_path = './tok/kyoto-train.en'
train_ja_path = './tok/kyoto-train.ja'
train_cln_en_path = './tok/kyoto-train.cln.en'
train_cln_ja_path = './tok/kyoto-train.cln.ja'
tune_en_path = './tok/kyoto-tune.en'
tune_ja_path = './tok/kyoto-tune.ja'
dev_en_path = './tok/kyoto-dev.en'
dev_ja_path = './tok/kyoto-dev.ja'
test_en_path = './tok/kyoto-test.en'
test_ja_path = './tok/kyoto-test.ja'

seed = 0

#Creating the dataset

In [5]:
class NMTDataset:
    def __init__(self, problem_type='ja-en'):
        self.problem_type = 'ja_en'
        self.inp_lang_tokenizer = None
        self.targ_lang_tokenizer = None

    def inp_Preprocess(self, text):
        text = text.replace('\n', '')
        text = list(reversed(text.split()))
        text = ' '.join(text)
        text = "<start> "+text+" <end>"
        return text
    
    def targ_Preprocess(self, text):
        text = text.replace('\n', '')
        # text = text.lower()
        text = '<start> '+text+' <end>'
        return text




    def Read(self, path, num_examples=None, inp = False):
        new_lines = []
        with open(path, 'r') as f:
            lines = f.readlines()
        for line in lines:
            if inp:
                new_lines.append(self.inp_Preprocess(line))
            else:
                new_lines.append(self.targ_Preprocess(line))
        if num_examples == None:
            return new_lines
        else:
            return new_lines[:num_examples]

    
    def make_dataset(self, inp_path, targ_path):
        inp_texts = self.Read(inp_path)
        targ_texts = self.Read(targ_path)
        translation_list = []
        for ja_text, en_text in zip(inp_texts, targ_texts):
            translation_list.append({'ja':ja_text, 'en':en_text})

        translation_dict = {'translation':translation_list}
        translation_dataset = Dataset.from_dict(translation_dict)

        return translation_dataset

In [6]:
dataset = NMTDataset()
train_dataset = dataset.make_dataset(train_cln_ja_path,train_cln_en_path)
validation_dataset = dataset.make_dataset(tune_ja_path,tune_en_path)
test_dataset = dataset.make_dataset(dev_ja_path,dev_en_path)

raw_datasets = DatasetDict({"train":train_dataset,
             "validation":validation_dataset,
             "test":test_dataset})

In [None]:
import datasets
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [None]:
show_random_elements(raw_datasets["train"])

Unnamed: 0,translation
0,"{'en': '<start> Please think about it . ' <end>', 'ja': '<start> 「 ぜひ 思いとどま り くださ い 。 」 <end>'}"
1,"{'en': '<start> 1951 : Kyoto Prefectural University Women 's College was established . <end>', 'ja': '<start> 1951 年 女子 短期 大学 部 を 設置 <end>'}"
2,"{'en': '<start> His second son was born . <end>', 'ja': '<start> 次男 ・ 不二 彦 誕生 。 <end>'}"
3,"{'en': '<start> The kamekan burial system quickly declined in the end period and was replaced by hole-shaped grave with stone lid and hakoshiki-sekkanbo ( box-shaped stone coffin grave ) . <end>', 'ja': '<start> 甕棺 墓制 は 後期 に は 急速 に 衰退 し て 石蓋 土壙 墓 ・ 箱 式 石棺 墓 など に 取 っ て 代わ ら れ た 。 <end>'}"
4,"{'en': '<start> The size of ihai is measured in the shakkan-ho ( the system of measuring length in shaku and weight in kan ) . <end>', 'ja': '<start> 位牌 の サイズ は 尺貫 法 で 表 さ れ る 。 <end>'}"


In [None]:
from datasets import load_metric

metric = load_metric("sacrebleu")

## Preprocessing the data

In [7]:
model_checkpoint = "Helsinki-NLP/opus-mt-ja-en"

In [None]:
from transformers import AutoTokenizer
import sentencepiece as spm

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [13]:
if "mbart" in model_checkpoint:
    tokenizer.src_lang = "ja_XX"
    tokenizer.tgt_lang = "en_XX"

In [15]:
max_input_length = 128
max_target_length = 128
source_lang = "ja"
target_lang = "en"

def preprocess_function(examples):
    inputs = [prefix + ex[source_lang] for ex in examples["translation"]]
    targets = [ex[target_lang] for ex in examples["translation"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

## Fine-tuning the model

In [None]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

In [None]:
batch_size = 16
model_name = model_checkpoint.split("/")[-1]
args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned-{source_lang}-to-{target_lang}-epoch2",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    num_train_epochs=1,
    predict_with_generate=True,
    fp16=True,
)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
import numpy as np

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [None]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

Using cuda_amp half precision backend


In [None]:
trainer.train()

## Test fine_tuned model

In [10]:
saved_model = model.from_pretrained("./opus-mt-ja-en-finetuned-ja-to-en/checkpoint-20500")

In [18]:
from transformers import pipeline

translate = pipeline('translation', model=saved_model, tokenizer=tokenizer)

In [19]:
for i in range(3):
    ja_text = raw_datasets["validation"][i]['translation']['ja']
    en_text = raw_datasets["validation"][i]['translation']['en']
    translated = translate(ja_text)
    print(' ')
    print(f'ja_text:{ja_text}')
    print(f'en_text:{en_text}')
    print(translated[0])



 
ja_text:<start> 浄土 真宗 （ じょうど しんしゅう 、 Shin - Buddhism , Pure Land Buddhism ） は 、 日本 の 仏教 の 宗派 の ひと つ で 、 鎌倉 時代 初期 、 法然 の 弟子 ・ 親鸞 が 、 法然 の 教え （ 浄土 宗 ） を 継承 発展 さ せ た 教団 で あ る 。 <end>
en_text:<start> Jodo Shinshu ( Shin-Buddhism / True Pure Land Sect ) is one of the sects of Japanese Buddhism , and a religious community that Shinran , an apprentice of Honen , succeeded and which developed Honen 's doctrine ( Jodo Shu / Pure Land Buddhism ) in the early Kamakura period . <end>
{'translation_text': '<start> Jodo Shinshu is one of the Buddhist sects in Japan , and a disciple of Honen in the early Kamakura period , Shinran developed the teachings of Honen ( Jodo sect ) . <end>'}
 
ja_text:<start> 宗派 名 の 成り立ち の 歴史 的 経緯 から 、 現在 、 同宗 に 属 する 宗派 の 多く が 宗旨 名 と し て は 真宗 を 名乗 る 。 <end>
en_text:<start> Due to the historical background of the origin of the sect name , most sects belonging to this religion call themselves Shinshu as a sect name . <end>
{'translation_text': '<start> Due to the historica