In [None]:
# You better run this code on Kaggle

# Import Lib

In [None]:
!pip install fairseq

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset
from transformers import get_linear_schedule_with_warmup, T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from fairseq.optim.adafactor import Adafactor
import re, os, time, gc, codecs, string, math, heapq
import numpy as np
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.translate.bleu_score import corpus_bleu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SRC_LANGUAGE = 'eng'
TGT_LANGUAGE = 'vi'
model_name = 't5-base'
task_prefix = "translate English to Vietnamese: "


In [None]:
NUM_EPOCHS = 20
SAMPLES = 1e9
MAX_LEN = 128
continue_training_from_checkpoint = 0 # 0: retrain
checkpoint_path = f'/kaggle/input/translate-machine-model-state-dict/pre_trained_checkpoint_{continue_training_from_checkpoint}.pth'
batch_size = 16

In [None]:
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length = MAX_LEN + 3, add_special_tokens=True)

# Data Preparing

In [None]:
!git clone https://github.com/KCDichDaNgu/KC4.0_MultilingualNMT.git
DATA_DIR = '/kaggle/working/KC4.0_MultilingualNMT/data/iwslt_en_vi'

In [None]:
def preprocess_sentence(text, language):
    # Lowercase the text
#     text = text.lower()
    
    # Remove punctuations
    text = text.translate(str.maketrans('', '', string.punctuation))
    
    # Handle special characters
#     text = re.sub(r'\d+', 'num', text) # Replace all digits with 'num'
    text = re.sub(r'\s+', ' ', text) # Replace multiple whitespaces with a single space
        
#     Tokenize the text
#     words = word_tokenize(text)
    
#     if language == 'eng' and len(words) > 50:
#         stop_words = set(stopwords.words('english'))
#         words = [w for w in words if not w in stop_words]
#     if language == 'vi' and len(words) > 50:
#         f = codecs.open('/kaggle/input/sentiment-analysis-foody/vietnamese-stopwords.txt', encoding='utf-8')
#         stop_words = []
#         for i, line in enumerate(f):
#             line = repr(line)
#             line = line[1:len(line)-3]
#             stop_words.append(line)
#         words = [w for w in words if not w in stop_words]
    # Deal with rare or infrequent words
    if language == 'TGT_LANGUAGE':
        return task_prefix + text.strip()
    return text.strip()

def load_data(source_file, target_file, number_of_examples, MAX_LEN):
    source_sents = open(source_file, "r").readlines()
    target_sents = open(target_file, "r").readlines()
    assert len(source_sents) == len(target_sents)

    source_data, target_data = [], []

    for src_sentence, trg_sentence in zip(source_sents, target_sents):
        if(len(source_data) >= number_of_examples):
            break
        source_data.append(preprocess_sentence(src_sentence, SRC_LANGUAGE))
        target_data.append(preprocess_sentence(trg_sentence, TGT_LANGUAGE))
    return source_data, target_data


In [None]:
train_source_sentences, train_target_sentences = load_data(DATA_DIR+"/train.en", DATA_DIR+"/train.vi", SAMPLES, MAX_LEN)
eval_source_sentences, eval_target_sentences = load_data(DATA_DIR+"/tst2012.en", DATA_DIR+"/tst2012.vi", SAMPLES, 1e9)

In [None]:
# src_encoding = tokenizer.encode_plus('My name is tan and i love my something other plan songoku chicken', padding='max_length', max_length=50, truncation=True, return_tensors='pt')
# {'input_ids': src_encoding['input_ids'][0], 
#                 'attention_mask': src_encoding['attention_mask'][0],}

In [None]:
class TranslationDataset(Dataset):
    def __init__(self, src_texts, tgt_texts, tokenizer, max_len):
        self.src_texts = src_texts
        self.tgt_texts = tgt_texts
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __len__(self):
        return len(self.src_texts)
    
    def __getitem__(self, index):
        src_text = str(self.src_texts[index])
        tgt_text = str(self.tgt_texts[index])
        
        src_encoding = tokenizer.encode_plus(src_text, padding='max_length', max_length=self.max_len, truncation=True, return_tensors='pt')
        tgt_encoding = tokenizer.encode_plus(tgt_text, padding='max_length', max_length=self.max_len, truncation=True, return_tensors='pt')
        labels = tgt_encoding['input_ids'][0].clone() # deep copy
        # replace padding token id's of the labels by -100 so it's ignored by the loss
        labels[labels == tokenizer.pad_token_id] = -100 
        return {'input_ids': src_encoding['input_ids'][0], 
                'attention_mask': src_encoding['attention_mask'][0],
                'decoder_input_ids': tgt_encoding['input_ids'][0],
                'decoder_attention_mask': tgt_encoding['attention_mask'][0],
                'labels': labels}


In [None]:
train_dataset = TranslationDataset(train_source_sentences, train_target_sentences, tokenizer, max_len=MAX_LEN + 3)
eval_dataset = TranslationDataset(eval_source_sentences, eval_target_sentences, tokenizer, max_len=MAX_LEN + 3)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)

# Training Model

# Let's go bruhhh

