In [0]:
from __future__ import absolute_import, division, print_function

import os

import tensorflow as tf
from tensorflow import keras

# Data

## Load data

In [6]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
print('train_images:', train_images.shape)
print('train_labels:', train_labels.shape)
print('test_images:', test_images.shape)
print('test_labels:', test_labels.shape)


train_images: (60000, 28, 28)
train_labels: (60000,)
test_images: (10000, 28, 28)
test_labels: (10000,)


## Reshape data

In [7]:
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

print('train_images:', train_images.shape)
print('train_labels:', train_labels.shape)
print('test_images:', test_images.shape)
print('test_labels:', test_labels.shape)

train_images: (1000, 784)
train_labels: (1000,)
test_images: (1000, 784)
test_labels: (1000,)


# Model

## Create model

In [0]:
# Returns a short sequential model
def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation=tf.keras.activations.relu, input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10, activation=tf.keras.activations.softmax)
  ])
  
  model.compile(optimizer=tf.keras.optimizers.Adam(),
                loss=tf.keras.losses.sparse_categorical_crossentropy,
                metrics=['accuracy'])
  
  return model

model = create_model()




## Create check-point callback

In [0]:
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1, save_weights_only=True,
    # Save weights, every 5-epochs.
    period=5)

# Train

## Save model

In [28]:
model = create_model()
model.save_weights(checkpoint_path.format(epoch=0))
model.fit(train_images, train_labels,
          epochs = 50, callbacks = [cp_callback],
          validation_data = (test_images,test_labels),
          verbose=0)


Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00005: saving model to training_2/cp-0005.ckpt

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00010: saving model to training_2/cp-0010.ckpt

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00015: saving model to training_2/cp-0015.ckpt

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00020: saving model to training_2/cp-0020.ckpt

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00025: saving model to training_2/cp-0025.ckpt

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00030: saving model to training_2/cp-0030.ckpt

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00035: saving model to training_2/cp-0035.ckpt

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00040: saving model to training_2/cp-0040.ckpt

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00045: saving model to training_2/cp-0045.ckpt

Consider 

<tensorflow.python.keras.callbacks.History at 0x7f4442db9828>

In [29]:
!ls -l training_2

total 17604
-rw-r--r-- 1 root root      81 Apr  1 08:20 checkpoint
-rw-r--r-- 1 root root 1631517 Apr  1 08:20 cp-0000.ckpt.data-00000-of-00001
-rw-r--r-- 1 root root     648 Apr  1 08:20 cp-0000.ckpt.index
-rw-r--r-- 1 root root 1631517 Apr  1 08:20 cp-0005.ckpt.data-00000-of-00001
-rw-r--r-- 1 root root     648 Apr  1 08:20 cp-0005.ckpt.index
-rw-r--r-- 1 root root 1631517 Apr  1 08:20 cp-0010.ckpt.data-00000-of-00001
-rw-r--r-- 1 root root     648 Apr  1 08:20 cp-0010.ckpt.index
-rw-r--r-- 1 root root 1631517 Apr  1 08:20 cp-0015.ckpt.data-00000-of-00001
-rw-r--r-- 1 root root     648 Apr  1 08:20 cp-0015.ckpt.index
-rw-r--r-- 1 root root 1631517 Apr  1 08:20 cp-0020.ckpt.data-00000-of-00001
-rw-r--r-- 1 root root     648 Apr  1 08:20 cp-0020.ckpt.index
-rw-r--r-- 1 root root 1631517 Apr  1 08:20 cp-0025.ckpt.data-00000-of-00001
-rw-r--r-- 1 root root     648 Apr  1 08:20 cp-0025.ckpt.index
-rw-r--r-- 1 root root 1631517 Apr  1 08:20 cp-0030.ckpt.data-00000-of-00001
-rw-r--r-- 1 roo

# Predict

## Restore model

In [19]:
latest = tf.train.latest_checkpoint(checkpoint_dir)
print(latest)

training_2/cp-0050.ckpt


In [20]:
model = create_model()
model.load_weights(latest)
loss, acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

Restored model, accuracy: 87.70%
