# Details

In [2]:
import os
import joblib
import numpy as np
import tensorflow as tf
from utils import *
tf.get_logger().setLevel("ERROR")

## Validation

I modify the code to be able to evaluate validation dataset.

In [3]:
def create_data_and_model():
    # set random seed
    seed = 0
    np.random.seed(seed)
    tf.random.set_seed(seed)

    # create a dataset with numpy arrays
    x = np.random.random((100, 10))
    y = np.random.randint(0, 2, (100, 1))
    x = tf.data.Dataset.from_tensor_slices(x)
    y = tf.data.Dataset.from_tensor_slices(y)

    # shuffling and batching the dataset
    train_data = tf.data.Dataset.zip((x, y)).shuffle(100).batch(10)

    # create a dataset with numpy arrays
    x = np.random.random((100, 10))
    y = np.random.randint(0, 2, (100, 1))
    x = tf.data.Dataset.from_tensor_slices(x)
    y = tf.data.Dataset.from_tensor_slices(y)

    # shuffling and batching the dataset
    validation_data = tf.data.Dataset.zip((x, y)).batch(10)

    # create a model to train on the dataset
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(10)),
        tf.keras.layers.Dense(1)
    ])
    model.compile(optimizer="adam", loss="binary_crossentropy", metrics="accuracy")
    return train_data, validation_data, model


def train(data, model, epochs, validation_data=None, initial_epoch=0, initial_step=0, epoch_period=0, step_period=0, ckpt_dir="checkpoint"):
    cm = CheckpointManager(ckpt_dir)

    iterator = iter(data)

    if initial_epoch != 0 or initial_step != 0:
        cm.restore_iterator(iterator, initial_epoch, initial_step)
        cm.restore_model(model, initial_epoch, initial_step)

    steps = len(data)
    callbacks = tf.keras.callbacks.CallbackList(add_history=True, add_progbar=True, model=model, epochs=epochs, steps=steps, verbose=True)

    callbacks.on_train_begin()
    train_fn = model.make_train_function()

    if initial_epoch != 0 or initial_step != 0:
        cm.restore_history(callbacks, initial_epoch, initial_step)

    logs = None
    for epoch in range(initial_epoch, epochs):

        if initial_step == 0:
            iterator = iter(data)
            model.reset_metrics()
        
        callbacks.on_epoch_begin(epoch)

        for step in range(initial_step, steps):
            callbacks.on_train_batch_begin(step)
            logs = train_fn(iterator)

            if initial_step != 0 and step == initial_step:
                logs = cm.restore_metrics(model, initial_epoch, initial_step)

            callbacks.on_train_batch_end(step + 1, logs)
            
            if step_period != 0 and (step + 1) % step_period == 0:
                cm.save_iterator(iterator, epoch, step + 1)
                cm.save_model(model, epoch, step + 1)
                cm.save_metrics(model, epoch, step + 1)
                cm.save_history(callbacks, epoch, step + 1)

        if initial_step != 0:
            initial_step = 0

        # validation step
        if validation_data:
            val_logs = model.evaluate(validation_data, callbacks=callbacks, return_dict=True)
            val_logs = {"val_" + name: val for name, val in val_logs.items()}
            logs.update(val_logs)

        callbacks.on_epoch_end(epoch, logs)

        if epoch_period != 0 and (epoch + 1) % epoch_period == 0:
            cm.save_iterator(iterator, epoch + 1, 0)
            cm.save_model(model, epoch + 1, 0)
            cm.save_metrics(model, epoch + 1, 0)
            cm.save_history(callbacks, epoch + 1, 0)

    callbacks.on_train_end(logs=logs)
    return model.history


This can reproduce the result of `model.fit()`.

In [4]:
epochs = 5
train_data, validation_data, model = create_data_and_model()
history = model.fit(train_data, epochs=epochs, validation_data=validation_data)
print(history.history)

print()

epochs = 5
train_data, validation_data, model = create_data_and_model()
history = train(train_data, model, epochs, validation_data=validation_data)
print(history.history)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
{'loss': [2.1480836868286133, 1.9837868213653564, 1.8324090242385864, 1.812350869178772, 1.8041961193084717], 'accuracy': [0.5099999904632568, 0.49000000953674316, 0.49000000953674316, 0.49000000953674316, 0.5], 'val_loss': [1.6133465766906738, 1.229250192642212, 1.1984832286834717, 1.1877398490905762, 1.1842304468154907], 'val_accuracy': [0.6399999856948853, 0.6399999856948853, 0.6499999761581421, 0.6499999761581421, 0.6499999761581421]}

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
{'loss': [2.1480836868286133, 1.9837868213653564, 1.8324090242385864, 1.812350869178772, 1.8041961193084717], 'accuracy': [0.5099999904632568, 0.49000000953674316, 0.49000000953674316, 0.49000000953674316, 0.5], 'val_loss': [1.6133465766906738, 1.229250192642212, 1.1984832286834717, 1.1877398490905762, 1.1842304468154907], 'val_accuracy': [0.6399999856948853, 0.6399999856948853, 0.6499999761581421, 0.6499999761581421, 0.6499999761581421]}


This modification doesn't effect on saving the training.

In [5]:
epochs = 5
step_period = 3
epoch_period = 1
train_data, validation_data, model = create_data_and_model()
history = train(train_data, model, epochs, validation_data=validation_data, step_period=step_period, epoch_period=epoch_period)
print(history.history)

print()

epochs = 5
initial_step = 3
initial_epoch = 2
train_data, validation_data, model = create_data_and_model()
history = train(train_data, model, epochs, validation_data=validation_data, initial_step=initial_step, initial_epoch=initial_epoch)
print(history.history)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
{'loss': [2.1480836868286133, 1.9837868213653564, 1.8324090242385864, 1.812350869178772, 1.8041961193084717], 'accuracy': [0.5099999904632568, 0.49000000953674316, 0.49000000953674316, 0.49000000953674316, 0.5], 'val_loss': [1.6133465766906738, 1.229250192642212, 1.1984832286834717, 1.1877398490905762, 1.1842304468154907], 'val_accuracy': [0.6399999856948853, 0.6399999856948853, 0.6499999761581421, 0.6499999761581421, 0.6499999761581421]}

Epoch 3/5
Epoch 4/5
Epoch 5/5
{'loss': [2.1480836868286133, 1.9837868213653564, 1.8324090242385864, 1.812350869178772, 1.8041961193084717], 'accuracy': [0.5099999904632568, 0.49000000953674316, 0.49000000953674316, 0.49000000953674316, 0.5], 'val_loss': [1.6133465766906738, 1.229250192642212, 1.1984832286834717, 1.1877398490905762, 1.1842304468154907], 'val_accuracy': [0.6399999856948853, 0.6399999856948853, 0.6499999761581421, 0.6499999761581421, 0.6499999761581421]}
