<a href="https://colab.research.google.com/github/salim-hbk/ai-ml/blob/main/Custom_Model_Resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.utils import plot_model

In [2]:
(train_x, train_y), (test_x, test_y) = fashion_mnist.load_data()

train_x = train_x.astype('float32')/255.
test_x = test_x.astype('float32')/255.

train_x = train_x.reshape(-1, 28,28,1)
test_x = test_x.reshape(-1, 28,28,1)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


In [3]:
class IdentityBlock(tf.keras.Model):
  def __init__(self, filters, kernel_size):
    super(IdentityBlock, self).__init__(name='')

    self.conv1 = tf.keras.layers.Conv2D(filters, kernel_size, padding='same')
    self.bn1 = tf.keras.layers.BatchNormalization()

    self.conv2 = tf.keras.layers.Conv2D(filters, kernel_size, padding='same')
    self.bn2 = tf.keras.layers.BatchNormalization()

    self.act = tf.keras.layers.Activation('relu')
    self.add = tf.keras.layers.Add()

  def call(self, input_tensor):
    x = self.conv1(input_tensor)
    x = self.bn1(x)
    x = self.act(x)

    x = self.conv2(x)
    x = self.bn2(x)
    x = self.act(x)

    x = self.add([x, input_tensor])
    x = self.act(x)
    return x

In [4]:
class Resnet(tf.keras.Model):

  def __init__(self, num_classes):
    super(Resnet, self).__init__()

    self.conv = tf.keras.layers.Conv2D(64, 7, padding='same')
    self.bn = tf.keras.layers.BatchNormalization()
    self.act = tf.keras.layers.Activation('relu')
    self.max_pool = tf.keras.layers.MaxPool2D((3,3))

    self.id1 = IdentityBlock(64, 3)
    self.id2 = IdentityBlock(64, 3)

    self.global_avg = tf.keras.layers.GlobalAveragePooling2D()
    self.classifier = tf.keras.layers.Dense(num_classes, 'softmax')

  def call(self, inputs):
    x = self.conv(inputs)
    x = self.bn(x)
    x = self.act(x)
    x = self.max_pool(x)
    x = self.id1(x)
    x = self.id2(x)

    x = self.global_avg(x)
    return self.classifier(x)

In [5]:
def preprocess(features):
  return tf.cast(features['image'], tf.float32) / 255., features['label']

In [6]:
import tensorflow_datasets as tfds
dataset = tfds.load('mnist', split=tfds.Split.TRAIN, data_dir='./data')
dataset = dataset.map(preprocess).batch(32)

resnet = Resnet(10)
resnet.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
resnet.fit(dataset, epochs=1)

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to ./data/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…



[1mDataset mnist downloaded and prepared to ./data/mnist/3.0.1. Subsequent calls will reuse this data.[0m


<tensorflow.python.keras.callbacks.History at 0x7f1e16e21c50>