In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import tensorflow as tf

In [86]:
def conv_layer(nf, kernel=(3, 3), strides=(1, 1), padding:str='valid', activation:str='relu', kernel_init='he_normal', input_layer:bool=False, input_shape=None, bn:bool = True, bn_mom=0.9):
    layers = []
    if input_layer:
        layers.append(tf.keras.layers.Conv2D(nf, kernel, strides, padding, kernel_initializer=kernel_init, input_shape=input_shape))
    else:
        layers.append(tf.keras.layers.Conv2D(nf, kernel, strides, padding, kernel_initializer=kernel_init))
    if bn:
        layers.append(tf.keras.layers.BatchNormalization(momentum=bn_mom))
    layers.append(tf.keras.layers.Activation(activation=activation))
    return tf.keras.Sequential(layers)

In [87]:
class ResBlock(tf.keras.Model):
    """Residual Block Consisting of 2 (conv + batch_norm + relu) layers
    nf: number of filters
    shortcut: there is a an identity  
    """
    def __init__(self, nf, bottleneck:bool=False):
        super(ResBlock, self).__init__()
        self.nf, self.bottleneck = nf, bottleneck
        if self.bottleneck:
            self.strides = (2, 2)
            self.bottleneck_block = conv_layer(self.nf, kernel=(1, 1), strides=self.strides, padding='same') 
        else:
            self.strides = (1, 1)
            
        self.conv1 = conv_layer(self.nf, kernel=(3, 3), strides=self.strides, padding='same')
        self.conv2 = conv_layer(self.nf, kernel=(3, 3), padding='same')
    def call(self, x, training=False):
        skip = x
        fx = x
        if self.bottleneck:
            skip = self.bottleneck_block(x)
        fx = self.conv1(fx)
        fx = self.conv2(fx)
        print(skip.shape, fx.shape)
        fx = tf.add(skip, fx) # may be function to minimize memory consumption.
        fx = tf.nn.relu(fx)
        return fx

In [88]:
class ResNet34(tf.keras.Model):
    def __init__(self, input_shape, include_top=True, n_classes=1000):
        super(ResNet34, self).__init__()
        self.ishape, self.include_top, self.n_classes = input_shape, include_top, n_classes
        self.conv1 = conv_layer(64, kernel=(7, 7), padding='same', strides=2, input_layer=True, input_shape=self.ishape)
        self.maxpool = tf.keras.layers.MaxPool2D((3, 3), 2, padding='same')
        self.res_blocks = tf.keras.Sequential()
        for nf, num_blocks, downscale in zip([64, 128, 256, 512],
                                             [3, 4, 6, 3], 
                                             [False, True, True, True]):
            for i in range(num_blocks):
                if i == 0 and downscale:
                    self.res_blocks.add(ResBlock(nf, bottleneck=True))
                else:
                    self.res_blocks.add(ResBlock(nf))
        if include_top:
            self.avg = tf.keras.layers.GlobalAveragePooling2D()
            self.flatten = tf.keras.layers.Flatten()
            self.out = tf.keras.layers.Dense(self.n_classes, activation='softmax')
            
    def call(self, x):
        print(x.shape)
        x = self.conv1(x)
        print(x.shape, x.shape)
        x = self.maxpool(x)
        print(x.shape, x.shape)
        x = self.res_blocks(x)
        if self.include_top:
            x = self.avg(x)
            print(x.shape, x.shape)
            x = self.flatten(x)
            x = self.out(x)
        print(x.shape)
        return x

In [89]:
res34 = ResNet34((224, 224, 3))
res34.build((1, 224, 224, 3))
res34.summary()

(1, 224, 224, 3)
(1, 112, 112, 64) (1, 112, 112, 64)
(1, 56, 56, 64) (1, 56, 56, 64)
(1, 56, 56, 64) (1, 56, 56, 64)
(1, 56, 56, 64) (1, 56, 56, 64)
(1, 56, 56, 64) (1, 56, 56, 64)
(1, 28, 28, 128) (1, 28, 28, 128)
(1, 28, 28, 128) (1, 28, 28, 128)
(1, 28, 28, 128) (1, 28, 28, 128)
(1, 28, 28, 128) (1, 28, 28, 128)
(1, 14, 14, 256) (1, 14, 14, 256)
(1, 14, 14, 256) (1, 14, 14, 256)
(1, 14, 14, 256) (1, 14, 14, 256)
(1, 14, 14, 256) (1, 14, 14, 256)
(1, 14, 14, 256) (1, 14, 14, 256)
(1, 14, 14, 256) (1, 14, 14, 256)
(1, 7, 7, 512) (1, 7, 7, 512)
(1, 7, 7, 512) (1, 7, 7, 512)
(1, 7, 7, 512) (1, 7, 7, 512)
(1, 512) (1, 512)
(1, 1000)
Model: "res_net34_23"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential_534 (Sequential)  (None, 112, 112, 64)      9728      
_________________________________________________________________
max_pooling2d_21 (MaxPooling multiple                  0         
_________