In [1]:
import os
import csv
import numpy as np
import tensorflow as tf

from model.seq2seq.Seq2Seq import *
from model.seq2seq_attn.Seq2Seq_Attn import *
from utils.utils import *

Loading JIT Compiled ChatSpace Model


In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [3]:
@tf.function
def loss_function(true, pred, loss_obj):
    mask = tf.math.logical_not(tf.math.equal(true, 0))

    loss = loss_obj(true, pred)

    mask = tf.cast(mask, dtype=loss.dtype)
    loss *= mask
    
    return tf.reduce_mean(loss)

In [6]:
def train():
    # Load data
    dataset = load_dataset(data_dir)
    
    num_batches_per_epoch = len(dataset) // batch_size
    
    # Load tokenizer
    enc_tokenizer = load_tokenizer('enc-tokenizer', (x for x, y in dataset), target_vocab_size=2**13)
    dec_tokenizer = load_tokenizer('dec-tokenizer', (y for x, y in dataset), target_vocab_size=2**13)
    enc_vocab_size = enc_tokenizer.vocab_size + 1
    dec_vocab_size = dec_tokenizer.vocab_size + 2
    print(f'enc_vocab_size: {enc_vocab_size}\tdec_vocab_size: {dec_vocab_size}')
    
    # Define the optimizer and the loss function
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
    loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
    
    # Define seq2seq model
    config = {'batch_size': batch_size,
              'enc_max_len': enc_max_len+1,
              'dec_max_len': dec_max_len+2,
              'enc_unit': enc_unit,
              'dec_unit': dec_unit,
              'embed_dim': embed_dim,
              'dropout_rate': dropout_rate,
              'enc_vocab_size': enc_vocab_size,
              'dec_vocab_size': dec_vocab_size,
              'dec_sos_token': dec_tokenizer.vocab_size
              }
    
    model = seq2seq(config)
    
    # checkpoint
    checkpoint_dir = 'checkpoint/daily-korean/seq2seq'
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
        
    checkpoint_prefix = os.path.join(checkpoint_dir, 'checkpoint')
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
    
    epoch_loss = tf.keras.metrics.Mean()
    
    for epoch in range(epochs):
        epoch_loss.reset_states()
        
        train_batches = batch_dataset(dataset, batch_size, enc_tokenizer, dec_tokenizer, enc_max_len, dec_max_len)
        
        for batch_idx, (batch_x, batch_y) in enumerate(train_batches):
            loss = 0.
            with tf.GradientTape() as tape:
                preds = model(batch_x, batch_y, True)

                loss = loss_function(batch_y[:, 1:], preds, loss_obj)
            
            variables = model.trainable_variables
            gradients = tape.gradient(loss, variables)
            optimizer.apply_gradients(zip(gradients, variables))
            
            epoch_loss(loss)
            
            if (batch_idx + 1) % log_interval == 0:
                print(f'[Epoch {epoch + 1}|Step {batch_idx + 1}/{num_batches_per_epoch}] loss: {loss.numpy()} (Avg. {epoch_loss.result()})')
        
        model.save_weights(filepath=checkpoint_prefix)
    
    print("Training is Done.")

In [7]:
if __name__ == '__main__':
    np.random.seed(1234)
    tf.random.set_seed(1234)
    
    train()

enc_vocab_size: 8633	dec_vocab_size: 7921
11823
[Epoch 1|Step 50/369] loss: 1.5011221170425415 (Avg. 1.7559351921081543)
[Epoch 1|Step 100/369] loss: 1.5256730318069458 (Avg. 1.6276884078979492)
[Epoch 1|Step 150/369] loss: 1.4489439725875854 (Avg. 1.5577911138534546)
[Epoch 1|Step 200/369] loss: 1.3613492250442505 (Avg. 1.5276087522506714)
[Epoch 1|Step 250/369] loss: 1.333289623260498 (Avg. 1.4995362758636475)
[Epoch 1|Step 300/369] loss: 1.2491003274917603 (Avg. 1.47150719165802)
[Epoch 1|Step 350/369] loss: 1.1868115663528442 (Avg. 1.4474748373031616)
11823
[Epoch 2|Step 50/369] loss: 1.1757022142410278 (Avg. 1.2392525672912598)
[Epoch 2|Step 100/369] loss: 1.2452367544174194 (Avg. 1.2425786256790161)
[Epoch 2|Step 150/369] loss: 1.2470276355743408 (Avg. 1.2250256538391113)
[Epoch 2|Step 200/369] loss: 1.0841407775878906 (Avg. 1.2170518636703491)
[Epoch 2|Step 250/369] loss: 1.213249683380127 (Avg. 1.2139114141464233)
[Epoch 2|Step 300/369] loss: 1.016128659248352 (Avg. 1.213134050