In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input, BatchNormalization, Conv2D, Activation, Add, MaxPool2D, GlobalAveragePooling2D
from tensorflow.keras.models import Model

In [2]:
class IdentityBlock(Model):
    def __init__(self, filters, kernel_size):
        super(IdentityBlock, self).__init__(name='')

        self.conv1 = Conv2D(filters, kernel_size, padding='same')
        self.bn1 = BatchNormalization()

        self.conv2 = Conv2D(filters, kernel_size, padding='same')
        self.bn2 = BatchNormalization()

        self.act = Activation('relu')
        self.add = Add()

    def call(self, input_tensor):
        x = self.conv1(input_tensor)
        x = self.bn1(x)
        x = self.act(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.add([x, input_tensor])
        x = self.act(x)
        return x 

In [3]:
class ResNet(Model):
    def __init__(self, num_classes):
        super(ResNet, self).__init__()
        self.conv = Conv2D(64, 7, padding='same')
        self.bn = BatchNormalization()
        self.act = Activation('relu')
        self.max_pool = MaxPool2D((3,3))

        self.id1 = IdentityBlock(64, 3)
        self.id2 = IdentityBlock(64, 3)

        self.global_pool = GlobalAveragePooling2D()
        self.classifier = Dense(num_classes, activation='softmax')

    def call(self, input):
        x = self.conv(input)
        x = self.bn(x)
        x = self.act(x)
        x = self.max_pool(x)

        x = self.id1(x)
        x = self.id2(x)

        x = self.global_pool(x)
        return self.classifier(x)

In [6]:
def preprocess(features):
    return tf.cast(features['image'], tf.float32) / 255.0, features['label']

In [7]:
import tensorflow_datasets as tfds
resnet = ResNet(10)
resnet.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
dataset = tfds.load('mnist', split=tfds.Split.TRAIN)
dataset = dataset.map(preprocess).batch(32)

In [8]:
resnet.fit(dataset, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x1bc7b15f648>