In [1]:
%matplotlib inline

import os
import numpy as np
import tensorflow as tf
from os.path import join
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Flatten, Dense, Conv2D, MaxPooling2D
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

In [2]:
(X_tr, y_tr), (X_ts, y_ts) = tf.keras.datasets.mnist.load_data()

In [3]:
nodes = 128
n_classes = np.unique(y_tr).size

model = tf.keras.Sequential()
model.add(Flatten(input_shape=(28,28,)))
model.add(Dense(nodes, activation='relu'))
model.add(Dense(n_classes, activation='softmax'))

model.compile(optimizer='sgd',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [4]:
checkpoint_path = "checkpoint/cp_{epoch:02d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
os.makedirs(checkpoint_dir, exist_ok=True)

cp_callback = ModelCheckpoint(filepath=checkpoint_path,
                              save_weights_only=True,
                              save_best_only=False,
                              verbose=1,
                              save_freq='epoch')

In [5]:
n_tr = 100

X_vl = X_tr[n_tr:]
X_tr = X_tr[:n_tr]

y_vl = y_tr[n_tr:]
y_tr = y_tr[:n_tr]

In [6]:
history = model.fit(x=X_tr,
                    y=y_tr,
                    epochs=5,
                    batch_size=100,
                    validation_data=(X_vl, y_vl),
                    callbacks=[cp_callback])

Train on 100 samples, validate on 59900 samples
Epoch 1/5

Epoch 00001: saving model to checkpoint/cp_01.ckpt
Epoch 2/5

Epoch 00002: saving model to checkpoint/cp_02.ckpt
Epoch 3/5

Epoch 00003: saving model to checkpoint/cp_03.ckpt
Epoch 4/5

Epoch 00004: saving model to checkpoint/cp_04.ckpt
Epoch 5/5

Epoch 00005: saving model to checkpoint/cp_05.ckpt


In [7]:
model_new = tf.keras.Sequential()
nodes=128
n_classes=10
model_new.add(Flatten(input_shape=(28,28,)))
model_new.add(Dense(nodes, activation='relu'))
model_new.add(Dense(n_classes, activation='softmax'))

model_new.compile(optimizer='sgd',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model_new.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_1 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 128)               100480    
_________________________________________________________________
dense_3 (Dense)              (None, 10)                1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________


In [13]:
latest = tf.train.latest_checkpoint(checkpoint_dir)
latest

'checkpoint/cp_05.ckpt'

In [14]:
model_new.load_weights(latest)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0xb3a46a6d8>

In [15]:
model_new.weights

[<tf.Variable 'dense_2/kernel:0' shape=(784, 128) dtype=float32, numpy=
 array([[-0.07783668, -0.02070515, -0.05053351, ..., -0.02784806,
         -0.04537976, -0.06529312],
        [-0.00805727,  0.0581274 , -0.00715165, ...,  0.066951  ,
         -0.03352453,  0.06791117],
        [ 0.01524504,  0.01145628, -0.04390196, ..., -0.00702702,
          0.05692265,  0.07503694],
        ...,
        [-0.04490545, -0.02593266,  0.07017189, ...,  0.05841414,
         -0.01869958, -0.07579486],
        [-0.07930264,  0.03656329, -0.01906898, ...,  0.02907538,
         -0.01763   , -0.04813129],
        [-0.05327711, -0.03428362, -0.01786644, ...,  0.03825085,
         -0.05213016,  0.04331684]], dtype=float32)>,
 <tf.Variable 'dense_2/bias:0' shape=(128,) dtype=float32, numpy=
 array([-9.5348878e-05, -7.5824256e-04, -1.6841773e-04, -2.7600443e-04,
        -9.5423294e-04, -7.7495375e-04, -3.2228179e-04, -2.2641562e-03,
         1.6353496e-04, -4.2950673e-04, -2.9968747e-04, -4.1736843e-04,
   

In [None]:
model_new.