In [1]:
%load_ext tensorboard
import tensorflow as tf
import datetime

In [2]:
# clear logs
!rm -rf ./logs/history_plotter1/

In [3]:
# load data

(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255., x_test / 255.

In [4]:
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))

train_dataset = train_dataset.shuffle(60000).batch(64)
test_dataset = test_dataset.batch(64)

In [5]:
# build model

model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
model.add(tf.keras.layers.Dense(512, activation='relu'))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(10))

In [6]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

# Define our metrics
train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')
test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy')

In [7]:
def train_step(inputs, targets):
    with tf.GradientTape() as tape:
        preds = model(inputs, training=True)
        loss = loss_object(targets, preds)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

    train_loss(loss)
    train_accuracy(targets, preds)


In [8]:
def test_step(inputs, targets):
    preds = model(inputs)
    loss = loss_object(targets, preds)

    test_loss(loss)
    test_accuracy(targets, preds)

In [9]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = 'logs/history_plotter1/' + current_time + '/train'
test_log_dir = 'logs/history_plotter1/' + current_time + '/test'
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
test_summary_writer = tf.summary.create_file_writer(test_log_dir)

In [10]:
epochs = 5
for epoch in range(epochs):
    # Reset metrics every epoch
    train_loss.reset_states()
    test_loss.reset_states()
    train_accuracy.reset_states()
    test_accuracy.reset_states()
    
    for (inp, tar) in train_dataset:
        train_step(inp, tar)
    with train_summary_writer.as_default():
        tf.summary.scalar('loss', train_loss.result(), step=epoch)
        tf.summary.scalar('accuracy', train_accuracy.result(), step=epoch)

    for (inp, tar) in test_dataset:
        test_step(inp, tar)
    with test_summary_writer.as_default():
        tf.summary.scalar('loss', test_loss.result(), step=epoch)
        tf.summary.scalar('accuracy', test_accuracy.result(), step=epoch)
  
    template = "Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}"
    print (
        template.format(
        epoch+1,
        train_loss.result(), 
        train_accuracy.result()*100,
        test_loss.result(), 
        test_accuracy.result()*100)
    )




To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

Epoch 1, Loss: 0.24630239605903625, Accuracy: 92.8550033569336, Test Loss: 0.12155003100633621, Test Accuracy: 96.45000457763672
Epoch 2, Loss: 0.10629675537347794, Accuracy: 96.74500274658203, Test Loss: 0.09101556241512299, Test Accuracy: 97.13999938964844
Epoch 3, Loss: 0.07132712751626968, Accuracy: 97.79666900634766, Test Loss: 0.07455600053071976, Test Accuracy: 97.68999481201172
Epoch 4, Loss: 0.055107444524765015, Accuracy: 98.32333374023438, Test Loss: 0.06994262337684631, Test Accuracy: 97.61000061035156
Epoch 5, Loss: 0.0440807044506073, Accuracy: 98.60832977294922, Test Loss: 0.06253194063901901, Test Accuracy: 98.05999755859375


In [13]:
%tensorboard --logdir logs/history_plotter1/

Reusing TensorBoard on port 6009 (pid 3714), started 0:50:58 ago. (Use '!kill 3714' to kill it.)