# Library and Configs

In [1]:
# For Colab. Run this after run time reset
# import sys
# sys.path.append('/content/drive/MyDrive/Colab Notebooks/transformer_mastery')
# path = '/content/drive/MyDrive/Colab Notebooks/transformer_mastery/metadata/'

In [2]:
# %pip install keras_preprocessing

In [3]:
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import LearningRateSchedule
from tensorflow.keras.metrics import Mean
from tensorflow import data, train, math, reduce_sum, cast, equal
from tensorflow import argmax, float32, GradientTape, function
from keras.losses import sparse_categorical_crossentropy
from transformer import TransformerModel
from prepare_dataset import PrepareDataset
from time import time
from pickle import dump, load

# Data

In [6]:
# Prepare the dataset
dataset = PrepareDataset()
trainX, trainY, valX, valY, train_org, val_org, enc_seq_length, dec_seq_length, enc_vocab_size, dec_vocab_size = dataset("None")
print('enc_seq_length, dec_seq_length, enc_vocab_size, dec_vocab_size')
print(enc_seq_length, dec_seq_length, enc_vocab_size, dec_vocab_size)

enc_seq_length, dec_seq_length, enc_vocab_size, dec_vocab_size
9 7 3157 2009


# Train

In [7]:
# Parameters
h = 8
d_k = 64
d_v = 64
d_model = 512
d_ff = 2048
n = 6

#
epochs = 20
batch_size = 64
beta_1 = 0.9
beta_2 = 0.98
epsilon = 1e-9
dropout_rate = 0.1

In [8]:
# Prepare the batches
train_dataset = data.Dataset.from_tensor_slices((trainX, trainY))
train_dataset = train_dataset.batch(batch_size)
val_dataset = data.Dataset.from_tensor_slices((valX, valY))
val_dataset = val_dataset.batch(batch_size)

In [9]:
# Loss function
def loss_fcn(target, prediction):
    padding_mask = math.logical_not(equal(target, 0))
    padding_mask = cast(padding_mask, float32)
    #
    loss = sparse_categorical_crossentropy(target, prediction)
    #
    return reduce_sum(loss) / reduce_sum(padding_mask)


def accuracy_fcn(target, prediction):
    padding_mask = math.logical_not(equal(target, 0))
    #
    accuracy = equal(target, argmax(prediction, axis=2))
    accuracy = math.logical_and(padding_mask, accuracy)
    #
    accuracy = cast(accuracy, float32)
    padding_mask = cast(padding_mask, float32)
    #

    return reduce_sum(accuracy) / reduce_sum(padding_mask)


train_loss = Mean(name='train_loss')
train_accuracy = Mean(name='train_accuracy')
val_loss = Mean(name='val_loss')

# Optimizer

class LRScheduler(LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000, **kwargs):
        super(LRScheduler, self).__init__(**kwargs)
        self.d_model = cast(d_model, float32)
        self.warmup_steps = warmup_steps

    def __call__(self, step_num):
        step_num = cast(step_num, float32)
        arg1 = step_num ** -0.5
        arg2 = step_num * (self.warmup_steps ** -1.5)
        return (self.d_model ** -0.5) * math.minimum(arg1, arg2)

#
optimizer = Adam(LRScheduler(d_model), beta_1, beta_2, epsilon)

In [13]:
# Create model
training_model = TransformerModel(enc_vocab_size, dec_vocab_size, enc_seq_length,
                                  dec_seq_length, h, d_k, d_v, d_model, d_ff, n, dropout_rate)

# Checkpoint manage
ckpt = train.Checkpoint(model=training_model, optimizer=optimizer)
ckpt_manager = train.CheckpointManager(ckpt, './metadata/checkpoints', max_to_keep=3)
train_loss_dict = {}
val_loss_dict = {}


