In [1]:
%reload_ext watermark
%reload_ext autoreload
%autoreload 2
%watermark -p numpy,sklearn,pandas
%watermark -p ipywidgets,cv2,PIL,matplotlib,plotly,netron
%watermark -p torch,torchvision,torchaudio
%watermark -p tensorflow,tensorboard,tflite
%watermark -p onnx,tf2onnx,onnxruntime,tensorrt,tvm
%matplotlib inline
%config InlineBackend.figure_format='retina'
%config IPCompleter.use_jedi = False

from IPython.display import display, Markdown, HTML, Image, Javascript
from IPython.core.magic import register_line_cell_magic, register_line_magic, register_cell_magic
display(HTML('<style>.container { width:%d%% !important; }</style>' % 90))

import sys, os, io, time, random, math
import json, base64, requests, shutil
import os.path as osp
import numpy as np
import argparse, shlex, signal

argparse.ArgumentParser.exit = lambda *arg, **kwargs: _IGNORE_

def _IMPORT(x):
    try:
        x = x.strip()
        if x.startswith('https://'):
            x = x[8:]
        if not x.endswith('.py'):
            x = x + '.py'
        if x[0] == '/':
            with open(x) as fr:
                x = fr.read()
        else:
            x = x.replace('blob/main/', '').replace('blob/master/', '')
            if x.startswith('raw.githubusercontent.com'):
                uri = 'https://' + x
                x = requests.get(uri)
                if x.status_code == 200:
                    x = x.text
            elif x.startswith('github.com'):
                uri = x.replace('github.com', 'raw.githubusercontent.com')
                mod = uri.split('/')
                for s in ['main', 'master']:
                    uri = 'https://' + '/'.join(mod[:3]) + s + '/'.join(mod[-3:])
                    x = requests.get(uri)
                    if x.status_code == 200:
                        x = x.text
                        break
            elif x.startswith('gitee.com'):
                mod = x.split('/')
                for s in ['/raw/main/', '/raw/master/']:
                    uri = 'https://' + '/'.join(mod[:3]) + s + '/'.join(mod[3:])
                    x = requests.get(uri)
                    if x.status_code == 200:
                        x = x.text
                        break
        exec(x, globals())
    except:
        pass

def _DIR(x, dumps=True, ret=True):
    attrs = sorted([y for y in dir(x) if not y.startswith('_')])
    result = '%s: %s' % (str(type(x))[8:-2], json.dumps(attrs) if dumps else attrs)
    if ret:
        return result
    print(result)

numpy 1.19.5
sklearn 0.0
pandas 1.1.5
ipywidgets 7.6.3
cv2 4.5.3
PIL 8.3.1
matplotlib 3.3.4
plotly 5.3.0
netron 5.1.6
torch 1.8.1+cu101
torchvision 0.9.1+cu101
torchaudio not installed
tensorflow 2.6.0
tensorboard 2.6.0
tflite not installed
onnx 1.10.1
tf2onnx 1.10.0
onnxruntime 1.8.1
tensorrt not installed
tvm not installed


In [3]:
# -------------------------
# -----  Toy Context  -----
# -------------------------


class Net(tf.keras.Model):
    """A simple linear model."""

    def __init__(self):
        super(Net, self).__init__()
        self.l1 = tf.keras.layers.Dense(5)

    def call(self, x):
        return self.l1(x)


def toy_dataset():
    inputs = tf.range(10.0)[:, None]
    labels = inputs * 5.0 + tf.range(5.0)[None, :]
    return (
        tf.data.Dataset.from_tensor_slices(dict(x=inputs, y=labels)).repeat().batch(2)
    )


def train_step(net, example, optimizer):
    """Trains `net` on `example` using `optimizer`."""
    with tf.GradientTape() as tape:
        output = net(example["x"])
        loss = tf.reduce_mean(tf.abs(output - example["y"]))
    variables = net.trainable_variables
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))
    return loss


# ----------------------------
# -----  Create Objects  -----
# ----------------------------

net = Net()
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(
    step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator
)
manager = tf.train.CheckpointManager(ckpt, "./tf_ckpts", max_to_keep=3)

# ----------------------------
# -----  Train and Save  -----
# ----------------------------

ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
else:
    print("Initializing from scratch.")

for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
        save_path = manager.save()
        print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
        print("loss {:1.2f}".format(loss.numpy()))


# ---------------------
# -----  Restore  -----
# ---------------------

# In another script, re-initialize objects
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(
    step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator
)
manager = tf.train.CheckpointManager(ckpt, "./tf_ckpts", max_to_keep=3)

# Re-use the manager code above ^

ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
else:
    print("Initializing from scratch.")

for _ in range(50):
    example = next(iterator)
    # Continue training or evaluate etc.


Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 30.41
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 23.83
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 17.27
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 10.82
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 4.73
Restored from ./tf_ckpts/ckpt-5


## References

- [TensorFlow2中Keras模型保存与加载][1]

[1]: https://www.cnblogs.com/chenzhen0530/p/13943172.html