<a href="https://colab.research.google.com/github/tienhuynh96/Low_Resource_NMT/blob/main/%5BDemo%5D_NMT_Back_Translation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Back-Translation
- Train model translate Vietnames to English.
- Synthetic Data (create more data):
  + Use trainer (VI-EN) to create sample.
  + Use Reference to create sample or batch sample.
- Train model EN-VI include orginial data and created data

In [None]:
!pip install -q transformers sentencepiece datasets accelerate evaluate sacrebleu

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from datasets import load_dataset
import evaluate
from transformers import (
    MBart50TokenizerFast,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)

  from .autonotebook import tqdm as notebook_tqdm


##**Vi-EN**

In [None]:
class NMTDataset(Dataset):
    def __init__(self, cfg, data_type="train"):
        super().__init__()
        self.cfg = cfg

        self.src_texts, self.tgt_texts = self.read_data(data_type)

        self.src_input_ids = self.texts_to_sequences(self.src_texts)
        self.labels = self.texts_to_sequences(self.tgt_texts)

    def read_data(self, data_type):
        data = load_dataset(
            "mt_eng_vietnamese",
            "iwslt2015-en-vi",
            split=data_type
        )
        src_texts = [sample["translation"][self.cfg.src_lang] for sample in data]
        tgt_texts = [sample["translation"][self.cfg.tgt_lang] for sample in data]
        return src_texts, tgt_texts

    def texts_to_sequences(self, texts):
        data_inputs = self.cfg.tokenizer(
            texts,
            padding='max_length',
            truncation=True,
            max_length=self.cfg.max_len,
            return_tensors='pt'
        )
        return data_inputs.input_ids

    def __getitem__(self, idx):
        return {
            "input_ids": self.src_input_ids[idx],
            "labels": self.labels[idx]
        }

    def __len__(self):
        return np.shape(self.src_input_ids)[0]

In [None]:
class BaseConfig:
    """ base Encoder Decoder config """

    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

class NMTConfig(BaseConfig):
    # Data
    src_lang = 'vi'
    tgt_lang = 'en'
    max_len = 75
    add_special_tokens = True

    # Model
    model_name = "facebook/mbart-large-50-many-to-many-mmt"

    # Training
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    learning_rate = 5e-5
    train_batch_size = 16
    eval_batch_size = 16
    num_train_epochs = 2
    save_total_limit = 1
    ckpt_dir = f'./mbart50-{src_lang}-{tgt_lang}'
    eval_steps = 1000

    # Inference
    beam_size = 5

cfg = NMTConfig()

In [None]:
# tokenizer = MBart50TokenizerFast.from_pretrained(cfg.model_name, src_lang="en_XX",tgt_lang = "vi_VN")
cfg.tokenizer = MBart50TokenizerFast.from_pretrained(cfg.model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(cfg.model_name)

In [None]:
metric = evaluate.load("sacrebleu")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    preds= np.where(preds != -100, preds, cfg.tokenizer.pad_token_id)
    decoded_preds = cfg.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)

    labels= np.where(labels != -100, labels, cfg.tokenizer.pad_token_id)
    decoded_labels = cfg.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != cfg.tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}

    return result

In [None]:
train_dataset = NMTDataset(cfg, data_type="train")
valid_dataset = NMTDataset(cfg, data_type="validation")
test_dataset = NMTDataset(cfg, data_type="test")

Found cached dataset mt_eng_vietnamese (/home/server-ailab-12gb/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71)
Found cached dataset mt_eng_vietnamese (/home/server-ailab-12gb/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71)
Found cached dataset mt_eng_vietnamese (/home/server-ailab-12gb/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71)


In [None]:
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    save_strategy='steps',
    save_steps=cfg.eval_steps,
    eval_steps=cfg.eval_steps,
    output_dir=cfg.ckpt_dir,
    per_device_train_batch_size=cfg.train_batch_size,
    per_device_eval_batch_size=cfg.eval_batch_size,
    learning_rate=cfg.learning_rate,
    save_total_limit=cfg.save_total_limit,
    num_train_epochs=cfg.num_train_epochs,
    load_best_model_at_end=True,
)

