In [1]:
import tensorflow as tf
from tensorflow.keras import layers, Sequential, Model

In [4]:
def _make_divisible(ch, divisor=8, min_ch=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    """
    if min_ch is None:
        min_ch = divisor
    new_ch = max(min_ch, int(ch + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_ch < 0.9 * ch:
        new_ch += divisor
    return new_ch

In [2]:
class ConvBNReLU(layers.Layer):
    def __init__(self, out_channel, kernel_size=3, stride=1, **kwarg):
        super(ConvBNReLU, self).__init__(**kwarg)
        self.Conv = layers.Conv2D(filters=out_channel, kernel_size=kernel_size, strides=stride, 
                                  padding='same', use_bias=False, name='Conv2d', **kwarg)
        self.BN = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name='BatchNorm')
        self.ReLU = layers.ReLU(max_value=6.0)
    
    def call(self, inputs):
        x = self.Conv(inputs)
        x = self.BN(x)
        out = self.ReLU(x)

        return out

In [35]:
class Bottleneck(layers.Layer):
    def __init__(self, in_channel, out_channel, t, stride, **kwarg):
        super(Bottleneck, self).__init__(**kwarg)
        self.hidden_channel = in_channel * t
        self.shortcut = in_channel == out_channel and stride==1

        layer_list = []
        if t != 1:
            layer_list.append(ConvBNReLU(out_channel=self.hidden_channel, kernel_size=1))
        layer_list.extend(
            [
                layers.DepthwiseConv2D(kernel_size=3, padding='same', strides=stride, use_bias=False),
                layers.BatchNormalization(momentum=0.9, epsilon=1e-5),
                layers.ReLU(max_value=6.0),
                layers.Conv2D(filters=out_channel, kernel_size=1, strides=1, use_bias=False)
            ]
        )
        self.block = Sequential(layer_list, name='bottleneck')
    
    def call(self, inputs, **kwarg):
        if self.shortcut:
            return inputs + self.block(inputs)
        else:
            return self.block(inputs)


In [36]:
def MobileNetV2(H = 224, W = 224, num_classes = 1000, alpha = 1.0, round_nearest = 8):
    input_channel = _make_divisible(32*alpha, round_nearest)
    last_channel = _make_divisible(1280*alpha, round_nearest)

    block_setting = [
        [1, 16, 1, 1],
        [6, 24, 2, 2],
        [6, 32, 3, 2],
        [6, 64, 4, 2],
        [6, 96, 3, 1],
        [6, 160, 3, 2],
        [6, 320, 1, 1]
    ]

    input_image = layers.Input(shape=[H, W, 3], dtype='float32')
    x = ConvBNReLU(out_channel=input_channel, stride=2, name='Conv')(input_image)

    for t, c, n, s in block_setting:
        out_channel = _make_divisible(c * alpha, round_nearest)
        for i in range(n):
            stride = s if i == 0 else 1
            x = Bottleneck(x.shape[-1], out_channel, t=t, stride=stride)(x)
        
    x = ConvBNReLU(out_channel=last_channel, kernel_size=1, name='Conv1')(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.2)(x)
    output = layers.Dense(num_classes, name='logits')(x)

    model = Model(inputs=input_image, outputs=output)

    return model

In [37]:
net = MobileNetV2()

In [38]:
net.summary()

Model: "model_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_11 (InputLayer)        [(None, 224, 224, 3)]     0         
_________________________________________________________________
Conv (ConvBNReLU)            (None, 112, 112, 32)      992       
_________________________________________________________________
bottleneck_88 (Bottleneck)   (None, 112, 112, 16)      928       
_________________________________________________________________
bottleneck_89 (Bottleneck)   (None, 56, 56, 24)        5472      
_________________________________________________________________
bottleneck_90 (Bottleneck)   (None, 56, 56, 24)        9360      
_________________________________________________________________
bottleneck_91 (Bottleneck)   (None, 28, 28, 32)        10512     
_________________________________________________________________
bottleneck_92 (Bottleneck)   (None, 28, 28, 32)        1555