In [1]:
try:
  %tensorflow_version 2.x
except Exception:
  pass

In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.layers import Layer

In [13]:
class Identity_Block(tf.keras.Model):
  def __init__(self,filters,kernel_size):
    super(Identity_Block,self).__init__(name="")
    self.conv1=tf.keras.layers.Conv2D(filters,kernel_size,padding='same')
    self.bn1=tf.keras.layers.BatchNormalization()
    self.conv2=tf.keras.layers.Conv2D(filters,kernel_size,padding='same')
    self.bn2=tf.keras.layers.BatchNormalization()
    self.act=tf.keras.layers.Activation('relu')
    self.add=tf.keras.layers.Add()
  
  def call(self,input_tensor):
    x=self.conv1(input_tensor)
    x=self.bn1(x)
    x=self.act(x)
    x=self.conv2(x)
    x=self.bn2(x)
    x=self.act(x)
    x=self.add([x,input_tensor])
    x=self.act(x)
    return x

In [14]:
class ResNet(tf.keras.Model):
  def __init__(self,classes):
    super(ResNet,self).__init__()
    self.conv = tf.keras.layers.Conv2D(64, 7, padding='same')
    self.bn = tf.keras.layers.BatchNormalization()
    self.act = tf.keras.layers.Activation('relu')
    self.max_pool = tf.keras.layers.MaxPool2D((3, 3))
    self.id1a = Identity_Block(64, 3)
    self.id1b = Identity_Block(64, 3)
    self.global_pool = tf.keras.layers.GlobalAveragePooling2D()
    self.classifier = tf.keras.layers.Dense(classes, activation='softmax')

  def call(self,inputs):
    x = self.conv(inputs)
    x = self.bn(x)
    x = self.act(x)
    x = self.max_pool(x)
    x=self.id1a(x)
    x=self.id1b(x)
    x=self.global_pool(x)
    return self.classifier(x)


In [15]:
def preprocess(features):
    return tf.cast(features['image'], tf.float32) / 255., features['label']

In [16]:
resnet = ResNet(10)
resnet.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
dataset = tfds.load('mnist', split=tfds.Split.TRAIN, data_dir='./data')
dataset = dataset.map(preprocess).batch(32)
resnet.fit(dataset, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7fc3ef1ef0d0>

In [17]:
dataset = tfds.load('mnist', split=tfds.Split.TEST, data_dir='./data')
dataset = dataset.map(preprocess).batch(32)
resnet.evaluate(dataset)



[0.3313046395778656, 0.9032999873161316]