@function
def train_step(encoder_input, decoder_input, decoder_output):
    with GradientTape() as tape:
        prediction = training_model(
            encoder_input, decoder_input, training=True)
        loss = loss_fcn(decoder_output, prediction)
        accuracy = accuracy_fcn(decoder_output, prediction)
    #
    gradients = tape.gradient(loss, training_model.trainable_weights)
    #
    optimizer.apply_gradients(zip(gradients, training_model.trainable_weights))

    train_loss(loss)
    train_accuracy(accuracy)
    # print(f'Loss {train_loss.result():.4f}, Accuracy {train_accuracy.result():.4f}')



In [14]:
start_time = time()
for epoch in range(epochs):
    train_loss.reset_states()
    train_accuracy.reset_states()
    val_loss.reset_states()

    print('\nStart of epoch %d' % (epoch+1))

    print(f'time() = {time()-start_time}')

    for step, (train_batchX, train_batchY) in enumerate(train_dataset):
        encoder_input = train_batchX[:, 1:]
        decoder_input = train_batchY[:, :-1]
        decoder_output = train_batchY[:, 1:]
        #
        train_step(encoder_input, decoder_input, decoder_output)
        #
        if step % 50 == 0:
            print(
                f'Epoch {epoch+1}, Step {step}, Loss {train_loss.result():.4f}, Accuracy {train_accuracy.result():.4f}')

    for val_batchX, val_batchY in val_dataset:
        encoder_input = val_batchX[:, 1:]
        decoder_input = val_batchY[:, :-1]
        decoder_output = val_batchY[:, 1:]
        #
        prediction = training_model(
            encoder_input, decoder_input, training=False)
        loss = loss_fcn(decoder_output, prediction)
        val_loss(loss)

    print('Epoch %d: Training Loss %.4f, Training Accuracy %.4f, Validation Loss %.4f'
          % (epoch+1, train_loss.result(), train_accuracy.result(), val_loss.result()))

    if (epoch + 1) % 1 == 0:
        save_path = ckpt_manager.save()
        print('Saved checkpoint at epoch %d' % (epoch + 1))

        training_model.save_weights('./metadata/weights/wghts' + str(epoch+1) + '.ckpt')
        train_loss_dict[epoch] = train_loss.result()
        val_loss_dict[epoch] = val_loss.result()




Start of epoch 1
time() = 0.0047757625579833984
Epoch 1, Step 0, Loss 12.6423, Accuracy 0.0000
Epoch 1, Step 50, Loss 7.1565, Accuracy 0.0090
Epoch 1, Step 100, Loss 6.1277, Accuracy 0.1394
Epoch 1: Training Loss 5.8364, Training Accuracy 0.1728, Validation Loss 4.6126
Saved checkpoint at epoch 1

Start of epoch 2
time() = 18.447755098342896
Epoch 2, Step 0, Loss 4.7424, Accuracy 0.2773
Epoch 2, Step 50, Loss 4.4311, Accuracy 0.3237
Epoch 2, Step 100, Loss 4.3059, Accuracy 0.3331
Epoch 2: Training Loss 4.2497, Training Accuracy 0.3390, Validation Loss 3.9393
Saved checkpoint at epoch 2

Start of epoch 3
time() = 28.17925500869751
Epoch 3, Step 0, Loss 4.1876, Accuracy 0.3193
Epoch 3, Step 50, Loss 3.9115, Accuracy 0.3728
Epoch 3, Step 100, Loss 3.8145, Accuracy 0.3871
Epoch 3: Training Loss 3.7719, Training Accuracy 0.3919, Validation Loss 3.6098
Saved checkpoint at epoch 3

Start of epoch 4
time() = 37.772111892700195
Epoch 4, Step 0, Loss 3.7115, Accuracy 0.3782
Epoch 4, Step 50, Lo

In [16]:
with open('./data/train_loss.pkl', 'wb') as file:
    dump(train_loss_dict, file)
with open('./data/val_loss.pkl', 'wb') as file:
    dump(val_loss_dict, file)
print('Total time taken: %.2fs' % (time() - start_time))

Total time taken: 234.21s


# Conclusion