# ResNet Implementation

- Based on [Github](https://github.com/calmisential/TensorFlow2.0_ResNet)
- 작성자: 고려대학교 수학과 석사과정 최선묵

In [14]:
import tensorflow as tf

from tensorflow.keras.layers import Dense, Conv2D, MaxPool2D, GlobalAveragePooling2D
from tensorflow.keras.layers import BatchNormalization, ReLU
from tensorflow.keras.layers import add

from tensorflow.keras import Sequential

## Basic Block

- Input argument ```stride```는 ```conv1```에 적용하고, ```conv2```는 stride=1로 고정

In [5]:
class BasicBlock(tf.keras.layers.Layer):
    
    def __init__(self, num_filters, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = Conv2D(filters=num_filters,
                           kernel_size=(3,3),
                           strides=stride,
                           padding='same',
                           kernel_initializer='he_uniform')
        self.bn1   = BatchNormalization()
        self.conv2 = Conv2D(filters=num_filters,
                           kernel_size=(3,3),
                           strides=1,
                           padding='same',
                           kernel_initializer='he_uniform')
        self.bn2   = BatchNormalization()
        
        if stride != 1:
            self.downsample = Sequential()
            self.downsample.add(Conv2D(filters=num_filters,
                                      kernel_size=(1,1),
                                      strides=stride,
                                      kernel_initializer='he_uniform'))
            self.downsample.add(BatchNormalization())
        else:
            self.downsample = lambda x: x
            
    def call(self, inputs, training=None, **kwargs):
        residual = self.downsample(inputs)
        
        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        x = ReLU(x)
        x = self.conv2(x)
        x = self.bn2(x, training=training)
        
        output = ReLU(add([residual, x]))
        
        return output

## Bottleneck Block

- Input argument ```stride```는 ```conv2```에 적용하고, ```conv1```은 stride=1로 고정

In [6]:
class Bottleneck(tf.keras.layers.Layer):
    def __init__(self, num_filters, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = Conv2D(filters=num_filters,
                           kernel_size=(1,1),
                           strides=1,
                           padding='same',
                           kernel_initializer='he_uniform')
        self.bn1   = BatchNormalization()
        self.conv2 = Conv2D(filters=num_filters,
                           kernel_size=(3,3),
                           strides=stride,
                           padding='same',
                           kernel_initializer='he_uniform')
        self.bn2   = BatchNormalization()
        self.conv3 = Conv2D(filters=num_filters,
                           kernel_size=(1,1),
                           strides=1,
                           padding='same',
                           kernel_initializer='he_uniform')
        self.bn3   = BatchNormalization()
        
        self.downsample = Sequential()
        self.downsample.add(Conv2D(filters=num_filters*4,
                                  kernel_size=(1,1),
                                  strides=stride,
                                  kernel_initializer='he_uniform'))
        self.downsample.add(BatchNormalization())
        
    def call(self, inputs, training=None, **kwargs):
        residual = self.downsample(inputs)
        
        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        x = ReLU(x)
        x = self.conv2(x)
        x = self.bn2(x, training=training)
        x = ReLU(x)
        x = self.conv3(x)
        x = self.bn3(x, training=training)
        
        output= ReLU(add([residual, x]))
        
        return output

In [8]:
def make_basic_block_layer(num_filters:int, num_blocks:int, stride=1):
    res_block = Sequential()
    res_block.add(BasicBlock(num_filters, stride=stride))
    
    for _ in range(1, num_blocks):
        res_block.add(BasicBlock(num_filters, stride=1))
        
    return res_block

In [9]:
def make_bottleneck_layer(num_filters:int, num_blocks:int, stride=1):
    res_block = Sequential()
    res_block.add(BottleNeck(num_filters, stride=stride))
    
    for _ in range(1, num_blocks):
        res_block.add(BottleNeck(num_filters, stride=1))
        
    return res_block

## ResNet 

In [12]:
NUM_CLASSES=10

In [16]:
class BasicResNet(tf.keras.Model):
    def __init__(self, layer_params:list):
        super(BasicResNet, self).__init__()
        
        self.conv1 = Conv2D(filters=64,
                           kernel_size=(7,7),
                           strides=2,
                           padding='same',
                           kernel_initializer='he_uniform')
        self.bn1 = BatchNormalization()
        self.pool1 = MaxPool2D(pool_size=(3,3),
                              strides=2,
                              padding='same')
        
        self.layer1 = make_basic_block_layer(num_filters=64, num_blocks=layer_params[0])
        self.layer2 = make_basic_block_layer(num_filters=128, num_blocks=layer_params[1], stride=2)
        self.layer3 = make_basic_block_layer(num_filters=256, num_blocks=layer_params[2], stride=2)
        self.layer4 = make_basic_block_layer(num_filters=512, num_blocks=layer_params[3], stride=2)
        
        self.avgpol = GlobalAveragePooling2D()
        self.fc = Dense(units=NUM_CLASSES, activation='softmax')
        
    def call(self, inputs, training=None, mask=None):
        x = self.conv1(inputs)
        x = self.bn1(x, trianing=training)
        x = ReLU(x)
        x = self.pool1(x)
        x = self.layer1(x, training=training)
        x = self.layer2(x, training=training)
        x = self.layer3(x, training=training)
        x = self.layer4(x, training=training)
        x = self.avgpool(x)
        output = self.fc(x)
        
        return output

In [17]:
class BottleneckResNet(tf.keras.Model):
    def __init__(self, layer_params):
        super(BottleneckResNet, self).__init__()

        self.conv1 = Conv2D(filters=64,
                           kernel_size=(7,7),
                           strides=2,
                           padding='same',
                           kernel_initializer='he_uniform')
        self.bn1 = BatchNormalization()
        self.pool1 = MaxPool2D(pool_size=(3,3),
                              strides=2,
                              padding='same')
        
        self.layer1 = make_bottleneck_layer(num_filters=64, num_blocks=layer_params[0])
        self.layer2 = make_bottleneck_layer(num_filters=128, num_blocks=layer_params[1], stride=2)
        self.layer3 = make_bottleneck_layer(num_filters=256, num_blocks=layer_params[2], stride=2)
        self.layer4 = make_bottleneck_layer(num_filters=512, num_blocks=layer_params[3], stride=2)
        
        self.avgpol = GlobalAveragePooling2D()
        self.fc = Dense(units=NUM_CLASSES, activation='softmax')

    def call(self, inputs, training=None, mask=None):
        x = self.conv1(inputs)
        x = self.bn1(x, trianing=training)
        x = ReLU(x)
        x = self.pool1(x)
        x = self.layer1(x, training=training)
        x = self.layer2(x, training=training)
        x = self.layer3(x, training=training)
        x = self.layer4(x, training=training)
        x = self.avgpool(x)
        output = self.fc(x)
        
        return output
    
    
        

In [18]:
def resnet_18(height, width, channel):
    model = BasicResNet(layer_params=[2,2,2,2])
    model.build(input_shape=(None, height, width, channel))
    model.summary()
    return model

def resnet_34(height, width, channel):
    model = BasicResNet(layer_params=[3,4,6,3])
    model.build(input_shape=(None, height, width, channel))
    model.summary()
    return model
    
def resnet_50(height, width, channel):
    model = BottleneckResNet(layer_params=[3,4,6,3])
    model.build(input_shape=(None, height, width, channel))
    model.summary()
    return model

def resnet_101(height, width, channel):
    model = BottleneckResNet(layer_params=[3,4,23,3])
    model.build(input_shape=(None, height, width, channel))
    model.summary()
    return model

def resnet_152(height, width, channel):
    model = BottleneckResNet(layer_params=[3,8,36,3])
    model.build(input_shape=(None, height, width, channel))
    model.summary()
    return model