# Details

In [1]:
import os
import numpy as np
import tensorflow as tf
tf.get_logger().setLevel("ERROR")

## Preparation
First, I create a dataset and a model, and train the model with `model.fit()`. 

One of the requirements of our code is that it can reproduce the following result.

In [2]:
# set random seed
seed = 0
np.random.seed(seed)
tf.random.set_seed(seed)

# create a dataset with numpy arrays
x = np.random.random((100, 10))
y = np.random.randint(0, 2, (100, 1))
x = tf.data.Dataset.from_tensor_slices(x)
y = tf.data.Dataset.from_tensor_slices(y)

# shuffling and batching the dataset
dataset = tf.data.Dataset.zip((x, y)).shuffle(100).batch(10)

# create a model to train on the dataset
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(10)),
    tf.keras.layers.Dense(1)
])
model.compile(optimizer="adam", loss="binary_crossentropy", metrics="accuracy")

# training
epochs = 3
history = model.fit(dataset, epochs=epochs)
print(history.history)

Epoch 1/3
Epoch 2/3
Epoch 3/3
{'loss': [2.1480836868286133, 1.9837868213653564, 1.8324090242385864], 'accuracy': [0.5099999904632568, 0.49000000953674316, 0.49000000953674316]}


Let's take a look at the toy dataset.
The input is a 10-dimensional random vector and the output is a random label 0 or 1.

In [3]:
for xi, yi in dataset.take(1).as_numpy_iterator():
    print(f"Input: shape={xi.shape}")
    print(xi)

    print()

    print(f"Target: shape={yi.shape}")
    print(yi)

Input: shape=(10, 10)
[[0.57019677 0.43860151 0.98837384 0.10204481 0.20887676 0.16130952
  0.65310833 0.2532916  0.46631077 0.24442559]
 [0.06271295 0.42403225 0.25868407 0.84903831 0.03330463 0.95898272
  0.35536885 0.35670689 0.0163285  0.18523233]
 [0.1390727  0.42690436 0.84285489 0.81803331 0.10241376 0.15638335
  0.30419869 0.07535907 0.424663   0.10761771]
 [0.69742877 0.45354268 0.7220556  0.86638233 0.97552151 0.85580334
  0.01171408 0.35997806 0.72999056 0.17162968]
 [0.18115096 0.78854551 0.05684808 0.69699724 0.7786954  0.77740756
  0.25942256 0.37381314 0.58759964 0.2728219 ]
 [0.30040368 0.54950057 0.93081872 0.52076144 0.26720703 0.87739879
  0.37191875 0.00138335 0.24768502 0.31823351]
 [0.8965466  0.36756187 0.43586493 0.89192336 0.80619399 0.70388858
  0.10022689 0.91948261 0.7142413  0.99884701]
 [0.04680635 0.97073144 0.00386035 0.17857997 0.61286675 0.0813696
  0.8818965  0.71962016 0.96638997 0.50763555]
 [0.99033895 0.21689698 0.6630782  0.26332238 0.020651   0.

Now I prepare a method to create a dataset and a model.

In [4]:
def create_data_and_model():
    # set random seed
    seed = 0
    np.random.seed(seed)
    tf.random.set_seed(seed)

    # create a dataset with numpy arrays
    x = np.random.random((100, 10))
    y = np.random.randint(0, 2, (100, 1))
    x = tf.data.Dataset.from_tensor_slices(x)
    y = tf.data.Dataset.from_tensor_slices(y)

    # shuffling and batching the dataset
    data = tf.data.Dataset.zip((x, y)).shuffle(100).batch(10)

    # create a model to train on the dataset
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(10)),
        tf.keras.layers.Dense(1)
    ])
    model.compile(optimizer="adam", loss="binary_crossentropy", metrics="accuracy")
    return data, model

## Save and Restore a Dataset

I need to save and restore the dataset because it has states such as a random state used by `shuffle()`. 

