In [1]:
import tensorflow as tf
from tensorflow.keras import layers, Sequential, Model

<div align=center><img src="./image/MobileNetV1.png" width="400"/></div>

In [51]:
class ConvBNReLU(layers.Layer):
    def __init__(self, out_channel, kernel_size=1, stride=1, **kwarg):
        super(ConvBNReLU, self).__init__(**kwarg)
        self.Conv = layers.Conv2D(filters=out_channel, kernel_size=kernel_size, strides=stride, padding='same', use_bias=False, **kwarg)
        self.BN = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)
        self.ReLU = layers.ReLU(max_value=6.0)
    
    def call(self, inputs):
        x = self.Conv(inputs)
        x = self.BN(x)
        out = self.ReLU(x)

        return out


In [52]:
class DWBNReLU(layers.Layer):
    def __init__(self, out_channel, kernel_size=3, stride=1, **kwarg):
        super(DWBNReLU, self).__init__(**kwarg)
        self.Conv = layers.DepthwiseConv2D(kernel_size=kernel_size, strides=stride, padding='same', use_bias=False, **kwarg)
        self.BN = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)
        self.ReLU = layers.ReLU(max_value=6.0)
    
    def call(self, inputs):
        x = self.Conv(inputs)
        x = self.BN(x)
        out = self.ReLU(x)
        
        return out

In [53]:
class Block(layers.Layer):
    def __init__(self, out_channel, stride = 1, **kwarg):
        super(Block, self).__init__(**kwarg)
        self.stride = stride
        self.DWBNReLU_1 = DWBNReLU(out_channel=out_channel)
        self.DWBNReLU_2 = DWBNReLU(out_channel=out_channel, stride=2)
        self.ConvBNReLU_1 = ConvBNReLU(out_channel=out_channel)
        self.ConvBNReLU_2 = ConvBNReLU(out_channel=2*out_channel)
    
    def call(self, inputs):
        if self.stride == 1:
            x = self.DWBNReLU_1(inputs)
            out = self.ConvBNReLU_1(x)
        elif self.stride == 2:
            x = self.DWBNReLU_2(inputs)
            out = self.ConvBNReLU_2(x)
        
        return out

        

In [59]:
def MobileNetV1(H = 224, W = 224, num_classes = 10):
    block_setting = [['DW', 32, 1], ['Conv', 64, 1], ['DW', 64, 2], ['Conv', 128, 1], ['DW', 128, 1], ['Conv', 128, 1],
                     ['DW', 128, 2], ['Conv', 256, 1], ['DW', 256, 1], ['Conv', 256, 1], ['DW', 256, 2], ['Conv', 512, 1]]
    input_image = layers.Input(shape=(H, W, 3), dtype='float32')
    x = ConvBNReLU(out_channel=32, kernel_size=3, stride=2, name='Conv_s2_1')(input_image)
    for i, (kind, out_channel, stride) in enumerate(block_setting):
        if kind == 'Conv':
            x = ConvBNReLU(out_channel=out_channel, stride=stride, name='Conv_s'+str(stride)+'_'+str(i+2))(x)
        elif kind == 'DW':
            x = DWBNReLU(out_channel=out_channel, stride=stride, name='Conv_dw_s'+str(stride)+'_'+str(i+2))(x)
    for i in range(5):
        x = Block(out_channel=512, stride=1, name='block'+'_'+str(i+14))(x)
    x = DWBNReLU(out_channel=512, stride=2, name='Conv_dw_s2_15')(x)
    x = ConvBNReLU(out_channel=1024, name='Conv_s1_16')(x)
    x = DWBNReLU(out_channel=1024, stride=2, name='Conv_dw_s2_17')(x)
    x = ConvBNReLU(out_channel=1024, name='Conv_s1_18')(x)
    x = layers.GlobalAveragePooling2D(name='Avg_Pool_s1_19')(x)
    output = layers.Dense(num_classes, name='logit_20')(x)
    
    model = Model(inputs=input_image, outputs=output)
    
    return model


In [60]:
net = MobileNetV1()

In [61]:
net.summary()

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_16 (InputLayer)        [(None, 224, 224, 3)]     0         
_________________________________________________________________
Conv_s2_1 (ConvBNReLU)       (None, 112, 112, 32)      992       
_________________________________________________________________
Conv_dw_s1_2 (DWBNReLU)      (None, 112, 112, 32)      416       
_________________________________________________________________
Conv_s1_3 (ConvBNReLU)       (None, 112, 112, 64)      2304      
_________________________________________________________________
Conv_dw_s2_4 (DWBNReLU)      (None, 56, 56, 64)        832       
_________________________________________________________________
Conv_s1_5 (ConvBNReLU)       (None, 56, 56, 128)       8704      
_________________________________________________________________
Conv_dw_s1_6 (DWBNReLU)      (None, 56, 56, 128)       1664