data_collator = DataCollatorForSeq2Seq(
    cfg.tokenizer,
    model=model
)

trainer = Seq2SeqTrainer(
    model,
    training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=data_collator,
    tokenizer=cfg.tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()

  0%|          | 0/8334 [00:00<?, ?it/s]You're using a MBart50TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  6%|▌         | 500/8334 [05:45<1:24:48,  1.54it/s]

{'loss': 0.6006, 'learning_rate': 4.700023998080154e-05, 'epoch': 0.12}


 12%|█▏        | 1000/8334 [11:10<1:19:31,  1.54it/s]

{'loss': 0.3267, 'learning_rate': 4.4000479961603073e-05, 'epoch': 0.24}


                                                     
 12%|█▏        | 1000/8334 [12:36<1:19:31,  1.54it/s]

{'eval_loss': 0.4115850627422333, 'eval_bleu': 35.2226, 'eval_gen_len': 30.632, 'eval_runtime': 85.9398, 'eval_samples_per_second': 14.766, 'eval_steps_per_second': 0.465, 'epoch': 0.24}


 18%|█▊        | 1500/8334 [18:30<1:18:43,  1.45it/s] 

{'loss': 0.3224, 'learning_rate': 4.100071994240461e-05, 'epoch': 0.36}


 24%|██▍       | 2000/8334 [24:15<1:14:02,  1.43it/s]

{'loss': 0.3238, 'learning_rate': 3.8000959923206145e-05, 'epoch': 0.48}


                                                     
 24%|██▍       | 2000/8334 [25:40<1:14:02,  1.43it/s]

{'eval_loss': 0.40333837270736694, 'eval_bleu': 36.3721, 'eval_gen_len': 30.2963, 'eval_runtime': 85.3961, 'eval_samples_per_second': 14.86, 'eval_steps_per_second': 0.468, 'epoch': 0.48}


 30%|██▉       | 2500/8334 [31:11<1:03:20,  1.53it/s] 

{'loss': 0.3192, 'learning_rate': 3.500119990400768e-05, 'epoch': 0.6}


 36%|███▌      | 3000/8334 [36:37<57:53,  1.54it/s]  

{'loss': 0.3138, 'learning_rate': 3.2001439884809216e-05, 'epoch': 0.72}


                                                   
 36%|███▌      | 3000/8334 [37:58<57:53,  1.54it/s]

{'eval_loss': 0.39890530705451965, 'eval_bleu': 36.0185, 'eval_gen_len': 30.591, 'eval_runtime': 81.1346, 'eval_samples_per_second': 15.641, 'eval_steps_per_second': 0.493, 'epoch': 0.72}


 42%|████▏     | 3500/8334 [43:30<52:32,  1.53it/s]   

{'loss': 0.3121, 'learning_rate': 2.9001679865610755e-05, 'epoch': 0.84}


 48%|████▊     | 4000/8334 [49:16<50:08,  1.44it/s]

{'loss': 0.3062, 'learning_rate': 2.6001919846412287e-05, 'epoch': 0.96}


                                                   
 48%|████▊     | 4000/8334 [50:47<50:08,  1.44it/s]

{'eval_loss': 0.3957849442958832, 'eval_bleu': 37.5987, 'eval_gen_len': 30.0969, 'eval_runtime': 91.6207, 'eval_samples_per_second': 13.851, 'eval_steps_per_second': 0.437, 'epoch': 0.96}


 54%|█████▍    | 4500/8334 [56:21<41:35,  1.54it/s]   

{'loss': 0.2427, 'learning_rate': 2.3002159827213822e-05, 'epoch': 1.08}


 60%|█████▉    | 5000/8334 [1:01:46<36:08,  1.54it/s]

{'loss': 0.217, 'learning_rate': 2.000239980801536e-05, 'epoch': 1.2}


                                                     
 60%|█████▉    | 5000/8334 [1:03:07<36:08,  1.54it/s]

{'eval_loss': 0.4291688799858093, 'eval_bleu': 36.1415, 'eval_gen_len': 30.2734, 'eval_runtime': 80.4327, 'eval_samples_per_second': 15.777, 'eval_steps_per_second': 0.497, 'epoch': 1.2}


 66%|██████▌   | 5500/8334 [1:08:39<30:44,  1.54it/s]   

{'loss': 0.2182, 'learning_rate': 1.7002639788816897e-05, 'epoch': 1.32}


 72%|███████▏  | 6000/8334 [1:14:04<25:18,  1.54it/s]

{'loss': 0.2142, 'learning_rate': 1.4002879769618432e-05, 'epoch': 1.44}


                                                     
 72%|███████▏  | 6000/8334 [1:15:26<25:18,  1.54it/s]

{'eval_loss': 0.42046114802360535, 'eval_bleu': 36.0593, 'eval_gen_len': 30.5122, 'eval_runtime': 81.5502, 'eval_samples_per_second': 15.561, 'eval_steps_per_second': 0.49, 'epoch': 1.44}


 78%|███████▊  | 6500/8334 [1:20:57<19:54,  1.54it/s]   

{'loss': 0.2167, 'learning_rate': 1.1003119750419968e-05, 'epoch': 1.56}


 84%|████████▍ | 7000/8334 [1:26:23<14:30,  1.53it/s]

{'loss': 0.2119, 'learning_rate': 8.003359731221503e-06, 'epoch': 1.68}


                                                     
 84%|████████▍ | 7000/8334 [1:27:44<14:30,  1.53it/s]

{'eval_loss': 0.41834622621536255, 'eval_bleu': 36.3178, 'eval_gen_len': 30.316, 'eval_runtime': 81.2804, 'eval_samples_per_second': 15.613, 'eval_steps_per_second': 0.492, 'epoch': 1.68}


 90%|████████▉ | 7500/8334 [1:33:16<09:02,  1.54it/s]  

{'loss': 0.2103, 'learning_rate': 5.003599712023038e-06, 'epoch': 1.8}


 96%|█████████▌| 8000/8334 [1:38:46<03:52,  1.44it/s]

{'loss': 0.2099, 'learning_rate': 2.003839692824574e-06, 'epoch': 1.92}


                                                     
 96%|█████████▌| 8000/8334 [1:40:16<03:52,  1.44it/s]

{'eval_loss': 0.4232094883918762, 'eval_bleu': 35.9713, 'eval_gen_len': 30.5201, 'eval_runtime': 90.3194, 'eval_samples_per_second': 14.05, 'eval_steps_per_second': 0.443, 'epoch': 1.92}


100%|██████████| 8334/8334 [1:44:22<00:00,  1.75it/s]  

{'train_runtime': 6262.0212, 'train_samples_per_second': 42.58, 'train_steps_per_second': 1.331, 'train_loss': 0.28221002299064124, 'epoch': 2.0}


100%|██████████| 8334/8334 [1:44:22<00:00,  1.33it/s]


TrainOutput(global_step=8334, training_loss=0.28221002299064124, metrics={'train_runtime': 6262.0212, 'train_samples_per_second': 42.58, 'train_steps_per_second': 1.331, 'train_loss': 0.28221002299064124, 'epoch': 2.0})

In [None]:
test_prediction = trainer.predict(test_dataset)
test_prediction

You're using a MBart50TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|██████████| 40/40 [01:28<00:00,  2.21s/it]


PredictionOutput(predictions=array([[     2, 250004,  14847, ...,      1,      1,      1],
       [     2, 250004,     87, ...,      1,      1,      1],
       [     2, 250004,    360, ...,      1,      1,      1],
       ...,
       [     2, 250004,     87, ...,      1,      1,      1],
       [     2, 250004,  25689, ...,      1,      1,      1],
       [     2, 250004,      2, ...,      1,      1,      1]]), label_ids=array([[250004,  14847,     87, ...,      1,      1,      1],
       [250004,   3493,     87, ...,      1,      1,      1],
       [250004,    360,  10696, ...,      1,      1,      1],
       ...,
       [250004,     87,  15673, ...,      1,      1,      1],
       [250004,  25689,    398, ...,      1,      1,      1],
       [250004,      2,      1, ...,      1,      1,      1]]), metrics={'test_loss': 0.3957849442958832, 'test_bleu': 37.5987, 'test_gen_len': 30.0969, 'test_runtime': 90.3645, 'test_samples_per_second': 14.043, 'test_steps_per_second': 0.443})

##**Synthetic Data**

Vi Monolingual Data: https://drive.google.com/drive/folders/1brbflvjGcF7GqDQFxafUbGcbf68jCB0g?usp=sharing

**#Cach 1: Su dung trainer => tạo dataset**

In [None]:
class PhoNMTDataset(Dataset):
    def __init__(self, cfg, src_texts, tgt_texts):
        super().__init__()
        self.cfg = cfg

        self.src_texts, self.tgt_texts = src_texts, tgt_texts

        self.src_input_ids = self.texts_to_sequences(self.src_texts)
        self.labels = self.texts_to_sequences(self.tgt_texts)

    def texts_to_sequences(self, texts):
        data_inputs = self.cfg.tokenizer(
            texts,
            padding='max_length',
            truncation=True,
            max_length=self.cfg.max_len,
            return_tensors='pt'
        )
        return data_inputs.input_ids

    def __getitem__(self, idx):
        return {
            "input_ids": self.src_input_ids[idx],
            "labels": self.labels[idx]
        }

    def __len__(self):
        return np.shape(self.src_input_ids)[0]

In [None]:
with open('./test.vi', 'r') as f:
    vi_phomt = f.readlines()

In [None]:
vi_phomt[:5]

['Anh Albert Barnett và chị Susan Barnett , thuộc hội thánh West ở Tuscaloosa , Alabama\n',
 'Ngày 11 và 12-1-2020 , những cơn bão lớn đã quét qua và phá huỷ nhiều vùng ở miền nam và miền trung Hoa Kỳ .\n',
 'Những trận mưa to và gió lớn trong suốt hai ngày cùng với nhiều cơn lốc xoáy đã gây thiệt hại nặng nề cho nhiều bang .\n',
 'Đáng buồn là anh Albert Barnett 85 tuổi , và vợ anh là chị Susan Barnett 75 tuổi đã thiệt mạng do một cơn lốc xoáy quét qua nhà họ .\n',
 'Chi nhánh Hoa Kỳ cũng cho biết có ít nhất bốn căn nhà của anh em chúng tôi và hai Phòng Nước Trời bị hư hại nhẹ .\n']

In [None]:
len(vi_phomt)

19151

In [None]:
phomt_dataset = PhoNMTDataset(cfg, vi_phomt, vi_phomt)

In [None]:
pred = trainer.predict(phomt_dataset)

You're using a MBart50TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|██████████| 1197/1197 [19:37<00:00,  1.02it/s]


In [None]:
pred

PredictionOutput(predictions=array([[     2, 250004,  24748, ...,      1,      1,      1],
       [     2, 250004,   2161, ...,      1,      1,      1],
       [     2, 250004,    581, ...,      1,      1,      1],
       ...,
       [     2, 250004,    360, ...,      1,      1,      1],
       [     2, 250004,   3164, ...,      1,      1,      1],
       [     2, 250004,  48176, ...,      1,      1,      1]]), label_ids=array([[250004,  67921,  24748, ...,      1,      1,      1],
       [250004,  48752,     13, ...,      1,      1,      1],
       [250004,  32964,  13312, ...,      1,      1,      1],
       ...,
       [250004,    360,  15824, ...,      1,      1,      1],
       [250004,  64511,  13909, ...,      1,      1,      1],
       [250004,     44,  83425, ...,      1,      1,      1]]), metrics={'test_loss': 0.6733054518699646, 'test_bleu': 24.9305, 'test_gen_len': 24.3073, 'test_runtime': 1178.7906, 'test_samples_per_second': 16.246, 'test_steps_per_second': 1.015})

In [None]:
en_phomt_preds = cfg.tokenizer.batch_decode(pred.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True)

In [None]:
en_phomt_preds[:10]

['Albert Barnett and his sister Susan Barnett, from the West Church in Tuscaloosa, Alabama.',
 'On November 11 and 12, 2020, major storms swept through and destroyed parts of the south and central United States.',
 'The heavy rains and winds over the course of two days, coupled with the heavy storm surge, caused extensive damage in many states.',
 'Sadly, Albert Barnett was 85, and his wife Susan Barnett was 75 when a tornado struck their home.',
 'The U.S. branch also reported that at least four of our brother &apos;s homes and two of our guest houses suffered minor damage.',
 'Also, storms can devastate a brother &apos;s business.',
 'Local elders and survivors are helping and providing material and spiritual support to those affected by the disaster.',
 'We believe that our heavenly Father, Jesus Christ, is comforting our brothers and sisters in their grief.',
 'International agencies and government officials have spoken out in response to the ruling of the Russian Supreme Court ban

In [None]:
vi_phomt[:10]

['Anh Albert Barnett và chị Susan Barnett , thuộc hội thánh West ở Tuscaloosa , Alabama\n',
 'Ngày 11 và 12-1-2020 , những cơn bão lớn đã quét qua và phá huỷ nhiều vùng ở miền nam và miền trung Hoa Kỳ .\n',
 'Những trận mưa to và gió lớn trong suốt hai ngày cùng với nhiều cơn lốc xoáy đã gây thiệt hại nặng nề cho nhiều bang .\n',
 'Đáng buồn là anh Albert Barnett 85 tuổi , và vợ anh là chị Susan Barnett 75 tuổi đã thiệt mạng do một cơn lốc xoáy quét qua nhà họ .\n',
 'Chi nhánh Hoa Kỳ cũng cho biết có ít nhất bốn căn nhà của anh em chúng tôi và hai Phòng Nước Trời bị hư hại nhẹ .\n',
 'Ngoài ra , những cơn bão cũng gây hư hại lớn cho cơ sở kinh doanh của một anh em .\n',
 'Các trưởng lão địa phương và giám thị xung quanh đang giúp đỡ và cung cấp về vật chất và tinh thần cho các anh chị bị ảnh hưởng trong thảm hoạ này .\n',
 'Chúng ta tin chắc rằng Cha trên trời , Đức Giê-hô-va , đang an ủi những anh chị em của chúng ta trong cảnh đau buồn .\n',
 'Các cơ quan và viên chức chính phủ quốc

**Cach 2: Sinh theo sample hoặc theo batch**

In [None]:
def inference(
    text,
    tokenizer,
    model,
    device="cpu",
    max_length=75,
    beam_size=5
    ):
    inputs = tokenizer(
        text,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
        )
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)
    model.to(device)

    outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_length=max_length,
        early_stopping=True,
        num_beams=beam_size,
        length_penalty=2.0
    )

    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    return output_str[0]

en_phomt_preds = []
for sent in vi_phomt:
    en_sent = inference(sent, cfg.tokenizer, model)
    en_phomt_preds.append(en_sent)

##**En-VI**

In [None]:
# Add created data above to dataset

In [None]:
class NMTDataset(Dataset):
    def __init__(self, cfg, src_augs=[], tgt_augs=[], data_type="train"):
        super().__init__()
        self.cfg = cfg

        self.src_texts, self.tgt_texts = self.read_data(data_type)
        if data_type == 'train':
            self.src_texts.extend(src_augs)
            self.tgt_texts.extend(tgt_augs)

        self.src_input_ids = self.texts_to_sequences(self.src_texts)
        self.labels = self.texts_to_sequences(self.tgt_texts)

    def read_data(self, data_type):
        data = load_dataset(
            "mt_eng_vietnamese",
            "iwslt2015-en-vi",
            split=data_type
        )
        src_texts = [sample["translation"][self.cfg.src_lang] for sample in data]
        tgt_texts = [sample["translation"][self.cfg.tgt_lang] for sample in data]
        return src_texts, tgt_texts

    def texts_to_sequences(self, texts):
        data_inputs = self.cfg.tokenizer(
            texts,
            padding='max_length',
            truncation=True,
            max_length=self.cfg.max_len,
            return_tensors='pt'
        )
        return data_inputs.input_ids

    def __getitem__(self, idx):
        return {
            "input_ids": self.src_input_ids[idx],
            "labels": self.labels[idx]
        }

    def __len__(self):
        return np.shape(self.src_input_ids)[0]

In [None]:
class BaseConfig:
    """ base Encoder Decoder config """

    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

class NMTConfig(BaseConfig):
    # Data
    src_lang = 'en'
    tgt_lang = 'vi'
    max_len = 75
    add_special_tokens = True

    # Model
    model_name = "facebook/mbart-large-50-many-to-many-mmt"

    # Training
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    learning_rate = 5e-5
    train_batch_size = 32
    eval_batch_size = 32
    num_train_epochs = 2
    save_total_limit = 1
    ckpt_dir = f'./mbart50-{src_lang}-{tgt_lang}-backtranslation-2'
    eval_steps = 1000

    # Inference
    beam_size = 5

cfg = NMTConfig()

In [None]:
# tokenizer = MBart50TokenizerFast.from_pretrained(cfg.model_name, src_lang="en_XX",tgt_lang = "vi_VN")
cfg.tokenizer = MBart50TokenizerFast.from_pretrained(cfg.model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(cfg.model_name)

In [None]:
train_dataset = NMTDataset(cfg, en_phomt_preds, vi_phomt, data_type="train")
valid_dataset = NMTDataset(cfg, data_type="validation")
test_dataset = NMTDataset(cfg, data_type="test")

Found cached dataset mt_eng_vietnamese (/home/server-ailab-12gb/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71)
Found cached dataset mt_eng_vietnamese (/home/server-ailab-12gb/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71)
Found cached dataset mt_eng_vietnamese (/home/server-ailab-12gb/.cache/huggingface/datasets/mt_eng_vietnamese/iwslt2015-en-vi/1.0.0/53add551a01e9874588066f89d42925f9fad43db347199dad00f7e4b0c905a71)


In [None]:
len(train_dataset)

152469

In [None]:
metric = evaluate.load("sacrebleu")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    preds= np.where(preds != -100, preds, cfg.tokenizer.pad_token_id)
    decoded_preds = cfg.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)

    labels= np.where(labels != -100, labels, cfg.tokenizer.pad_token_id)
    decoded_labels = cfg.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != cfg.tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}

    return result

