In [1]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.utils import plot_model

In [8]:
def conv_batchnorm_relu(x, filters, kernel_size, strides=1):
    x = layers.Conv2D(filters, kernel_size, strides=strides, padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x

In [9]:
def conv_batchnorm(x, filters, kernel_size, strides=1):
    x = layers.Conv2D(filters, kernel_size, strides=strides, padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    return x

In [10]:
def residual_block(x, filters, strides=1, use_projection=False):
    shortcut = x

    if use_projection:
        shortcut = conv_batchnorm(shortcut, filters * 4, kernel_size = 1, strides=strides)

    x = conv_batchnorm_relu(x, filters, kernel_size=1, strides=strides)
    x = conv_batchnorm_relu(x, filters, kernel_size=3, strides=1)
    x = conv_batchnorm(x, filters * 4, kernel_size=1, strides=1)

    x = layers.Add()([x, shortcut])
    x = layers.ReLU()(x)
    return x

In [11]:
def build_resnet50(input_shape=(224, 224, 3), num_classes=1000):
    inputs = layers.Input(shape=input_shape)
    
    x = conv_batchnorm_relu(inputs, filters=64, kernel_size=7, strides=2)
    x = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(x)

    x = residual_block(x, filters=64, strides=1, use_projection=True)
    for _ in range(2):
        x = residual_block(x, filters=64, strides=1)

    x = residual_block(x, filters=128, strides=2, use_projection=True)
    for _ in range(3):
        x = residual_block(x, filters=128, strides=1)

    x = residual_block(x, filters=256, strides=2, use_projection=True)
    for _ in range(5):
        x = residual_block(x, filters=256, strides=1)

    x = residual_block(x, filters=512, strides=2, use_projection=True)
    for _ in range(2):
        x = residual_block(x, filters=512, strides=1)

    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(num_classes, activation='softmax')(x)

    model = models.Model(inputs, x)
    return model

In [None]:
resnet50_scratch_model = build_resnet50()
resnet50_scratch_model.summary()

In [None]:
resnet50_tf_model = ResNet50(include_top=True, weights="imagenet", input_shape=(224, 224, 3))
resnet50_tf_model.summary()