<a href="https://colab.research.google.com/github/peter-lang/ml-tutorial/blob/master/01_TF_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install tensorflow>=2.0.0 tensorboard>=2.0.0 --upgrade

In [0]:
import tensorflow as tf
import numpy as np
import datetime
import matplotlib.pyplot as plt
import random

%matplotlib inline

In [0]:
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 # Normalized to [0, 1] range

print("Train shape", x_train.shape, y_train.shape)
print("Test shape", x_test.shape, y_test.shape)

plt.imshow(x_train[0], interpolation='nearest', cmap='Greys')
print("Truth: ", y_train[0])

In [0]:
def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
  ])

In [0]:
tf.keras.backend.clear_session()

model = create_model()
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
initial_epoch = 0

log_dir="logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=log_dir,
    histogram_freq=1,
    write_graph=True,
    write_images=True,
    update_freq='batch'
)


In [0]:
#!rm -rf logs
%load_ext tensorboard
%tensorboard --logdir logs

In [0]:
epochs = 1

train_start = 0
train_end = len(x_train)
#train_size = 1000
#train_start = random.randint(0, len(x_train) - train_size)
#train_end = train_start + train_size

model.fit(x=x_train[train_start:train_end], 
          y=y_train[train_start:train_end], 
          initial_epoch=initial_epoch,
          epochs=initial_epoch + epochs,
          #validation_data=(x_test, y_test), 
          callbacks=[tensorboard_callback])
initial_epoch += epochs

In [0]:
index = random.randint(0, len(x_test))
img = x_test[index]
truth = y_test[index]


plt.imshow(img, interpolation='nearest', cmap='Greys')

values = list(range(0, 10))
predictions = model.predict(np.expand_dims(img, axis=0))[0]

_, ax = plt.subplots()

ax.barh(values, predictions, align='center')
ax.set_yticks(values)
ax.invert_yaxis()
ax.set_xlabel('Confidence')

plt.show()

print("Prediction: ", predictions.argmax(), " Truth: ", truth)