So I would like to customize `model.fit()`.
One way is to implement a custom callback as described [here](https://www.tensorflow.org/guide/keras/custom_callback). 

However, it seems to be hard to access the dataset from a callback. 
Therefore, I first try to reproduce the behaviour of `model.fit()` as much as possible using the Keras API.

In [5]:
def train1(data, model, epochs):
    # create an iterator
    iterator = iter(data)

    # create callbacks
    steps = len(data)
    callbacks = tf.keras.callbacks.CallbackList(add_history=True, add_progbar=True, model=model, epochs=epochs, steps=steps, verbose=True)

    # training
    callbacks.on_train_begin()
    train_fn = model.make_train_function()

    logs = None
    for epoch in range(epochs):
        iterator = iter(data)
        model.reset_metrics()
        callbacks.on_epoch_begin(epoch)

        for step in range(steps):
            callbacks.on_train_batch_begin(step)
            logs = train_fn(iterator)
            callbacks.on_train_batch_end(step + 1, logs)

        callbacks.on_epoch_end(epoch, logs)

    callbacks.on_train_end(logs=logs)
    return model.history

Let's see if the result is the same as when training with `model.fit()`.

In [6]:
epochs = 3

# model.fit()
print("model.fit()")
data, model = create_data_and_model()
history = model.fit(data, epochs=epochs)
print(history.history)

print()

# custom training
print("custom training")
data, model = create_data_and_model()
history = train1(data, model, epochs)
print(history.history)

model.fit()
Epoch 1/3
Epoch 2/3
Epoch 3/3
{'loss': [2.1480836868286133, 1.9837868213653564, 1.8324090242385864], 'accuracy': [0.5099999904632568, 0.49000000953674316, 0.49000000953674316]}

custom training
Epoch 1/3
Epoch 2/3
Epoch 3/3
{'loss': [2.1480836868286133, 1.9837868213653564, 1.8324090242385864], 'accuracy': [0.5099999904632568, 0.49000000953674316, 0.49000000953674316]}


The key to get the same result is the line 3. 
```python
iterator = iter(data)
```
In `model.fit()`, this is called at the beginning of the training to convert a dataset to an iterator, and is also called at the beginning of each epoch to get data used in the epoch. 

This means that the very first cycle of the dataset is not used anywhere. So let's see what happens without the line 3.

In [7]:
def train2(data, model, epochs):
    # create callbacks
    steps = len(data)
    callbacks = tf.keras.callbacks.CallbackList(add_history=True, add_progbar=True, model=model, epochs=epochs, steps=steps, verbose=True)

    # training
    callbacks.on_train_begin()
    train_fn = model.make_train_function()

    logs = None
    for epoch in range(epochs):
        iterator = iter(data)
        model.reset_metrics()
        callbacks.on_epoch_begin(epoch)

        for step in range(steps):
            callbacks.on_train_batch_begin(step)
            logs = train_fn(iterator)
            callbacks.on_train_batch_end(step + 1, logs)

        callbacks.on_epoch_end(epoch, logs)

    callbacks.on_train_end(logs=logs)
    return model.history

In [8]:
epochs = 3

# model.fit()
print("model.fit()")
data, model = create_data_and_model()
history = model.fit(data, epochs=epochs)
print(history.history)

print()

# custom training
print("custom training")
data, model = create_data_and_model()
history = train2(data, model, epochs)
print(history.history)

model.fit()
Epoch 1/3
Epoch 2/3
Epoch 3/3
{'loss': [2.1480836868286133, 1.9837868213653564, 1.8324090242385864], 'accuracy': [0.5099999904632568, 0.49000000953674316, 0.49000000953674316]}

custom training
Epoch 1/3
Epoch 2/3
Epoch 3/3
{'loss': [2.287079334259033, 2.024529457092285, 1.88909912109375], 'accuracy': [0.5, 0.5, 0.49000000953674316]}


Now I modify `train1()` so that I can save and restore the dataset. 
The way to create checkpoints of a dataset during training is described [here](https://www.tensorflow.org/guide/checkpoint).

In [9]:
# path to ckechpoint files
ckpt_dir = "checkpoint"
ckpt_path = os.path.join(ckpt_dir, "ckpt_{epoch:04d}_{step:04d}", "{name}")


# save a dataset
def save_iterator(iterator, epoch, step):
    path = ckpt_path.format(epoch=epoch, step=step, name="iterator")
    ckpt = tf.train.Checkpoint(iterator)
    ckpt.write(path)
    return


# restore a dataset
def restore_iterator(iterator, epoch, step):
    path = ckpt_path.format(epoch=epoch, step=step, name="iterator")
    ckpt = tf.train.Checkpoint(iterator)
    ckpt.read(path).assert_consumed()
    return


def train3(data, model, epochs, initial_epoch, initial_step, epoch_period, step_period):
    iterator = iter(data)

    # restore the dataset
    if initial_epoch != 0 or initial_step != 0:
        restore_iterator(iterator, initial_epoch, initial_step)

    steps = len(data)
    callbacks = tf.keras.callbacks.CallbackList(add_history=True, add_progbar=True, model=model, epochs=epochs, steps=steps, verbose=True)

    callbacks.on_train_begin()
    train_fn = model.make_train_function()

    # start at the "initial_epoch"th epoch
    logs = None
    for epoch in range(initial_epoch, epochs):

        # initialize if starting an epoch from the beginning
        if initial_step == 0:
            iterator = iter(data)
            model.reset_metrics()
        
        callbacks.on_epoch_begin(epoch)

        # start at the "initial_step"th step for only the first epoch
        for step in range(initial_step, steps):
            callbacks.on_train_batch_begin(step)
            #logs = train_fn(iterator)

            # only for displaying a dataset
            x, y = next(iterator)
            print(x.numpy()[0, :5])

            callbacks.on_train_batch_end(step + 1, logs)

            # save the dataset every "step_period" steps
            if step_period != 0 and (step + 1) % step_period == 0:
                save_iterator(iterator, epoch, step + 1)

        # reset "initial_step" after the first epoch
        if initial_step != 0:
            initial_step = 0

        callbacks.on_epoch_end(epoch, logs)

        # save the dataset every "epochs_period" epochs
        if epoch_period != 0 and (epoch + 1) % epoch_period == 0:
            save_iterator(iterator, epoch + 1, 0)

    callbacks.on_train_end(logs=logs)
    return model.history

This code enables to save the dataset every `step_period` steps and/or every `epoch_period` epochs, and restart the training from the `initial_step`th step of the `initial_epoch`th epoch. 

Note that it doesn't perform training but only display elements of the dataset to check the behaviour. Now let's try to run this code.

In [10]:
# save the dataset
print("save the dataset every 3 steps")
epochs = 1
initial_epoch = 0
initial_step = 0
epoch_period = 0
step_period = 3
data, model = create_data_and_model()
history = train3(data, model, epochs, initial_epoch, initial_step, epoch_period, step_period)

print()

# restore the dataset
print("restart the training from the 3rd step of the first epoch")
epochs = 1
initial_epoch = 0
initial_step = 3
epoch_period = 0
step_period = 0
data, model = create_data_and_model()
history = train3(data, model, epochs, initial_epoch, initial_step, epoch_period, step_period)

save the dataset every 3 steps
[0.00270321 0.64719665 0.60039224 0.58873961 0.96277032]
 2/10 [=====>........................] - ETA: 0s[0.97861834 0.79915856 0.46147936 0.78052918 0.11827443]
[0.60571196 0.11566187 0.72788816 0.63746228 0.81193856]
[0.50106317 0.37638916 0.36491184 0.2609045  0.4959703 ]
[0.18115096 0.78854551 0.05684808 0.69699724 0.7786954 ]
[0.31038083 0.37303486 0.52497044 0.75059502 0.33350747]
[0.4012595  0.92929142 0.09961493 0.94530153 0.86948853]
[0.42370635 0.85712492 0.11731556 0.27125208 0.40379274]
[0.351893   0.72140667 0.63758269 0.81305386 0.97622566]
[0.44132147 0.48641045 0.44836918 0.567846   0.62116925]

restart the training from the 3rd step of the first epoch
[0.50106317 0.37638916 0.36491184 0.2609045  0.4959703 ]
[0.31038083 0.37303486 0.52497044 0.75059502 0.33350747]
[0.4012595  0.92929142 0.09961493 0.94530153 0.86948853]
[0.42370635 0.85712492 0.11731556 0.27125208 0.40379274]
[0.351893   0.72140667 0.63758269 0.81305386 0.97622566]
[0.4413

## Save and Restore a Model

Of course, I need to save and restore the model too.
There are mainly two ways to do that as described [here](https://www.tensorflow.org/tutorials/keras/save_and_load).

One way saves the model architecture and the model weights, and the other way saves only the model weights.

In [11]:
# save and restore a model architecture and model weights
def save_model1(model, epoch, step):
    path = ckpt_path.format(epoch=epoch, step=step, name="model")
    model.save(path)
    return

def restore_model1(epoch, step):
    path = ckpt_path.format(epoch=epoch, step=step, name="model")
    return tf.keras.models.load_model(path)


# save and restore only model weights
def save_model2(model, epoch, step):
    path = ckpt_path.format(epoch=epoch, step=step, name="model")
    model.save_weights(path)
    return

def restore_model2(model, epoch, step):
    path = ckpt_path.format(epoch=epoch, step=step, name="model")
    model.load_weights(path)
    return

Now I modify `train3()` so that I can save and restore the model in two ways.

In [12]:
# modification using the former way
def train4(data, model, epochs, initial_epoch, initial_step, epoch_period, step_period):
    iterator = iter(data)

    # restore the model
    if initial_epoch != 0 or initial_step != 0:
        restore_iterator(iterator, initial_epoch, initial_step)
        model = restore_model1(initial_epoch, initial_step)

    steps = len(data)
    callbacks = tf.keras.callbacks.CallbackList(add_history=True, add_progbar=True, model=model, epochs=epochs, steps=steps, verbose=True)

    callbacks.on_train_begin()
    train_fn = model.make_train_function()

    logs = None
    for epoch in range(initial_epoch, epochs):
        if initial_step == 0:
            iterator = iter(data)
            model.reset_metrics()
        
        callbacks.on_epoch_begin(epoch)

        for step in range(initial_step, steps):
            callbacks.on_train_batch_begin(step)
            logs = train_fn(iterator)
            callbacks.on_train_batch_end(step + 1, logs)

            # save the model every "step_period" steps
            if step_period != 0 and (step + 1) % step_period == 0:
                save_iterator(iterator, epoch, step + 1)
                save_model1(model, epoch, step + 1)

        if initial_step != 0:
            initial_step = 0

        callbacks.on_epoch_end(epoch, logs)

        # save the model every "epochs_period" epochs
        if epoch_period != 0 and (epoch + 1) % epoch_period == 0:
            save_iterator(iterator, epoch + 1, 0)
            save_model1(model, epoch + 1, 0)

    callbacks.on_train_end(logs=logs)
    return model.history


# modification using the latter way
def train5(data, model, epochs, initial_epoch, initial_step, epoch_period, step_period):
    iterator = iter(data)

    # restore the model
    if initial_epoch != 0 or initial_step != 0:
        restore_iterator(iterator, initial_epoch, initial_step)
        restore_model2(model, initial_epoch, initial_step)

    steps = len(data)
    callbacks = tf.keras.callbacks.CallbackList(add_history=True, add_progbar=True, model=model, epochs=epochs, steps=steps, verbose=True)

    callbacks.on_train_begin()
    train_fn = model.make_train_function()

    logs = None
    for epoch in range(initial_epoch, epochs):
        if initial_step == 0:
            iterator = iter(data)
            model.reset_metrics()
        
        callbacks.on_epoch_begin(epoch)

        for step in range(initial_step, steps):
            callbacks.on_train_batch_begin(step)
            logs = train_fn(iterator)
            callbacks.on_train_batch_end(step + 1, logs)

            # save the model every "step_period" steps
            if step_period != 0 and (step + 1) % step_period == 0:
                save_iterator(iterator, epoch, step + 1)
                save_model2(model, epoch, step + 1)

        if initial_step != 0:
            initial_step = 0

        callbacks.on_epoch_end(epoch, logs)

        # save the model every "epochs_period" epochs
        if epoch_period != 0 and (epoch + 1) % epoch_period == 0:
            save_iterator(iterator, epoch + 1, 0)
            save_model2(model, epoch + 1, 0)

    callbacks.on_train_end(logs=logs)
    return model.history

In [13]:
# save the model
print("save the model every 3 steps")
epochs = 5
initial_epoch = 0
initial_step = 0
epoch_period = 0
step_period = 3
data, model = create_data_and_model()
history = train4(data, model, epochs, initial_epoch, initial_step, epoch_period, step_period)
print(history.history)

print()

# restore the model
print("restart the training from the 3rd step of the second epoch")
epochs = 5
initial_epoch = 1
initial_step = 3
epoch_period = 0
step_period = 0
data, model = create_data_and_model()
history = train4(data, model, epochs, initial_epoch, initial_step, epoch_period, step_period)
print(history.history)

save the model every 3 steps
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
{'loss': [2.1480836868286133, 1.9837868213653564, 1.8324090242385864, 1.812350869178772, 1.8041961193084717], 'accuracy': [0.5099999904632568, 0.49000000953674316, 0.49000000953674316, 0.49000000953674316, 0.5]}

restart the training from the 3rd step of the second epoch
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
{'loss': [1.7955976724624634, 1.8284294605255127, 1.810083270072937, 1.8030287027359009], 'accuracy': [0.5285714268684387, 0.49000000953674316, 0.5, 0.49000000953674316]}


In [14]:
# save the model
print("save the model every 3 steps")
epochs = 5
initial_epoch = 0
initial_step = 0
epoch_period = 0
step_period = 3
data, model = create_data_and_model()
history = train5(data, model, epochs, initial_epoch, initial_step, epoch_period, step_period)
print(history.history)

print()

# restore the model
print("restart the training from the 3rd step of the second epoch")
epochs = 5
initial_epoch = 1
initial_step = 3
epoch_period = 0
step_period = 0
data, model = create_data_and_model()
history = train5(data, model, epochs, initial_epoch, initial_step, epoch_period, step_period)
print(history.history)

save the model every 3 steps
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
{'loss': [2.1480836868286133, 1.9837868213653564, 1.8324090242385864, 1.812350869178772, 1.8041961193084717], 'accuracy': [0.5099999904632568, 0.49000000953674316, 0.49000000953674316, 0.49000000953674316, 0.5]}

restart the training from the 3rd step of the second epoch
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
{'loss': [1.7994781732559204, 1.8324090242385864, 1.812350869178772, 1.8041961193084717], 'accuracy': [0.5285714268684387, 0.49000000953674316, 0.49000000953674316, 0.5]}


The result of the former is completely different, but I don't know the reason at the moment. 

On the other hand, the result of the latter after the third epoch looks the same.
This is probably because the states of the metrics and the history are not saved.

## Save and Restore other information

I modify `train5()` so that I can save and restore the metrics and the history.

### metrics

The metrics can be saved in the same as the iterator.

But the metrics are created after the first `train_fn`. 
So I restore the metrics after the first `train_fn` to refer that objects.


### history

The history object (i.e. callback) cannot be saved.
So I save and restore the attributes of the history object, `period` and `history`.

In [15]:
import joblib


# save metrics
def save_metrics(model, epoch, step):
    path = ckpt_path.format(epoch=epoch, step=step, name="metrics")
    ckpt = tf.train.Checkpoint()
    ckpt.metrics = model.metrics
    ckpt.save(path)
    return


# restore metrics
def restore_metrics(model, epoch, step):
    path = ckpt_path.format(epoch=epoch, step=step, name="metrics-1")
    ckpt = tf.train.Checkpoint()

    # create initialized metrics
    restored = []
    for metric in model.metrics:
        name = metric.name
        if name == "loss":
            name = "Mean"
        name = "".join([w.capitalize() for w in  name.split("_")])
        restored.append(tf.keras.metrics.get(name))

    # restore metrics
    ckpt.metrics = restored
    ckpt.restore(path)

    # merge states and update logs
    logs = {}
    for m1, m2 in zip(model.metrics, restored):
        m1.merge_state([m2])
        logs[m1.name] = m1.result()
    return logs


# save a history
def save_history(callbacks, epoch, step):
    path = ckpt_path.format(epoch=epoch, step=step, name="history.pkl")

    # get the history object from callbacks
    history = None
    for callback in callbacks.callbacks:
        if callback.__class__.__name__ == "History":
            history = callback
            break
    joblib.dump({"epoch": history.epoch, "history": history.history}, path)
    return


# restore a history
def restore_history(callbacks, epoch, step):
    path = ckpt_path.format(epoch=epoch, step=step, name="history.pkl")
    restored = joblib.load(path)

    for callback in callbacks.callbacks:
        if callback.__class__.__name__ == "History":
            callback.epoch = restored["epoch"]
            callback.history = restored["history"]
            break
    return


def train6(data, model, epochs, initial_epoch, initial_step, epoch_period, step_period):
    iterator = iter(data)

    if initial_epoch != 0 or initial_step != 0:
        restore_iterator(iterator, initial_epoch, initial_step)
        restore_model2(model, initial_epoch, initial_step)

    steps = len(data)
    callbacks = tf.keras.callbacks.CallbackList(add_history=True, add_progbar=True, model=model, epochs=epochs, steps=steps, verbose=True)

    callbacks.on_train_begin()
    train_fn = model.make_train_function()

    # restore the hisotry
    if initial_epoch != 0 or initial_step != 0:
        restore_history(callbacks, initial_epoch, initial_step)

    logs = None
    for epoch in range(initial_epoch, epochs):
        if initial_step == 0:
            iterator = iter(data)
            model.reset_metrics()
        
        callbacks.on_epoch_begin(epoch)

        for step in range(initial_step, steps):
            callbacks.on_train_batch_begin(step)
            logs = train_fn(iterator)

            # restore the metrics and update the logs
            if initial_step != 0 and step == initial_step:
                logs = restore_metrics(model, initial_epoch, initial_step)

            callbacks.on_train_batch_end(step + 1, logs)
            
            # save the metrics and the history every "epochs_period" epochs
            if step_period != 0 and (step + 1) % step_period == 0:
                save_iterator(iterator, epoch, step + 1)
                save_model2(model, epoch, step + 1)
                save_metrics(model, epoch, step + 1)
                save_history(callbacks, epoch, step + 1)

        if initial_step != 0:
            initial_step = 0

        callbacks.on_epoch_end(epoch, logs)

        # save the metrics and the history every "epochs_period" epochs
        if epoch_period != 0 and (epoch + 1) % epoch_period == 0:
            save_iterator(iterator, epoch + 1, 0)
            save_model2(model, epoch + 1, 0)
            save_metrics(model, epoch + 1, 0)
            save_history(callbacks, epoch + 1, 0)

    callbacks.on_train_end(logs=logs)
    return model.history

In [17]:
# save the model
print("save the metrics and the history every 3 steps")
epochs = 5
initial_epoch = 0
initial_step = 0
epoch_period = 2
step_period = 3
data, model = create_data_and_model()
history = train6(data, model, epochs, initial_epoch, initial_step, epoch_period, step_period)
print(history.history)

print()

# restore the model
print("restart the training from the 3rd step of the second epoch")
epochs = 5
initial_epoch = 2
initial_step = 3
epoch_period = 0
step_period = 0
data, model = create_data_and_model()
history = train6(data, model, epochs, initial_epoch, initial_step, epoch_period, step_period)
print(history.history)

save the metrics and the history every 3 steps
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
{'loss': [2.1480836868286133, 1.9837868213653564, 1.8324090242385864, 1.812350869178772, 1.8041961193084717], 'accuracy': [0.5099999904632568, 0.49000000953674316, 0.49000000953674316, 0.49000000953674316, 0.5]}

restart the training from the 3rd step of the second epoch
Epoch 3/5
Epoch 4/5
Epoch 5/5
{'loss': [2.1480836868286133, 1.9837868213653564, 1.8324090242385864, 1.812350869178772, 1.8041961193084717], 'accuracy': [0.5099999904632568, 0.49000000953674316, 0.49000000953674316, 0.49000000953674316, 0.5]}
