In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [4]:
mnist_dataset = tfds.load('mnist')

mnist_dataset

{'test': <_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>,
 'train': <_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>}

In [12]:
ds_train = mnist_dataset['train']
ds_test = mnist_dataset['test']

batch, = ds_train.take(1)
batch['image'].shape, batch['label'].shape

(TensorShape([28, 28, 1]), TensorShape([]))

In [59]:
ds_train = mnist_dataset['train']
ds_test = mnist_dataset['test']

ds_train = ds_train.map(lambda x: (x['image'], x['label']))
ds_train = ds_train.shuffle(buffer_size=10000)
ds_train = ds_train.repeat()
ds_train = ds_train.batch(32)

batch, = ds_train.take(1)
batch[0].shape, batch[1].shape

(TensorShape([32, 28, 28, 1]), TensorShape([32]))

In [65]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3),
                           strides=(1, 1), padding='same',
                           activation='relu', 
                           data_format='channels_last',
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPool2D(pool_size=(2, 2), 
                              strides=(2, 2),
                              padding='valid'),
    
    tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3),
                           strides=(1, 1), padding='same',
                           activation='relu'),
    tf.keras.layers.MaxPool2D(pool_size=(2, 2), 
                              strides=(2, 2),
                              padding='valid'),
    
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])


model.compute_output_shape(input_shape=(None, 28, 28, 1))


TensorShape([None, 10])

In [66]:
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [None]:
ds_train = 

In [67]:
history = model.fit(ds_train, epochs=10, steps_per_epoch=np.ceil(50000 / 32))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


## Prediciton on the test-set

In [86]:
ds_test = mnist_dataset['test']

ds_test = ds_test.batch(32)


test_labels = []
test_preds = []
for batch in ds_test:
    proba = model.predict(batch['image'])
    preds = tf.argmax(proba, axis=1)
    test_preds.extend(list(preds.numpy()))
    test_labels.extend(list(batch['label'].numpy()))
    
print(len(test_labels), len(test_preds))
print(test_labels[:4], test_preds[:4])

correct = np.sum(np.array(test_labels) == np.array(test_preds))
print('Test Accuracy :: ', correct / len(test_labels))

10000 10000
[6, 2, 3, 7] [6, 2, 3, 7]
Test Accuracy ::  0.9885


## Saving the model

In [89]:
config = model.get_config()

In [90]:
json_config = model.to_json()

with open('mnist-clf.json', 'w') as json_file:
    json_file.write(json_config)
    
model.save_weights('mnist-clf-weights.h5')

## Restart the kernel

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np


In [2]:
## setup the dataset:
mnist_dataset = tfds.load('mnist')
ds_test = mnist_dataset['test']
ds_test = ds_test.batch(32)


## restore the saved model
with open('mnist-clf.json') as json_file:
    json_config = json_file.read()

model = tf.keras.models.model_from_json(json_config)
model.load_weights('mnist-clf-weights.h5')


## prediciton on the test-set

test_labels = []
test_preds = []
for batch in ds_test:
    proba = model.predict(batch['image'])
    preds = tf.argmax(proba, axis=1)
    test_preds.extend(list(preds.numpy()))
    test_labels.extend(list(batch['label'].numpy()))
    
print(len(test_labels), len(test_preds))
print(test_labels[:4], test_preds[:4])

correct = np.sum(np.array(test_labels) == np.array(test_preds))
print('Test Accuracy :: ', correct / len(test_labels))

10000 10000
[6, 2, 3, 7] [6, 2, 3, 7]
Test Accuracy ::  0.9885
