In [2]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds

In [56]:
(ds_train, ds_test), ds_info = tfds.load('mnist', split=['train', 'test'],
                                        shuffle_files=True,
                                        as_supervised=True,
                                        with_info=True)

In [57]:
def normalize(img, label):
    return tf.cast(img, tf.float32)/255., label


In [58]:
ds_train = ds_train.map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache().shuffle(ds_info.splits['train'].num_examples).batch(128).prefetch(tf.data.experimental.AUTOTUNE)

In [59]:
ds_test = ds_test.map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.cache().shuffle(ds_info.splits['test'].num_examples).batch(128).prefetch(tf.data.experimental.AUTOTUNE)

In [60]:
classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [61]:
plt.figure(figsize=(10, 10))
for images, labels in ds_train.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i+1)
        image = np.squeeze(images[i])
        plt.imshow(image)
        plt.title(classes[labels[i]])
        plt.axis('off')

In [62]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

In [63]:
model.summary()

In [64]:
tf.keras.utils.plot_model(model)

In [65]:
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])


In [66]:
history = model.fit(ds_train,
                   validation_data=ds_test,
                   epochs=8)

In [67]:
for e in history.history:
    print(e)

In [68]:
plt.figure(figsize=(10, 10))
plt.plot(range(8), history.history['val_loss'])

In [69]:
model.evaluate(ds_test)

In [70]:
pred = model.predict(ds_test)

In [71]:
pred

In [72]:
pred = list(map(lambda x: np.argmax(x), pred))

In [73]:
pred[:20]

In [75]:
for element in ds_test.as_numpy_iterator():
    print(element)

In [74]:
plt.figure(figsize=(20, 20))

for images, labels in ds_test.take(1):
    for i in range(81):
        ax = plt.subplot(9, 9, i+1)
        image = np.squeeze(images[i])
        plt.imshow(image)
        prediction = pred[i] == classes[labels[i]]
        plt.title(pred[labels[i]], color=('blue' if prediction else 'red'))
        plt.axis('off')