In [1]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
from training_dashboard import TrainingDashboard

In [2]:
def get_model():
    model = tf.keras.Sequential(
        [
            tf.keras.Input(shape=(28, 28, 1)),
            layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Flatten(),
            layers.Dropout(0.5),
            layers.Dense(10, activation="softmax"),
        ]
    )
    return model

In [3]:
num_classes = 10
input_shape = (28, 28, 1)

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255

x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

In [4]:
auc = tf.keras.metrics.AUC(name='auc')

model = get_model()
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy", auc])

callback = TrainingDashboard(validation=True, # because we are using validation data and want to track its metrics
                             min_loss=0, # we want the loss axes to be fixed on the lower end
                             metrics=["accuracy", "auc"], # metrics that we want plotted
                             batch_step=10, # plot every 10th batch
                             min_metric_dict={"accuracy": 0, "auc": 0}, # minimum possible value for metrics used (optional)
                             max_metric_dict={"accuracy": 1, "auc": 1}) # maximum possible value for metrics used (optional)
model.fit(x_train,
          y_train,
          batch_size=512,
          epochs=25,
          verbose=1,
          validation_split=0.2,
          callbacks=[callback])

AppLayout(children=(Figure(axes=[Axis(label='batch', scale=OrdinalScale()), Axis(grid_lines='dashed', label='t…

AppLayout(children=(Figure(axes=[Axis(label='epoch', scale=OrdinalScale()), Axis(label='accuracy', orientation…

HBox(children=(Output(),), layout=Layout(align_items='center', display='flex', flex_flow='column'))

Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25


<tensorflow.python.keras.callbacks.History at 0x655234d50>