# Save & Restore a Model

Save and Restore a model using TensorFlow v2. In this example, we will go over both low and high-level approaches:

- Low-level: TF Checkpoint.
- High-level: TF Module/Model saver.

This example is using the MNIST database of handwritten digits as toy dataset (http://yann.lecun.com/exdb/mnist/).

# set params

In [51]:

from __future__ import absolute_import, division, print_function

import tensorflow as tf
import numpy as np

In [52]:
# MNIST dataset parameters.
num_classes = 10 # 0 to 9 digits
num_features = 784 # 28*28

# Training parameters.
learning_rate = 0.01
training_steps = 1000
batch_size = 256
display_step = 50

# load datasets of mnist

In [53]:
# Prepare MNIST data.
path = '/home/zju/.keras/datasets/mnist.npz'
with np.load(path) as f:
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']
# Convert to float32.
x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)
x_train, x_test = x_train.reshape([-1, num_features]), x_test.reshape([-1, num_features])
# Normalize images value from [0, 255] to [0, 1].
x_train, x_test = x_train / 255., x_test / 255.

In [54]:
# Use tf.data API to shuffle and batch data.
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)

# tf checkpoint

In [55]:
# Weight of shape [784, 10], the 28*28 image features, and total number of classes.
W = tf.Variable(tf.random.normal([num_features, num_classes]), name="weight")
# Bias of shape [10], the total number of classes.
b = tf.Variable(tf.zeros([num_classes]), name="bias")

# Logistic regression (Wx + b).
def logistic_regression(x):
    # Apply softmax to normalize the logits to a probability distribution.
    return tf.nn.softmax(tf.matmul(x, W) + b)

# Cross-Entropy loss function.
def cross_entropy(y_pred, y_true):
    # Encode label to a one hot vector.
    y_true = tf.one_hot(y_true, depth=num_classes)
    # Clip prediction values to avoid log(0) error.
    y_pred = tf.clip_by_value(y_pred, 1e-9, 1.)
    # Compute cross-entropy.
    return tf.reduce_mean(-tf.reduce_sum(y_true * tf.math.log(y_pred)))

# Accuracy metric.
def accuracy(y_pred, y_true):
    # Predicted class is the index of highest score in prediction vector (i.e. argmax).
    correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64))
    return tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# Adam optimizer.
optimizer = tf.optimizers.Adam(learning_rate)

In [56]:
# Optimization process.
def run_optimization(x, y):
    # Wrap computation inside a GradientTape for automatic differentiation.
    with tf.GradientTape() as g:
        pred = logistic_regression(x)
        loss = cross_entropy(pred, y)
        gradients = g.gradient(loss, [W, b])
        optimizer.apply_gradients(zip(gradients, [W, b]))

In [57]:
# Run training for the given number of steps.
for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):
    # Run the optimization to update W and b values.
    run_optimization(batch_x, batch_y)

    if step % display_step == 0:
        pred = logistic_regression(batch_x)
        loss = cross_entropy(pred, batch_y)
        acc = accuracy(pred, batch_y)
        print("step: %i, loss: %f, accuracy: %f" % (step, loss, acc))

step: 50, loss: 511.156372, accuracy: 0.640625
step: 100, loss: 285.288330, accuracy: 0.789062
step: 150, loss: 237.997314, accuracy: 0.804688
step: 200, loss: 194.171921, accuracy: 0.839844
step: 250, loss: 197.448730, accuracy: 0.875000
step: 300, loss: 215.134338, accuracy: 0.855469
step: 350, loss: 107.719086, accuracy: 0.890625
step: 400, loss: 141.498627, accuracy: 0.871094
step: 450, loss: 82.451706, accuracy: 0.933594
step: 500, loss: 88.396576, accuracy: 0.890625
step: 550, loss: 93.500801, accuracy: 0.886719
step: 600, loss: 100.532074, accuracy: 0.902344
step: 650, loss: 114.126762, accuracy: 0.898438
step: 700, loss: 72.538055, accuracy: 0.929688
step: 750, loss: 90.510681, accuracy: 0.921875
step: 800, loss: 66.654335, accuracy: 0.929688
step: 850, loss: 90.080620, accuracy: 0.902344
step: 900, loss: 91.745361, accuracy: 0.894531
step: 950, loss: 98.064079, accuracy: 0.902344
step: 1000, loss: 158.179825, accuracy: 0.867188


# Save and Load with TF Checkpoint

In [73]:
# Save weights and optimizer variables.
vars_to_save = {"W": W, "b": b, "optimizer": optimizer}
checkpoint = tf.train.Checkpoint(**vars_to_save)
saver = tf.train.CheckpointManager(
      checkpoint, directory="/home/zju/Documents/TuchaoZhang/models/LearningTF2/tf-example", max_to_keep=5)

