In [2]:
%reload_ext watermark
%reload_ext autoreload
%autoreload 2
%watermark -v -p numpy,sklearn,pandas
%watermark -v -p cv2,PIL,matplotlib
%watermark -v -p torch,torchvision,torchaudio
%watermark -v -p tensorflow
%matplotlib inline
%config InlineBackend.figure_format='retina'
%config IPCompleter.use_jedi = False

from IPython.display import display, Markdown, HTML, Javascript
display(HTML('<style>.container { width:%d%% !important; }</style>' % 80))

import sys, os, io, time, random, math
import json, base64, requests
import os.path as osp
import tensorflow as tf

def _IMPORT_(x):
    try:
        segs = x.split(' ')
        g = globals()
        if 'github.com' in segs[1]:
            uri = segs[1].replace('github.com', 'raw.githubusercontent.com')
            mod = uri.split('/')
            for s in ['main', 'master']:
                uri = 'https://' + '/'.join(mod[:-1]) + '/main/' + mod[-1] + '.py'
                x = requests.get(uri).text
                if x.status == 200:
                    break
        elif 'gitee.com' in segs[1]:
            mod = segs[1].split('/')
            for s in ['/raw/main/', '/raw/master/']:
                uri = 'https://' + '/'.join(mod[:3]) + s + '/'.join(mod[3:]) + '.py'
                x = requests.get(uri).text
                if x.status == 200:
                    break
        elif segs[1][0] == '/':
            with open(segs[1] + '.py') as fr:
                x = fr.read()
        exec(x, g)
    except:
        pass

def print_progress_bar(x):
    print('\r', end='')
    print('Progress: {}%:'.format(x), '%s%s' % ('▋'*(x//2), '.'*((100-x)//2)), end='')
    sys.stdout.flush()


CPython 3.6.9
IPython 7.16.1

numpy 1.18.5
sklearn 0.24.0
pandas 1.1.5
CPython 3.6.9
IPython 7.16.1

cv2 4.5.1
PIL 6.2.2
matplotlib 3.3.3
CPython 3.6.9
IPython 7.16.1

torch 1.8.0.dev20210103+cu101
torchvision 0.9.0.dev20210103+cu101
torchaudio not installed
CPython 3.6.9
IPython 7.16.1

tensorflow 2.3.2


In [3]:
# -------------------------
# -----  Toy Context  -----
# -------------------------


class Net(tf.keras.Model):
    """A simple linear model."""

    def __init__(self):
        super(Net, self).__init__()
        self.l1 = tf.keras.layers.Dense(5)

    def call(self, x):
        return self.l1(x)


def toy_dataset():
    inputs = tf.range(10.0)[:, None]
    labels = inputs * 5.0 + tf.range(5.0)[None, :]
    return (
        tf.data.Dataset.from_tensor_slices(dict(x=inputs, y=labels)).repeat().batch(2)
    )


def train_step(net, example, optimizer):
    """Trains `net` on `example` using `optimizer`."""
    with tf.GradientTape() as tape:
        output = net(example["x"])
        loss = tf.reduce_mean(tf.abs(output - example["y"]))
    variables = net.trainable_variables
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))
    return loss


# ----------------------------
# -----  Create Objects  -----
# ----------------------------

net = Net()
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(
    step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator
)
manager = tf.train.CheckpointManager(ckpt, "./tf_ckpts", max_to_keep=3)

# ----------------------------
# -----  Train and Save  -----
# ----------------------------

ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
else:
    print("Initializing from scratch.")

for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
        save_path = manager.save()
        print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
        print("loss {:1.2f}".format(loss.numpy()))


# ---------------------
# -----  Restore  -----
# ---------------------

# In another script, re-initialize objects
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(
    step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator
)
manager = tf.train.CheckpointManager(ckpt, "./tf_ckpts", max_to_keep=3)

# Re-use the manager code above ^

ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
else:
    print("Initializing from scratch.")

for _ in range(50):
    example = next(iterator)
    # Continue training or evaluate etc.


Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 30.41
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 23.83
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 17.27
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 10.82
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 4.73
Restored from ./tf_ckpts/ckpt-5


## References

- [TensorFlow2中Keras模型保存与加载][1]

[1]: https://www.cnblogs.com/chenzhen0530/p/13943172.html