In [None]:
import tensorflow as tf
from time import time
import matplotlib.pyplot as plt
import config
import utils

@tf.function
def train_step(img_tensor, target, encoder, decoder, optimizer, loss_object):
    dec_input = tf.convert_to_tensor(target[:, :-1])
    with tf.GradientTape() as tape:
        features = encoder(img_tensor)
        predictions, _, _ = decoder(dec_input, features)
        loss = utils.loss_function(target[:, 1:], predictions, loss_object)

    trainable_variables = encoder.trainable_variables + decoder.trainable_variables
    gradients = tape.gradient(loss, trainable_variables)
    optimizer.apply_gradients(zip(gradients, trainable_variables))

    return loss

def train_model(encoder, decoder, dataset):
    optimizer = tf.keras.optimizers.Adam()
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

    ckpt = tf.train.Checkpoint(encoder=encoder, decoder=decoder, optimizer=optimizer)
    ckpt_manager = tf.train.CheckpointManager(ckpt, config.checkpoint_path, max_to_keep=5)

    start_epoch = 0
    if ckpt_manager.latest_checkpoint:
        start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])
        ckpt.restore(ckpt_manager.latest_checkpoint)

    loss_plot = []
    for epoch in range(start_epoch, config.EPOCHS):
        start = time()
        total_loss = 0

        for (batch, (img_tensor, target)) in enumerate(dataset):
            batch_loss = train_step(img_tensor, target, encoder, decoder, optimizer, loss_object)
            total_loss += batch_loss

            if batch % 100 == 0:
                print(f'Epoch {epoch+1} Batch {batch} Loss {batch_loss.numpy():.4f}')

        loss_plot.append(total_loss / config.NUM_STEPS)

        if epoch % 5 == 0:
            ckpt_manager.save()

        print(f'Epoch {epoch+1} Loss {total_loss:.6f}')
        print(f'Time taken for 1 epoch {time() - start:.2f} sec\n')

    plt.plot(loss_plot)
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss Plot')
    plt.show()