In [1]:
import tensorflow as tf

from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model

In [5]:
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")

In [6]:
train_ds = tf.data.Dataset\
    .from_tensor_slices((x_train, y_train))\
    .shuffle(10000).batch(32)

test_ds = tf.data.Dataset\
    .from_tensor_slices((x_test, y_test))\
    .batch(32)

In [7]:
class MyModel(Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.conv1 = Conv2D(32, 3, activation='relu')
    self.flatten = Flatten()
    self.d1 = Dense(128, activation='relu')
    self.d2 = Dense(10)

  def call(self, x):
    x = self.conv1(x)
    x = self.flatten(x)
    x = self.d1(x)
    return self.d2(x)

model = MyModel()

In [8]:
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import Mean, SparseCategoricalAccuracy

loss_object = SparseCategoricalCrossentropy(from_logits=True)
optimizer = Adam()

train_loss = Mean(name='train_loss')
train_accuracy = SparseCategoricalAccuracy(name='train_accuracy')

test_loss = Mean(name='test_loss')
test_accuracy = SparseCategoricalAccuracy(name='test_accuracy')

In [9]:
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images, training=True)
        loss = loss_object(labels, predictions)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_accuracy(labels, predictions)

@tf.function
def test_step(images, labels):
    predictions = model(images, training=False)
    loss = loss_object(labels, predictions)

    test_loss(loss)
    test_accuracy(labels, predictions)

In [11]:
EPOCHS = 5

for epoch in range(EPOCHS):
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

    for images, labels in train_ds:
        train_step(images, labels)

    for images, labels in test_ds:
        test_step(images, labels)

    print(
        f"""
        Epoch {epoch + 1}
        Loss {train_loss.result()}
        Accuracy {train_accuracy.result() * 100}
        Test Loss {test_loss.result()}
        Test Accuracy {test_accuracy.result() * 100}
        """)


        Epoch 1
        Loss 0.04240938648581505
        Accuracy 98.65833282470703
        Test Loss 0.06251469999551773
        Test Accuracy 98.00999450683594
        

        Epoch 2
        Loss 0.022531278431415558
        Accuracy 99.28500366210938
        Test Loss 0.056357257068157196
        Test Accuracy 98.1500015258789
        

        Epoch 3
        Loss 0.013551957905292511
        Accuracy 99.5199966430664
        Test Loss 0.05379883944988251
        Test Accuracy 98.5199966430664
        

        Epoch 4
        Loss 0.009815998375415802
        Accuracy 99.66999816894531
        Test Loss 0.06393037736415863
        Test Accuracy 98.43999481201172
        

        Epoch 5
        Loss 0.007054198533296585
        Accuracy 99.76166534423828
        Test Loss 0.06637132912874222
        Test Accuracy 98.38999938964844
        
