<a href="https://colab.research.google.com/github/phonhay103/anything/blob/main/ResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
import tensorflow as tf
from tensorflow.keras import layers

In [68]:
def block(x, filters, kernel_size=3, stride=1, conv_shortcut=True, name=None):
    if conv_shortcut:
        shortcut = layers.Conv2D(4 * filters, 1, strides=stride, name=name + '_0_conv')(x)
        shortcut = layers.BatchNormalization(epsilon=1.001e-5, name=name + '_0_bn')(shortcut)
    else:
        shortcut = x

    x = layers.Conv2D(filters, 1, strides=stride, name=name+'_1_conv')(x)
    x = layers.BatchNormalization(epsilon=1.001e-5, name=name + '_1_bn')(x)
    x = layers.Activation('relu', name=name + '_1_relu')(x)

    x = layers.Conv2D(filters, kernel_size, padding='same', name=name + '_2_conv')(x)
    x = layers.BatchNormalization(epsilon=1.001e-5, name=name + '_2_bn')(x)
    x = layers.Activation('relu', name=name + '_2_relu')(x)

    x = layers.Conv2D(4 * filters, 1, name=name + '_3_conv')(x)
    x = layers.BatchNormalization(epsilon=1.001e-5, name=name + '_3_bn')(x)

    # Skip connection
    x = layers.Add(name=name + '_add')([shortcut, x])
    x = layers.Activation('relu', name=name + '_out')(x)
    return x

def stack(x, filters, blocks, stride=2, name=None):
    x = block(x, filters, stride=stride, name=name + '_block1')
    for i in range(2, blocks + 1):
        x = block(x, filters, conv_shortcut=False, name=name + '_block' + str(i))
    return x

In [69]:
def ResNet(stack_fn, input_shape=None, classes=1000, classifier_activation='softmax'):
    if input_shape is None:
        input_shape = (None, None, 3)
    
    img_input = layers.Input(shape=input_shape)
    x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)), name='conv1_pad')(img_input)
    x = layers.Conv2D(64, 7, strides=2, name='conv1_conv')(x)
    x = layers.BatchNormalization(epsilon=1.001e-5, name='conv1_bn')(x)
    x = layers.Activation('relu', name='conv1_relu')(x)

    x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name='pool1_pad')(x)
    x = layers.MaxPooling2D(3, strides=2, name='pool1_pool')(x)
    x = stack_fn(x)

    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
    x = layers.Dense(classes, activation=classifier_activation, name='predictions')(x)

    return tf.keras.Model(inputs=img_input, outputs=x)

In [70]:
def ResNet50(input_shape=None, classes=1000):
    def stack_fn(x):
        x = stack(x, 64, 3, stride=1, name='conv2')
        x = stack(x, 128, 4, name='conv3')
        x = stack(x, 256, 6, name='conv4')
        return stack(x, 512, 3, name='conv5')
    return ResNet(stack_fn, input_shape, classes)

def ResNet101(input_shape=None, classes=1000):
    def stack_fn(x):
        x = stack(x, 64, 3, stride=1, name='conv2')
        x = stack(x, 128, 4, name='conv3')
        x = stack(x, 256, 23, name='conv4')
        return stack(x, 512, 3, name='conv5')
    return ResNet(stack_fn,input_shape,classes)

def ResNet152(input_shape=None, classes=1000):
    def stack_fn(x):
        x = stack(x, 64, 3, stride=1, name='conv2')
        x = stack(x, 128, 8, name='conv3')
        x = stack(x, 256, 36, name='conv4')
        return stack(x, 512, 3, name='conv5')
    return ResNet(stack_fn, input_shape, classes)

In [73]:
ResNet152().summary()

Model: "model_10"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_28 (InputLayer)           [(None, None, None,  0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, None, None, 3 0           input_28[0][0]                   
__________________________________________________________________________________________________
conv1_conv (Conv2D)             (None, None, None, 6 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
conv1_bn (BatchNormalization)   (None, None, None, 6 256         conv1_conv[0][0]                 
___________________________________________________________________________________________