In [None]:
def train_epoch(epoch, model, optimizer, scheduler, dataloader):
    model.train()
    losses = 0
    total = 0
    for batch_id, batch_data in enumerate(dataloader):
        input_ids = batch_data['input_ids'].to(device)
        attention_mask = batch_data['attention_mask'].to(device)
        labels = batch_data['labels'].to(device)
        
        outputs = model(input_ids=input_ids, 
                        attention_mask=attention_mask, 
                        labels=labels)
        loss = outputs[0]
        
        optimizer.zero_grad()
        loss.backward()
        
        losses += loss.item()
        total += batch_data['input_ids'].size(0)
        
        optimizer.step()
        scheduler.step()

        del input_ids
        del attention_mask
        del labels
        gc.collect()
        torch.cuda.empty_cache()
        if batch_id % 100 == 0 or batch_id < 10:
            print(f"""Total loss: {losses:.4f} | Total: {total} | Loss per batch: {losses/(batch_id + 1):.4f}""")
    
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': losses / len(dataloader),
    }
    if epoch > continue_training_from_checkpoint and epoch % 1 == 0:
        torch.save(checkpoint, f'/kaggle/working/pre_trained_checkpoint_{epoch}.pth')
    del checkpoint
    gc.collect()
    return losses / len(dataloader)

def eval_epoch(model, dataloader):
    model.eval()
    output_sentence = []
    tgt_sentence = []
    total = 0
    for batch_id, batch_data in enumerate(dataloader):
        input_ids = batch_data['input_ids'].to(device)
        attention_mask = batch_data['attention_mask'].to(device)
        decoder_input_ids = batch_data['decoder_input_ids'].to(device)
        
        output_token_id = model.generate(input_ids=input_ids, 
                        attention_mask=attention_mask, max_length = MAX_LEN + 3)
        output_sentence.extend([sentence.split() for sentence in tokenizer.batch_decode(output_token_id, skip_special_tokens=True)])
        tgt_sentence.extend([[sentence.split()] for sentence in tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True)])
        
        total += batch_data['input_ids'].size(0)
        del input_ids
        del attention_mask
        del decoder_input_ids
        gc.collect()
        torch.cuda.empty_cache()
        if batch_id % 10 == 0:
            print(f"""Processed: {total} sequences""")

#     print((tgt_sentence))
#     print(len(tgt_sentence), len(tgt_sentence[0]))
#     print('--------------')
#     print(output_sentence)
#     print(len(output_sentence))
    bleu_score = corpus_bleu(tgt_sentence, output_sentence)

    return bleu_score

# def eval_epoch

In [None]:
# Optimize ram usage 
# del train_source_sentences
# del train_target_sentences 
# del eval_source_sentences
# del eval_target_sentences 
# del train_dataset
# del eval_dataset
gc.collect()
torch.cuda.empty_cache()

model.to(device)
optimizer = Adafactor(model.parameters(), 
                        lr=3e-4, 
                        eps=(1e-30, 1e-3),
                        clip_threshold=1.0,
                        decay_rate=-0.8,
                        beta1=None,
                        weight_decay=0.0,
                        relative_step=False,
                        scale_parameter=False,
                        warmup_init=False)
num_training_steps = len(train_dataloader) * NUM_EPOCHS
num_warmup_steps = 500
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)

if continue_training_from_checkpoint > 0:
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print(f"Epoch: {epoch}, Avg train loss: {loss:.3f}")

bleu_score =  0 # bleu score for epoch 6
for epoch in range(continue_training_from_checkpoint+1, NUM_EPOCHS+1):
    print(f'======== Epoch {epoch} / {NUM_EPOCHS} ========')
    print('Training...')
    start_time = time.time()
    train_loss = train_epoch(epoch, model, optimizer, scheduler, train_dataloader)
    end_time = time.time()
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, "f"Training time = {(end_time - start_time):.3f}s"))
    print('Validating...')
    start_time = time.time()
    bleu_score_tmp = eval_epoch(model, eval_dataloader)
    end_time = time.time()
    print((f"Epoch: {epoch}, Bleu Score: {bleu_score_tmp:.6f}, "f"Validating time = {(end_time - start_time):.3f}s"))
    if bleu_score_tmp - bleu_score >= 1e-4:
        bleu_score = bleu_score_tmp
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': train_loss,
            'bleu_score': bleu_score_tmp
        }
        torch.save(checkpoint, f'/kaggle/working/best_pre_trained.pth')
print('Best bleu score: ', bleu_score)


In [None]:
def eval_by_KC(model, dataloader):
    model.eval()
    output_sentence = []
    tgt_sentence = []
    total = 0
    for batch_id, batch_data in enumerate(dataloader):
        input_ids = batch_data['input_ids'].to(device)
        attention_mask = batch_data['attention_mask'].to(device)
        decoder_input_ids = batch_data['decoder_input_ids'].to(device)
        
        output_token_id = model.generate(input_ids=input_ids, 
                        attention_mask=attention_mask, max_length = MAX_LEN + 3)
        output_sentence.extend(tokenizer.batch_decode(output_token_id, skip_special_tokens=True))
        tgt_sentence.extend(tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True))
        
        total += batch_data['input_ids'].size(0)
        del input_ids
        del attention_mask
        del decoder_input_ids
        gc.collect()
        torch.cuda.empty_cache()
        if batch_id % 10 == 0:
            print(f"""Processed: {total} sequences""")

    with open('/kaggle/working/translate.en2vi.vi', 'w') as f:
        for sentence in output_sentence:
            f.write(sentence + '\n')
    
    with open('/kaggle/working/tst2012.vi', 'w') as f:
        for sentence in tgt_sentence:
            f.write(sentence + '\n')

            
eval_by_KC(model, eval_dataloader)

In [None]:
%cd /kaggle/working/KC4.0_MultilingualNMT/third-party
!perl ./multi-bleu.perl /kaggle/working/translate.en2vi.vi < /kaggle/working/tst2012.vi