In [None]:
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    save_strategy='steps',
    save_steps=cfg.eval_steps,
    eval_steps=cfg.eval_steps,
    output_dir=cfg.ckpt_dir,
    per_device_train_batch_size=cfg.train_batch_size,
    per_device_eval_batch_size=cfg.eval_batch_size,
    learning_rate=cfg.learning_rate,
    save_total_limit=cfg.save_total_limit,
    num_train_epochs=cfg.num_train_epochs,
    load_best_model_at_end=True,
)

data_collator = DataCollatorForSeq2Seq(
    cfg.tokenizer,
    model=model
)

trainer = Seq2SeqTrainer(
    model,
    training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=data_collator,
    tokenizer=cfg.tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()

  0%|          | 0/9530 [00:00<?, ?it/s]You're using a MBart50TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  5%|▌         | 500/9530 [05:45<1:42:47,  1.46it/s]

{'loss': 0.7768, 'learning_rate': 4.737670514165792e-05, 'epoch': 0.1}


 10%|█         | 1000/9530 [11:24<1:38:14,  1.45it/s]

{'loss': 0.4776, 'learning_rate': 4.475341028331584e-05, 'epoch': 0.21}


                                                     
 10%|█         | 1000/9530 [13:02<1:38:14,  1.45it/s]

{'eval_loss': 0.568303108215332, 'eval_bleu': 32.9774, 'eval_gen_len': 32.948, 'eval_runtime': 97.3289, 'eval_samples_per_second': 13.038, 'eval_steps_per_second': 0.411, 'epoch': 0.21}


 16%|█▌        | 1500/9530 [18:55<1:26:51,  1.54it/s] 

{'loss': 0.473, 'learning_rate': 4.213011542497377e-05, 'epoch': 0.31}


 21%|██        | 2000/9530 [24:20<1:21:28,  1.54it/s]

{'loss': 0.4677, 'learning_rate': 3.950682056663169e-05, 'epoch': 0.42}


                                                     
 21%|██        | 2000/9530 [25:45<1:21:28,  1.54it/s]

{'eval_loss': 0.5526649951934814, 'eval_bleu': 33.5444, 'eval_gen_len': 32.6872, 'eval_runtime': 85.5442, 'eval_samples_per_second': 14.834, 'eval_steps_per_second': 0.468, 'epoch': 0.42}


 26%|██▌       | 2500/9530 [31:15<1:15:57,  1.54it/s] 

{'loss': 0.4618, 'learning_rate': 3.688352570828961e-05, 'epoch': 0.52}


 31%|███▏      | 3000/9530 [36:40<1:10:35,  1.54it/s]

{'loss': 0.4611, 'learning_rate': 3.4260230849947537e-05, 'epoch': 0.63}


                                                     
 31%|███▏      | 3000/9530 [38:02<1:10:35,  1.54it/s]

{'eval_loss': 0.540764570236206, 'eval_bleu': 33.9756, 'eval_gen_len': 32.6391, 'eval_runtime': 82.5237, 'eval_samples_per_second': 15.377, 'eval_steps_per_second': 0.485, 'epoch': 0.63}


 37%|███▋      | 3500/9530 [43:33<1:06:43,  1.51it/s] 

{'loss': 0.4518, 'learning_rate': 3.163693599160546e-05, 'epoch': 0.73}


 42%|████▏     | 4000/9530 [48:58<59:49,  1.54it/s]  

{'loss': 0.4515, 'learning_rate': 2.901364113326338e-05, 'epoch': 0.84}


                                                   
 42%|████▏     | 4000/9530 [50:22<59:49,  1.54it/s]

{'eval_loss': 0.5334458947181702, 'eval_bleu': 34.1069, 'eval_gen_len': 32.9582, 'eval_runtime': 84.4311, 'eval_samples_per_second': 15.03, 'eval_steps_per_second': 0.474, 'epoch': 0.84}


 47%|████▋     | 4500/9530 [55:53<54:25,  1.54it/s]   

{'loss': 0.4487, 'learning_rate': 2.63903462749213e-05, 'epoch': 0.94}


 52%|█████▏    | 5000/9530 [1:01:17<49:01,  1.54it/s]

{'loss': 0.4043, 'learning_rate': 2.3767051416579224e-05, 'epoch': 1.05}


                                                     
 52%|█████▏    | 5000/9530 [1:02:41<49:01,  1.54it/s]

{'eval_loss': 0.5374477505683899, 'eval_bleu': 34.2478, 'eval_gen_len': 32.6856, 'eval_runtime': 83.8705, 'eval_samples_per_second': 15.13, 'eval_steps_per_second': 0.477, 'epoch': 1.05}


 58%|█████▊    | 5500/9530 [1:08:12<43:36,  1.54it/s]   

{'loss': 0.363, 'learning_rate': 2.1143756558237147e-05, 'epoch': 1.15}


 63%|██████▎   | 6000/9530 [1:13:37<38:12,  1.54it/s]

{'loss': 0.3635, 'learning_rate': 1.852046169989507e-05, 'epoch': 1.26}


                                                     
 63%|██████▎   | 6000/9530 [1:15:00<38:12,  1.54it/s]

{'eval_loss': 0.5342561602592468, 'eval_bleu': 34.3835, 'eval_gen_len': 32.7597, 'eval_runtime': 83.6559, 'eval_samples_per_second': 15.169, 'eval_steps_per_second': 0.478, 'epoch': 1.26}


 68%|██████▊   | 6500/9530 [1:20:31<32:45,  1.54it/s]   

{'loss': 0.3669, 'learning_rate': 1.589716684155299e-05, 'epoch': 1.36}


 73%|███████▎  | 7000/9530 [1:25:55<28:26,  1.48it/s]

{'loss': 0.3625, 'learning_rate': 1.3273871983210914e-05, 'epoch': 1.47}


                                                     
 73%|███████▎  | 7000/9530 [1:27:24<28:26,  1.48it/s]

{'eval_loss': 0.5296697020530701, 'eval_bleu': 34.4664, 'eval_gen_len': 32.8731, 'eval_runtime': 88.0977, 'eval_samples_per_second': 14.404, 'eval_steps_per_second': 0.454, 'epoch': 1.47}


 79%|███████▊  | 7500/9530 [1:33:15<21:55,  1.54it/s]   

{'loss': 0.3614, 'learning_rate': 1.0650577124868834e-05, 'epoch': 1.57}


 84%|████████▍ | 8000/9530 [1:38:40<16:31,  1.54it/s]

{'loss': 0.3565, 'learning_rate': 8.027282266526758e-06, 'epoch': 1.68}


                                                     
 84%|████████▍ | 8000/9530 [1:40:03<16:31,  1.54it/s]

{'eval_loss': 0.5277815461158752, 'eval_bleu': 34.7893, 'eval_gen_len': 32.6509, 'eval_runtime': 83.923, 'eval_samples_per_second': 15.121, 'eval_steps_per_second': 0.477, 'epoch': 1.68}


 89%|████████▉ | 8500/9530 [1:45:49<11:54,  1.44it/s]   

{'loss': 0.3587, 'learning_rate': 5.403987408184681e-06, 'epoch': 1.78}


 94%|█████████▍| 9000/9530 [1:51:36<06:08,  1.44it/s]

{'loss': 0.3518, 'learning_rate': 2.7806925498426024e-06, 'epoch': 1.89}


                                                     
 94%|█████████▍| 9000/9530 [1:53:10<06:08,  1.44it/s]

{'eval_loss': 0.5230357646942139, 'eval_bleu': 35.2273, 'eval_gen_len': 32.7801, 'eval_runtime': 93.7786, 'eval_samples_per_second': 13.532, 'eval_steps_per_second': 0.427, 'epoch': 1.89}


100%|█████████▉| 9500/9530 [1:59:04<00:20,  1.45it/s]  

{'loss': 0.3538, 'learning_rate': 1.5739769150052467e-07, 'epoch': 1.99}


100%|██████████| 9530/9530 [1:59:26<00:00,  1.57it/s]

{'train_runtime': 7166.4681, 'train_samples_per_second': 42.551, 'train_steps_per_second': 1.33, 'train_loss': 0.426721374435665, 'epoch': 2.0}


100%|██████████| 9530/9530 [1:59:26<00:00,  1.33it/s]


TrainOutput(global_step=9530, training_loss=0.426721374435665, metrics={'train_runtime': 7166.4681, 'train_samples_per_second': 42.551, 'train_steps_per_second': 1.33, 'train_loss': 0.426721374435665, 'epoch': 2.0})

In [None]:
trainer.evaluate(test_dataset)

100%|██████████| 40/40 [01:31<00:00,  2.29s/it]


{'eval_loss': 0.5230357646942139,
 'eval_bleu': 35.2273,
 'eval_gen_len': 32.7801,
 'eval_runtime': 93.6366,
 'eval_samples_per_second': 13.552,
 'eval_steps_per_second': 0.427,
 'epoch': 2.0}

##**Checkpoint**
https://drive.google.com/drive/folders/1ii_lPm2-1CfIhQM8RVzLgTHMxXDKgnk4?usp=sharing