In [1]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# Get the data
(trainX,trainy),(testX,testy) = keras.datasets.mnist.load_data()

In [3]:
# Scale the pixel values to between 0 and 1 and add new channel dimension
trainX = trainX/255.0
testX = testX/255.0
trainX = trainX[..., tf.newaxis]
testX = testX[..., tf.newaxis]

In [4]:
# Create datasets from existing tensors
trainDS = tf.data.Dataset.from_tensor_slices(
            (trainX,trainy)).shuffle(10000).batch(32)
testDS = tf.data.Dataset.from_tensor_slices(
            (testX,testy)).batch(32)

In [5]:
# Build the model
class CNNModel(keras.Model):
    def __init__(self):
        super(CNNModel,self).__init__()
        self.conv = keras.layers.Conv2D(32, 3,activation='relu')
        self.flatten = keras.layers.Flatten()
        self.d1 = keras.layers.Dense(128, activation='relu')
        self.d2 = keras.layers.Dense(10,activation='softmax')
    
    def call(self,x):
        x = self.conv(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

loss_calculator = keras.losses.SparseCategoricalCrossentropy()
optimizer = keras.optimizers.Adam()

train_loss = keras.metrics.Mean(name='train_loss')
train_acc = keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = keras.metrics.Mean(name='test_loss')
test_acc = keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')


In [8]:
# Define train and test steps
@tf.function
def train_step(images,labels):
    with tf.GradientTape() as tape:
        preds = model(images)
        loss = loss_calculator(labels,preds)
        gradients = tape.gradient(loss,model.trainable_variables)
        optimizer.apply_gradients(zip(gradients,model.trainable_variables))
        train_loss(loss)
        train_acc(labels,preds)

@tf.function
def test_step(images,labels):
    preds = model(images)
    loss = loss_calculator(labels,preds)
    test_loss(loss)
    test_acc(labels,preds)
    

In [11]:
# Create the model
model = CNNModel()

In [14]:
# Fit the model
for ep in range(5):
    for images, labels in trainDS:
        train_step(images,labels)
    
    for images,labels in testDS:
        test_step(images,labels)
    print("{}: loss: {}, acc: {}, val_loss: {}, val_acc: {}".format(ep,
                train_loss.result(),100*train_acc.result(),
                test_loss.result(),100*test_acc.result()))

KeyboardInterrupt: 