In [1]:
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

In [2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [4]:
x_train.shape

(60000, 28, 28, 1)

In [6]:
# CNN -> BatchNorm -> ReLU (common structure)
# x10 (a lot of code to write!)


class CNNBlock(layers.Layer):
    def __init__(self, out_channels, kernel_size=3):
        super(CNNBlock, self).__init__()
        self.conv = layers.Conv2D(out_channels, kernel_size, padding="same")
        self.bn = layers.BatchNormalization()

    def call(self, input_tensor, training=False):
        x = self.conv(input_tensor)
        x = self.bn(x, training=training)
        x = tf.nn.relu(x)
        return x


In [7]:
model = keras.Sequential(
    [CNNBlock(32), CNNBlock(64), CNNBlock(128), layers.Flatten(), layers.Dense(10)]
)


In [9]:
model.layers

<__main__.CNNBlock at 0x7cb402e312a0>

In [10]:
model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)


In [12]:
model.fit(x_train, y_train, batch_size=64, epochs=1, verbose=1)
model.evaluate(x_test, y_test, batch_size=64, verbose=2)

157/157 - 1s - loss: 0.0994 - accuracy: 0.9796 - 1s/epoch - 7ms/step


[0.09941404312849045, 0.9796000123023987]

In [13]:
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 cnn_block (CNNBlock)        (None, 28, 28, 32)        448       
                                                                 
 cnn_block_1 (CNNBlock)      (None, 28, 28, 64)        18752     
                                                                 
 cnn_block_2 (CNNBlock)      (None, 28, 28, 128)       74368     
                                                                 
 flatten (Flatten)           (None, 100352)            0         
                                                                 
 dense (Dense)               (None, 10)                1003530   
                                                                 
Total params: 1,097,098
Trainable params: 1,096,650
Non-trainable params: 448
_________________________________________________________________


In [14]:
class ResBlock(layers.Layer):
    def __init__(self, channels):
        super(ResBlock, self).__init__()
        self.channels = channels
        self.cnn1 = CNNBlock(channels[0], 3)
        self.cnn2 = CNNBlock(channels[1], 3)
        self.cnn3 = CNNBlock(channels[2], 3)
        self.pooling = layers.MaxPooling2D()
        self.identity_mapping = layers.Conv2D(channels[1], 3, padding="same")

    def call(self, input_tensor, training=False):
        x = self.cnn1(input_tensor, training=training)
        x = self.cnn2(x, training=training)
        x = self.cnn3(x + self.identity_mapping(input_tensor), training=training)
        x = self.pooling(x)
        return x


In [15]:
class ResNet_Like(keras.Model):
    def __init__(self, num_classes=10):
        super(ResNet_Like, self).__init__()
        self.block1 = ResBlock([32, 32, 64])
        self.block2 = ResBlock([128, 128, 256])
        self.block3 = ResBlock([128, 256, 512])
        self.pool = layers.GlobalAveragePooling2D()
        self.classifier = layers.Dense(num_classes)

    def call(self, input_tensor, training=False):
        x = self.block1(input_tensor, training=training)
        x = self.block2(x, training=training)
        x = self.block3(x, training=training)
        x = self.pool(x, training=training)
        x = self.classifier(x)
        return x

    def model(self):
        x = keras.Input(shape=(28, 28, 1))
        return keras.Model(inputs=[x], outputs=self.call(x))


In [16]:
model = ResNet_Like().model()
# base_input = model.layers[0].input
# base_output = model.layers[2].output
# output = layers.Dense(10)(layers.Flatten()(base_output))
# model = keras.Model(base_input, output)


In [17]:
model.layers

[<keras.engine.input_layer.InputLayer at 0x7cb2d80c28c0>,
 <__main__.ResBlock at 0x7cb2e16910f0>,
 <__main__.ResBlock at 0x7cb2d816a2f0>,
 <__main__.ResBlock at 0x7cb2d810ad40>,
 <keras.layers.pooling.global_average_pooling2d.GlobalAveragePooling2D at 0x7cb2d80c2e00>,
 <keras.layers.core.dense.Dense at 0x7cb2d80c36a0>]

In [18]:
model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)


In [19]:
model.fit(x_train, y_train, batch_size=64, epochs=1, verbose=1)
model.evaluate(x_test, y_test, batch_size=64, verbose=2)
# model.save("pretrained")

157/157 - 2s - loss: 0.0349 - accuracy: 0.9884 - 2s/epoch - 12ms/step


[0.034905821084976196, 0.9883999824523926]

In [20]:
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 res_block (ResBlock)        (None, 14, 14, 64)        28896     
                                                                 
 res_block_1 (ResBlock)      (None, 7, 7, 256)         592512    
                                                                 
 res_block_2 (ResBlock)      (None, 3, 3, 512)         2364032   
                                                                 
 global_average_pooling2d (G  (None, 512)              0         
 lobalAveragePooling2D)                                          
                                                                 
 dense_1 (Dense)             (None, 10)                5130      
                                                             