# Ungraded Lab: Implementing ResNet

In [14]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Add, Dense, GlobalAveragePooling2D, MaxPooling2D

## Implement Model subclasses

In [19]:
class IdentityBlock(tf.keras.Model):
    def __init__(self, filters, kernel_size):
        super(IdentityBlock, self).__init__(name='')
        
        self.conv1 = Conv2D(filters, kernel_size, padding='same')
        self.bn1 = BatchNormalization()
        
        self.conv2 = Conv2D(filters, kernel_size, padding='same')
        self.bn2 = BatchNormalization()
        
        self.act = Activation('relu')
        self.add = Add()
        
    def call(self, input_tensor):
        x = self.conv1(input_tensor)
        x = self.bn1(x)
        x = self.act(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act(x)
        
        x = self.add([x, input_tensor])
        x = self.act(x)
        
        return x

In [20]:
class ResNet(tf.keras.Model):
    def __init__(self, num_classes):
        super(ResNet, self).__init__()
        
        self.conv = Conv2D(64, 7, padding='same')
        self.bn = BatchNormalization()
        self.act = Activation('relu')
        
        self.max_pool = MaxPooling2D((3,3))
        
        self.id1a = IdentityBlock(64, 3)
        self.id1b = IdentityBlock(64, 3)
        
        self.global_pool = GlobalAveragePooling2D()
        self.classifier = Dense(num_classes, activation='softmax')
        
    def call(self, inputs):
        x = self.conv(inputs)
        x = self.bn(x)
        x = self.act(x)
        x = self.max_pool(x)
        x = self.id1a(x)
        x = self.id1b(x)
        x = self.global_pool(x)
        
        return self.classifier(x)

In [27]:
def preprocess(features):
    return tf.cast(features['image'], tf.float32)/255.0, features['label']

In [28]:
resnet = ResNet(10)
resnet.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [29]:
dataset = tfds.load('mnist', split=tfds.Split.TRAIN)
dataset = dataset.map(preprocess).batch(32)

In [31]:
resnet.fit(dataset, epochs=5)

Epoch 1/5
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m109s[0m 56ms/step - accuracy: 0.7621 - loss: 0.6894
Epoch 2/5
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m102s[0m 54ms/step - accuracy: 0.9782 - loss: 0.0769
Epoch 3/5
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m129s[0m 47ms/step - accuracy: 0.9842 - loss: 0.0534
Epoch 4/5
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m144s[0m 49ms/step - accuracy: 0.9880 - loss: 0.0417
Epoch 5/5
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m95s[0m 51ms/step - accuracy: 0.9899 - loss: 0.0341


<keras.src.callbacks.history.History at 0x788b5b922e00>