## Save as checkpoint

In [None]:
import tensorflow as tf

In [None]:
class Net(tf.keras.Model):
  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)

In [None]:
net = Net()

In [None]:
net.save_weights('easy_checkpoint')

In [None]:
%%bash
ls

In [None]:
def toy_dataset():
  inputs = tf.range(10.)[:, None] # None means shape=(10, 1)
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)

In [None]:
def train_step(net, example, optimizer):
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

In [None]:
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

In [None]:
# Save checkpoint at every step
def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print('restored from {}'.format(manager.latest_checkpoint))
  else:
    print('from scratch')

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1) # number of steps
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print('Saved ckpt for step {}: {}'.format(int(ckpt.step), save_path))
      print('loss {:1.2f}'.format(loss.numpy()))

In [None]:
train_and_checkpoint(net, manager)

## Restore and training

In [None]:
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)

In [None]:
print(manager.checkpoints)