In [None]:
import warnings
warnings.filterwarnings('ignore')

import sys
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import pandas as pd
import sentencepiece as spm
import tensorflow as tf
import numpy as np
from tqdm import tqdm 
from sklearn.model_selection import train_test_split



# set path for module import from user's directory
sys.path.insert(0, os.getenv('HOME') + '/saturi_lab_multi_nmt_low_resource/src/training/')
sys.path.insert(1, os.getenv('HOME') + '/saturi_lab_multi_nmt_low_resource/src/utils')
sys.path.insert(2, os.getenv('HOME') + '/saturi_lab_multi_nmt_low_resource/src/models/baseline/')

from dataset_util import CustomDatasetforTranslation
import utils
from vanilla_transformer import Transformer, generate_masks


In [None]:
main_path = os.getenv('HOME') + '/saturi_lab_multi_nmt_low_resource'
data_path = main_path + '/data/processed/translated_train_data.csv'
df = pd.read_csv(data_path,)
del df['Unnamed: 0']
df.head()

In [None]:
# Tokenizer model files path
src_tok_path = main_path + '/saved_models/tokenizer/old/spm_enc_spm16000.model'
tgt_tok_path = main_path + '/saved_models/tokenizer/old/spm_dec_spm16000.model'

In [None]:
# Load tokenzier models

src_tokenizer = spm.SentencePieceProcessor()#.Load(src_tok_path)
src_tokenizer.Load(src_tok_path)
tgt_tokenizer = spm.SentencePieceProcessor()#.Load(tgt_tok_path)
tgt_tokenizer.Load(tgt_tok_path)
tgt_tokenizer.set_encode_extra_options("bos:eos")

print('source tokenizer vocab size :',src_tokenizer.vocab_size())
print(src_tokenizer.EncodeAsPieces('Here is an example of source tokenization.'))
print('target tokenizer vocab size :',tgt_tokenizer.vocab_size())
print(tgt_tokenizer.EncodeAsPieces('이것은 토큰화 예시입니다.'))

In [None]:
train, test = train_test_split(df, test_size=0.2,stratify=df['reg'],random_state=2,shuffle=True)

In [None]:
# Cretate dataset for training; CustomDataSetforTranslation class created in src folder
train_dataset = CustomDatasetforTranslation(train['eng'].to_numpy(),train['dial'].to_numpy(),train['reg'].to_numpy(), 32, src_tokenizer,tgt_tokenizer)
valid_dataset = CustomDatasetforTranslation(test['eng'].to_numpy(), test['dial'].to_numpy(), test['reg'].to_numpy(), 32, src_tokenizer,tgt_tokenizer)

In [None]:
# Learning Scheduler,optimizer, and loss function
learningrate = utils.LearningRateScheduler(512)
optimizer = tf.keras.optimizers.Adam(learningrate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

In [None]:
# loss
criterion = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = criterion(real, pred)

    # Masking 되지 않은 입력의 개수로 Scaling하는 과정
    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_sum(loss_)/tf.reduce_sum(mask)

In [None]:
# define train function

@tf.function(reduce_retracing=True)
def train_step(src, tgt, model, optimizer):
    gold = tgt[:, 1:]
        
    enc_mask, dec_enc_mask, dec_mask = generate_masks(src, tgt)

    # 계산된 loss에 tf.GradientTape()를 적용해 학습을 진행합니다.
    with tf.GradientTape() as tape:
        predictions, enc_attns, dec_attns, dec_enc_attns = model(src, tgt, enc_mask, dec_enc_mask, dec_mask)
        loss = loss_function(gold, predictions[:, :-1])

    # 최종적으로 optimizer.apply_gradients()가 사용됩니다. 
    gradients = tape.gradient(loss, model.trainable_variables)    
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    
    return loss

In [None]:
# Validating the model
@tf.function
def model_validate(src, tgt, model):
    gold = tgt[:, 1:]
        
    enc_mask, dec_enc_mask, dec_mask = generate_masks(src, tgt)
    predictions, enc_attns, dec_attns, dec_enc_attns = model(src, tgt, enc_mask, dec_enc_mask, dec_mask)
    v_loss = loss_function(gold, predictions[:, :-1])
    
    return v_loss, predictions

In [None]:
#train function
def train(transformer,train_dataset, valid_dataset,optimizer,EPOCHS):

    for epoch in range(EPOCHS):
        t = tqdm(train_dataset)
        total_loss = 0
        t.set_description_str(f'EPOCH {epoch}')

        for i, pairs in enumerate(t):
            
            src, tgt = pairs
            max_len = len(max(src,key=len))
            enc_train = tf.keras.preprocessing.sequence.pad_sequences(src, padding='post', maxlen=max_len)
            dec_train = tf.keras.preprocessing.sequence.pad_sequences(tgt, padding='post', maxlen=max_len)

            batch_loss = train_step(enc_train,
                                    dec_train,
                                    transformer,
                                    optimizer)

            total_loss += batch_loss


            t.set_postfix_str('Loss %.4f' %  (total_loss.numpy() / (i + 1)))
            
            
        #validation
        total_loss_val = 0
        tv = tqdm(valid_dataset)
        
        for k, vpairs in enumerate(tv) :
            src, tgt = vpairs
            max_len = len(max(src,key=len))
            enc_val = tf.keras.preprocessing.sequence.pad_sequences(src, padding='post', maxlen=max_len)
            dec_val = tf.keras.preprocessing.sequence.pad_sequences(tgt, padding='post', maxlen=max_len)
            val_loss = model_validate(enc_val,
                                      dec_val,
                                      transformer)
            total_loss_val += val_loss
            tv.set_postfix_str('val_Loss %.4f' % (total_loss_val.numpy() / (k + 1)))

In [None]:
# get model config
import json
config_path = main_path + '/src/utils/config.json'
with open(config_path,'r') as f :
    config = json.load(f)['model']

In [None]:
config['src_vocab_size'] = src_tokenizer.vocab_size()
config['tgt_vocab_size'] = tgt_tokenizer.vocab_size()

In [None]:
# model init
transformer = Transformer(
    n_layers=config['n_layers'],
    d_model=config['d_model'],
    n_heads=config['n_heads'],
    d_ff=config['d_ff'],
    src_vocab_size=config['src_vocab_size'],
    tgt_vocab_size=config['tgt_vocab_size'],
    pos_len=config['pos_len'],
    dropout=config['dropout'],
    shared=config['shared'])

In [None]:
# train
train(transformer,train_dataset, valid_dataset,optimizer,config['epochs'])