# Save variables.
saver.save()

'/home/zju/Documents/TuchaoZhang/models/LearningTF2/tf-example/ckpt-1'

In [74]:
np.mean(W.numpy())

0.017873187

In [75]:
# Reset variables to test restore.
W = tf.Variable(tf.random.normal([num_features, num_classes]), name="weight")
b = tf.Variable(tf.zeros([num_classes]), name="bias")

# Check resetted weight value.
np.mean(W.numpy())

0.0054865666

In [76]:
vars_to_load = {"W": W, "b": b, "optimizer": optimizer}
checkpoint = tf.train.Checkpoint(**vars_to_load)
latest_ckpt = tf.train.latest_checkpoint("/home/zju/Documents/TuchaoZhang/models/LearningTF2/tf-example/tf-example/")
checkpoint.restore(latest_ckpt)

<tensorflow.python.training.tracking.util.InitializationOnlyStatus at 0x7fc20c32b990>

In [77]:
# Confirm that W has been correctly restored.
np.mean(W.numpy())

0.0054865666

# 2) TF Model

In [78]:
from tensorflow.keras import Model, layers

In [79]:
# MNIST dataset parameters.
num_classes = 10 # 0 to 9 digits
num_features = 784 # 28*28

# Training parameters.
learning_rate = 0.01
training_steps = 1000
batch_size = 256
display_step = 100

In [80]:
# Create TF Model.
class NeuralNet(Model):
    # Set layers.
    def __init__(self):
        super(NeuralNet, self).__init__(name="NeuralNet")
        # First fully-connected hidden layer.
        self.fc1 = layers.Dense(64, activation=tf.nn.relu)
        # Second fully-connected hidden layer.
        self.fc2 = layers.Dense(128, activation=tf.nn.relu)
        # Third fully-connecter hidden layer.
        self.out = layers.Dense(num_classes, activation=tf.nn.softmax)

    # Set forward pass.
    def __call__(self, x, is_training=False):
        x = self.fc1(x)
        x = self.out(x)
        if not is_training:
            # tf cross entropy expect logits without softmax, so only
            # apply softmax when not training.
            x = tf.nn.softmax(x)
        return x

# Build neural network model.
neural_net = NeuralNet()

In [81]:
# Cross-Entropy loss function.
def cross_entropy(y_pred, y_true):
    y_true = tf.cast(y_true, tf.int64)
    crossentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)
    return tf.reduce_mean(crossentropy)

# Accuracy metric.
def accuracy(y_pred, y_true):
    # Predicted class is the index of highest score in prediction vector (i.e. argmax).
    correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64))
    return tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# Adam optimizer.
optimizer = tf.optimizers.Adam(learning_rate)

In [84]:
# Optimization process.
def run_optimization(x, y):
    # Wrap computation inside a GradientTape for automatic differentiation.
    with tf.GradientTape() as g:
        pred = neural_net(x, is_training=True)
        loss = cross_entropy(pred, y)

        # Compute gradients.
        gradients = g.gradient(loss, neural_net.trainable_variables)

        # Update W and b following gradients.
        optimizer.apply_gradients(zip(gradients, neural_net.trainable_variables))

In [85]:
# Run training for the given number of steps.
for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):
    # Run the optimization to update W and b values.
    run_optimization(batch_x, batch_y)

    if step % display_step == 0:
        pred = neural_net(batch_x, is_training=False)
        loss = cross_entropy(pred, batch_y)
        acc = accuracy(pred, batch_y)
        print("step: %i, loss: %f, accuracy: %f" % (step, loss, acc))

step: 100, loss: 2.194707, accuracy: 0.851562
step: 200, loss: 2.180918, accuracy: 0.949219
step: 300, loss: 2.177780, accuracy: 0.968750
step: 400, loss: 2.180268, accuracy: 0.945312
step: 500, loss: 2.175300, accuracy: 0.984375
step: 600, loss: 2.177104, accuracy: 0.964844
step: 700, loss: 2.176259, accuracy: 0.968750
step: 800, loss: 2.179758, accuracy: 0.945312
step: 900, loss: 2.177743, accuracy: 0.960938
step: 1000, loss: 2.177008, accuracy: 0.964844


In [86]:
# Save TF model.
neural_net.save_weights(filepath="./tfmodel.ckpt")

In [87]:

# Re-build neural network model with default values.
neural_net = NeuralNet()
# Test model performance.
pred = neural_net(batch_x)
print("accuracy: %f" % accuracy(pred, batch_y))

accuracy: 0.078125


In [88]:
neural_net.load_weights(filepath="./tfmodel.ckpt")

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fc1e42d99d0>

In [89]:
# Test that weights loaded correctly.
pred = neural_net(batch_x)
print("accuracy: %f" % accuracy(pred, batch_y))

accuracy: 0.964844
