#### Importing libraries

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

import io
import itertools

import numpy as np
from sklearn import metrics
import matplotlib.pyplot as plt



In [6]:
BATCH_SIZE = 128
EPOCHS = 20
BUFFER_SIZE = 70_000

In [2]:
mnist_data, mnist_info = tfds.load('mnist', with_info=True,as_supervised=True )

In [3]:
mnist_train, mnist_test = mnist_data['train'], mnist_data['test']

In [42]:
def scale(image, label):
    image = tf.cast(image,tf.float32)
    image /=255.
    return image, label

In [43]:
train_val_data = mnist_train.map(scale)
test_data = mnist_test.map(scale)

In [44]:
val_size = tf.cast(.1*mnist_info.splits['train'].num_examples,tf.int64)
test_size = tf.cast(mnist_info.splits['test'].num_examples, tf.int64)

In [45]:
train_val_data = train_val_data.shuffle(BUFFER_SIZE)

In [46]:
train_data = train_val_data.skip(val_size)
val_data = train_val_data.take(val_size)

In [47]:
train_data = train_data.batch(BATCH_SIZE)
test_data = test_data.batch(test_size)
val_data = val_data.batch(val_size)

In [49]:
#extracting the arrays from val_data for building confusion matrix
for images, labels in val_data:
    image_val = images.numpy()
    label_val = labels.numpy()

#### creating the model

In [50]:
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(50,5, activation='relu',input_shape =(28,28,1)),
    tf.keras.layers.MaxPooling2D(pool_size=(2,2)),
    tf.keras.layers.Conv2D(50,3, activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2,2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10)
])

In [51]:
model.summary(line_length=80)

Model: "sequential_1"
________________________________________________________________________________
Layer (type)                        Output Shape                    Param #     
conv2d_2 (Conv2D)                   (None, 24, 24, 50)              1300        
________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)      (None, 12, 12, 50)              0           
________________________________________________________________________________
conv2d_3 (Conv2D)                   (None, 10, 10, 50)              22550       
________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)      (None, 5, 5, 50)                0           
________________________________________________________________________________
flatten_1 (Flatten)                 (None, 1250)                    0           
________________________________________________________________________________
dense_

In [52]:
#definig loss func
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [53]:
#compliling the model
model.compile(optimizer='adam', loss= loss_fn,metrics=['accuracy'] )

In [54]:
log_dir = 'logs\\fit\\'+'run-1'

In [55]:
#return a figure of confusion matrix
def plot_confusion_matrix(cm, class_names):
    figure = plt.figure(figsize=(12, 12))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Confusion matrix")
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)

    # Normalize the confusion matrix
    cm = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)

    # Use white text if squares are dark; otherwise black.
    threshold = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        color = "white" if cm[i, j] > threshold else "black"
        plt.text(j, i, cm[i, j], horizontalalignment="center", color=color)

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label') 
    return figure

In [56]:
#converting the matplotlib figure to png
def plot_to_image(figure):
    
    #save the plot to a png in memory
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    
    #closing the figure to prevent from displaying in the notebook
    plt.close(figure)
    
    buf.seek(0)
    
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    
    #add the batch dimension
    image = tf.expand_dims(image, 0)
    
    return image

In [57]:
#defining a file writer variable for logging
file_writer_cm = tf.summary.create_file_writer(log_dir + '/cm')
def log_confusion_matrix(epoch, logs):
    #predict values for validation set
    test_pred_r = model.predict(image_val)
    test_pred = np.argmax(test_pred_r, axis=1)
    
    cm = metrics.confusion_matrix(label_val, test_pred)
    
    figure = plot_confusion_matrix(cm, class_names=['0','1','2','3','4','5','6','7','8','9'])
    cm_image = plot_to_image(figure)
    
    #log the confusion matrix as image summary
    with file_writer_cm.as_default():
        tf.summary.image('Confusion Matrix', cm_image, step=epoch)

In [60]:
#defining callbacks
cm_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=log_confusion_matrix)
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir =log_dir, histogram_freq = 1, profile_batch =0)

In [59]:
#early stopping callback
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    mode = 'auto',
    min_delta = 0,
    patience= 2,
    verbose = 0,
    restore_best_weights= True
)


In [61]:
model.fit(
    train_data,
    epochs=EPOCHS,
    verbose=2,
    callbacks=[tensorboard_callback,cm_callback,early_stopping],
    validation_data= val_data,
)

Epoch 1/20
422/422 - 26s - loss: 0.2861 - accuracy: 0.9174 - val_loss: 0.0905 - val_accuracy: 0.9733
Epoch 2/20
422/422 - 25s - loss: 0.0726 - accuracy: 0.9785 - val_loss: 0.0483 - val_accuracy: 0.9858
Epoch 3/20
422/422 - 25s - loss: 0.0541 - accuracy: 0.9841 - val_loss: 0.0613 - val_accuracy: 0.9828
Epoch 4/20
422/422 - 26s - loss: 0.0449 - accuracy: 0.9864 - val_loss: 0.0299 - val_accuracy: 0.9907
Epoch 5/20
422/422 - 25s - loss: 0.0367 - accuracy: 0.9886 - val_loss: 0.0288 - val_accuracy: 0.9903
Epoch 6/20
422/422 - 26s - loss: 0.0324 - accuracy: 0.9902 - val_loss: 0.0249 - val_accuracy: 0.9928
Epoch 7/20
422/422 - 26s - loss: 0.0286 - accuracy: 0.9914 - val_loss: 0.0275 - val_accuracy: 0.9917
Epoch 8/20
422/422 - 26s - loss: 0.0228 - accuracy: 0.9925 - val_loss: 0.0182 - val_accuracy: 0.9942
Epoch 9/20
422/422 - 25s - loss: 0.0216 - accuracy: 0.9930 - val_loss: 0.0188 - val_accuracy: 0.9945
Epoch 10/20
422/422 - 25s - loss: 0.0185 - accuracy: 0.9942 - val_loss: 0.0126 - val_accura

<tensorflow.python.keras.callbacks.History at 0x23c6886e070>

#### Visualisingin Tensorboard

In [62]:
%load_ext tensorboard
%tensorboard --logdir 'logs/fit'

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard
