# Save and load models

Tutorial URL: https://www.tensorflow.org/tutorials/keras/save_and_load

Valid as of: 2023.05.02

## Imports

In [1]:
import ml.core.repo_paths
import tensorflow as tf

print(f"TensorFlow version: {tf.__version__}")

TensorFlow version: 2.10.1


# Loading the dataset

In [5]:
# Use prebuilt MNIST dataset
(data_train, label_train), (data_test, label_test) = tf.keras.datasets.mnist.load_data()

# Normalize data
data_train = data_train / 255.0
data_test = data_test / 255.0

# Declare model

In [6]:
# Declare model
model = tf.keras.models.Sequential(
    [
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10),
    ]
)

# Compile model
model.compile(
    optimizer="adam",
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

# Train with checkpoints

In [7]:
model.fit(
    data_train,
    label_train,
    epochs=10,
    validation_data=(data_test, label_test),
    callbacks=[
        tf.keras.callbacks.ModelCheckpoint(
            filepath=ml.core.repo_paths.get_dir_checkpoints("tutorial_tensorflow_save_and_load") / "cp.ckpt", save_weights_only=True, verbose=1
        )
    ],
)

Epoch 1/10
Epoch 1: saving model to c:\Users\sophi\Code\ml\artifacts\checkpoints\tutorial_tensorflow_save_and_load\cp.ckpt
Epoch 2/10
Epoch 2: saving model to c:\Users\sophi\Code\ml\artifacts\checkpoints\tutorial_tensorflow_save_and_load\cp.ckpt
Epoch 3/10
Epoch 3: saving model to c:\Users\sophi\Code\ml\artifacts\checkpoints\tutorial_tensorflow_save_and_load\cp.ckpt
Epoch 4/10
Epoch 4: saving model to c:\Users\sophi\Code\ml\artifacts\checkpoints\tutorial_tensorflow_save_and_load\cp.ckpt
Epoch 5/10
Epoch 5: saving model to c:\Users\sophi\Code\ml\artifacts\checkpoints\tutorial_tensorflow_save_and_load\cp.ckpt
Epoch 6/10
Epoch 6: saving model to c:\Users\sophi\Code\ml\artifacts\checkpoints\tutorial_tensorflow_save_and_load\cp.ckpt
Epoch 7/10
Epoch 7: saving model to c:\Users\sophi\Code\ml\artifacts\checkpoints\tutorial_tensorflow_save_and_load\cp.ckpt
Epoch 8/10
Epoch 8: saving model to c:\Users\sophi\Code\ml\artifacts\checkpoints\tutorial_tensorflow_save_and_load\cp.ckpt
Epoch 9/10
Epoch

<keras.callbacks.History at 0x1cd1f7fbdc0>

# Load model

In [9]:
# Create untrained model
model = tf.keras.models.Sequential(
    [
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10),
    ]
)

model.compile(
    optimizer="adam",
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

# Load weights
model.load_weights(ml.core.repo_paths.get_dir_checkpoints("tutorial_tensorflow_save_and_load") / "cp.ckpt")

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x1cd4c256fa0>

# Evaluate in different formats

In [10]:
# Evaluate
model.evaluate(data_test, label_test, verbose=2)

# Save entire model in SavedModel
model.save(ml.core.repo_paths.get_dir_models("tutorial_tensorflow_save_and_load") / "mnist")

# Load the entire model
model = tf.keras.models.load_model(ml.core.repo_paths.get_dir_models("tutorial_tensorflow_save_and_load") / "mnist")

model.evaluate(data_test, label_test, verbose=2)

# Save the entire model in HDF5
model.save(ml.core.repo_paths.get_dir_models("tutorial_tensorflow_save_and_load") / "mnist.h5")

# Load the entire model again
model = tf.keras.models.load_model(ml.core.repo_paths.get_dir_models("tutorial_tensorflow_save_and_load") / "mnist.h5")

model.evaluate(data_test, label_test, verbose=2)

313/313 - 1s - loss: 0.0684 - accuracy: 0.9812 - 1s/epoch - 4ms/step
INFO:tensorflow:Assets written to: c:\Users\sophi\Code\ml\artifacts\models\tutorial_tensorflow_save_and_load\mnist\assets
313/313 - 1s - loss: 0.0684 - accuracy: 0.9812 - 1000ms/epoch - 3ms/step
313/313 - 1s - loss: 0.0684 - accuracy: 0.9812 - 1s/epoch - 4ms/step


[0.06838671863079071, 0.9811999797821045]