In [28]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist



In [29]:
(x_train,y_train),(x_test,y_test) = mnist.load_data()
x_train = x_train.reshape(-1,28,28,1).astype("float32") / 255.0
x_test = x_test.reshape(-1,28,28,1).astype("float32") / 255.0

In [30]:
class CNNBlock(layers.Layer):
  def __init__(self, out_channels, kernel_size=3):
    super(CNNBlock, self).__init__()
    self.conv = layers.Conv2D(out_channels, kernel_size, padding='same')
    self.bn = layers.BatchNormalization()

  def call(self,input_tensor,training=False):
   x = self.conv(input_tensor)
   x = self.bn(x,training=training)
   x = tf.nn.relu(x)
   return x   

In [31]:
model = keras.Sequential(
   [
     CNNBlock(32),
     CNNBlock(64),
     CNNBlock(128),
     layers.Flatten(),
     layers.Dense(10),
  ]
)

In [32]:
model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"]
)

In [33]:
model.fit(x_train,y_train,batch_size=64,epochs=1,verbose=2)

938/938 - 24s - loss: 0.5077 - accuracy: 0.9492 - 24s/epoch - 25ms/step


<keras.callbacks.History at 0x7f95f7542cd0>

In [34]:
model.summary()

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 cnn_block_33 (CNNBlock)     (None, 28, 28, 32)        448       
                                                                 
 cnn_block_34 (CNNBlock)     (None, 28, 28, 64)        18752     
                                                                 
 cnn_block_35 (CNNBlock)     (None, 28, 28, 128)       74368     
                                                                 
 flatten_2 (Flatten)         (None, 100352)            0         
                                                                 
 dense_5 (Dense)             (None, 10)                1003530   
                                                                 
Total params: 1,097,098
Trainable params: 1,096,650
Non-trainable params: 448
_________________________________________________________________


In [35]:
model.evaluate(x_test,y_test,batch_size=64,verbose=2)

157/157 - 2s - loss: 0.1032 - accuracy: 0.9840 - 2s/epoch - 10ms/step


[0.10324744880199432, 0.984000027179718]