In [1]:
import tensorflow as tf

In [2]:
mnist = tf.keras.datasets.mnist

In [19]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [20]:
x_train, x_test = x_train / 255.0, x_test / 255.0

In [21]:
x_train = tf.expand_dims(x_train[:1000].astype('float32'), axis=-1)
x_test = tf.expand_dims(x_test[:500].astype('float32'), axis=-1)

In [22]:
x_train.shape

TensorShape([1000, 28, 28, 1])

In [23]:
# Make datasets
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train[:1000])).shuffle(1000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test[:500])).batch(32)

In [30]:
class MyModel(tf.keras.Model):
    
    def __init__(self, loss, optimizer):
        super(MyModel, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu')
        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = tf.keras.layers.Dense(128, activation='relu')
        self.dense_out = tf.keras.layers.Dense(10, activation='linear')
        
        self.loss = loss
        self.optimizer = optimizer
        
    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.dense1(x)
        return self.dense_out(x)

    def train_step(self, images, labels):
        with tf.GradientTape() as tape:
            pred = self(images)
            loss = self.loss(labels, pred)
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        return pred, loss

    def test_step(self, images, labels):
        pred = self(images)
        loss = self.loss(labels, pred)
        return pred, loss

In [35]:
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(0.00001)

In [36]:
model = MyModel(loss=loss_obj, optimizer=optimizer)

In [37]:
train_loss = tf.keras.metrics.Mean()
test_loss = tf.keras.metrics.Mean()

train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()

In [38]:
EPOCHS = 5

writer = tf.summary.create_file_writer('summaries/live_mnist/run2')


for epoch in range(EPOCHS):
    
    train_loss.reset_states()
    test_loss.reset_states()
    train_acc.reset_states()
    test_acc.reset_states()
    
    for train_images, train_labels in train_ds:
        pred_train, loss_train = model.train_step(train_images, train_labels)
        train_loss(loss_train)
        train_acc(train_labels, pred_train)
        
    for test_images, test_labels in test_ds:
        pred_test, loss_test = model.test_step(test_images, test_labels)
        test_loss(loss_test)
        test_acc(test_labels, pred_test)
        
        
    with writer.as_default():
        tf.summary.scalar('Train loss', train_loss.result(), step=epoch)
        tf.summary.scalar('Train accuracy', train_acc.result(), step=epoch)
        tf.summary.scalar('Test loss', test_loss.result(), step=epoch)
        tf.summary.scalar('Test accuracy', test_acc.result(), step=epoch)
    