<a href="https://colab.research.google.com/github/visith1577/Deep-learning/blob/main/resnet_from_sratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from keras import backend as K
%matplotlib inline

In [6]:
class IdentityBlock(tf.keras.Model):
  def __init__(self, filters, kernel_size):
    super().__init__()
    
    self.conv1 = tf.keras.layers.Conv2D(filters, kernel_size, padding='same')
    self.bn1 = tf.keras.layers.BatchNormalization()

    self.conv2 = tf.keras.layers.Conv2D(filters, kernel_size, padding='same')
    self.bn2 = tf.keras.layers.BatchNormalization()

    self.act = tf.keras.layers.Activation('relu')
    self.add = tf.keras.layers.Add()

  def call(self, input_tensor):
    x = self.conv1(input_tensor)
    x = self.bn1(x)
    x = self.act(x)

    x = self.conv2(x)
    x = self.bn2(x)
    
    x = self.add([input_tensor, x])
    x = self.act(x)

    return x


In [10]:
class ResNetBlock(tf.keras.Model):
  def __init__(self, num_classes):
    super().__init__()
    self.conv = tf.keras.layers.Conv2D(64, 7, padding='same')
    self.bn = tf.keras.layers.BatchNormalization()
    self.act = tf.keras.layers.Activation('relu')
    self.max_pool = tf.keras.layers.MaxPool2D((3, 3))

    self.idlb1 = IdentityBlock(64, 3)
    self.idlb2 = IdentityBlock(64, 3)

    self.global_pool = tf.keras.layers.GlobalAveragePooling2D()
    self.classifier = tf.keras.layers.Dense(num_classes, activation='softmax')

  def call(self, input):
    x = self.conv(input)
    x = self.bn(x)
    x = self.act(x)
    x = self.max_pool(x)

    # insert the identity blocks in the middle of the network
    x = self.idlb1(x)
    x = self.idlb2(x)

    x = self.global_pool(x)
    return self.classifier(x)    

In [None]:
def preprocess(features):
    return tf.cast(features['image'], tf.float32) / 255., features['label']

# create a ResNet instance with 10 output units for MNIST
resnet = ResNetBlock(10)
resnet.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# load and preprocess the dataset
import tensorflow_datasets as tfds
dataset = tfds.load('mnist', split=tfds.Split.TRAIN, data_dir='./data')
dataset = dataset.map(preprocess).batch(32)

# train the model.
resnet.fit(dataset, epochs=1)
