In [21]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [49]:
# Load MNIST dataset
mnist_ds = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist_ds.load_data()
# Normalise dataset
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add the dimenison to tensors that serve as the channel
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")

In [50]:
# Define the LeNet architecture
class LeNet(tf.keras.Model):
    def __init__(self):
        super(LeNet, self).__init__()
        
        self.lenet = tf.keras.models.Sequential([
            tf.keras.layers.Conv2D(filters=6,  kernel_size=(5,5), padding='same',  activation='relu'),
            tf.keras.layers.MaxPool2D(),
            tf.keras.layers.Conv2D(filters=16, kernel_size=(5,5), padding='valid', activation='relu'),
            tf.keras.layers.MaxPool2D(),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(120, activation='relu'),
            tf.keras.layers.Dense(84,  activation='relu'),
            tf.keras.layers.Dense(10,  activation='softmax')
        ])
        
    def call(self, inputs):
        return self.lenet(inputs)

In [52]:
lenet = LeNet()
lenet.compile(optimizer='Adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

# Train the model
lenet.fit(x=x_train, y=y_train,
          batch_size=64, 
          epochs=3,
          validation_data=(x_test,y_test),
          verbose=2)

Epoch 1/3
938/938 - 11s - loss: 1.6818 - accuracy: 0.7860 - val_loss: 1.5895 - val_accuracy: 0.8729
Epoch 2/3
938/938 - 11s - loss: 1.5272 - accuracy: 0.9353 - val_loss: 1.4853 - val_accuracy: 0.9768
Epoch 3/3
938/938 - 10s - loss: 1.4849 - accuracy: 0.9774 - val_loss: 1.4814 - val_accuracy: 0.9799


<tensorflow.python.keras.callbacks.History at 0x14f2b4110>

In [69]:
# Evaluate the model
predictions = lenet(x_test, training=False)

acc_funct = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
acc_funct(y_test,predictions)
print('Accuracy = %.2f%%' % (acc_funct.result() * 100))

Accuracy = 97.99%
