In [1]:
import tensorflow as tf 
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
import cv2

In [2]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True
)

In [3]:
def normalize(image, label):
    return tf.cast(image, tf.float32) / 255., label

In [4]:
ds_train = ds_train.map(
    normalize, num_parallel_calls=tf.data.AUTOTUNE
)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

In [5]:
ds_test = ds_test.map(
    normalize, num_parallel_calls=tf.data.AUTOTUNE
)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

In [196]:
inp = tf.keras.layers.Input((28, 28, 1))

l = tf.keras.layers.Conv2D(4, 7, 1, activation="relu")(inp)
l = tf.keras.layers.MaxPool2D()(l)
l = tf.keras.layers.BatchNormalization()(l)

l = tf.keras.layers.Conv2D(2, 5, 1, activation="relu")(l)
l = tf.keras.layers.MaxPool2D()(l)
l = tf.keras.layers.BatchNormalization()(l)

l = tf.keras.layers.Conv2D(1, 1, 1, activation="relu")(l)

l = tf.keras.layers.Flatten()(l)

outp = tf.keras.layers.Dense(10, activation="softmax")(l)

model = tf.keras.Model(inputs=inp, outputs=outp)

model.summary()

Model: "model_80"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_85 (InputLayer)       [(None, 28, 28, 1)]       0         
                                                                 
 conv2d_164 (Conv2D)         (None, 22, 22, 4)         200       
                                                                 
 max_pooling2d_152 (MaxPooli  (None, 11, 11, 4)        0         
 ng2D)                                                           
                                                                 
 batch_normalization_120 (Ba  (None, 11, 11, 4)        16        
 tchNormalization)                                               
                                                                 
 conv2d_165 (Conv2D)         (None, 7, 7, 2)           202       
                                                                 
 max_pooling2d_153 (MaxPooli  (None, 3, 3, 2)          0  

In [197]:
model.compile(
    'adam', 
    'sparse_categorical_crossentropy',
    metrics=["acc"]
)

In [None]:
model.fit(ds_train, epochs=100)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100

In [195]:
model.evaluate(ds_test)



[0.19479070603847504, 0.9401999711990356]