In [2]:
import tensorflow as tf
from tensorflow.keras import Sequential, layers, Model

In [3]:
print(tf.__version__)

2.0.0


In [4]:
def _make_divisible(ch, divisor=8, min_ch=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    """
    if min_ch is None:
        min_ch = divisor
    new_ch = max(min_ch, int(ch + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_ch < 0.9 * ch:
        new_ch += divisor
    return new_ch

In [38]:
class h_swish(layers.Layer):
    def __init__(self):
        super(h_swish, self).__init__()
        self.relu = layers.ReLU(max_value=6.0)
    
    def call(self, inputs):
        x = self.relu(inputs) * inputs

        return x


In [85]:
class SqueezeBlock(layers.Layer):
    def __init__(self, exp_size, divide = 4):
        super(SqueezeBlock, self).__init__()
        self.linear = Sequential([
            layers.GlobalAveragePooling2D(),
            layers.Dense(exp_size // divide),
            layers.ReLU(max_value=6.0),
            layers.Dense(exp_size),
        ])
        self.h_sigmoid = tf.keras.activations.hard_sigmoid

    def call(self, inputs):
        x = self.linear(inputs)
        x = self.h_sigmoid(x)
        x = tf.reshape(x, (tf.shape(x)[0], 1, 1, tf.shape(x)[-1]))
        out = tf.math.multiply(inputs, x)

        return out

In [86]:
class ConvBN(layers.Layer):
    def __init__(self, out_channel, kernel_size, stride, activation, padding):
        super(ConvBN, self).__init__()
        self.conv = layers.Conv2D(out_channel, kernel_size=kernel_size, strides=stride, padding=padding)
        self.BN = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name='BatchNorm')
        self.flag = activation
        self.relu = layers.ReLU()
        self.h_swish = h_swish()

    def call(self, inputs):
        x = self.conv(inputs)
        x = self.BN(x)
        if self.flag == 'relu':
            out = self.relu(x)
        elif self.flag == 'h_swish':
            out = self.h_swish(x)

        return out

In [87]:
class Bneck(layers.Layer):
    def __init__(self, in_channel, out_channel, kernel_size, stride, SE, NL, exp_size, dropout_rate = 1.0):
        super(Bneck, self).__init__()
        self.out_channel= out_channel
        self.SE = SE
        self.dropout_rate = dropout_rate
        self.NL = NL

        self.shortcut = (stride == 1 and in_channel == out_channel)

        if self.NL == 'RE':
            activation = 'relu'
        elif NL == 'HS':
            activation = 'h_swish'
        
        self.ConvBN_1 = ConvBN(out_channel=exp_size, kernel_size=1, stride=1, padding='valid', activation=activation)

        self.DWBN = Sequential([
            layers.DepthwiseConv2D(kernel_size=kernel_size, strides=stride, padding='same', use_bias=False),
            layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name='BatchNorm')
        ])

        if self.SE:
            self.squeeze_block = SqueezeBlock(exp_size)
        
        self.ConvBN_2 = ConvBN(out_channel=out_channel, kernel_size=1, stride=1, padding='valid', activation=activation)

    def call(self, inputs):
        x = self.ConvBN_1(inputs)
        x = self.DWBN(x)

        if self.SE:
            x = self.squeeze_block(x)
            
        x = self.ConvBN_2(x)

        if self.shortcut:
            return inputs + x
        else:
            return x



In [90]:
def MobileNetV3(mode='large', H=224, W=224, num_classes=1000, alpha = 1.0):
    input_channel = _make_divisible(16*alpha)
    last_channel = _make_divisible(1280*alpha)

    bneck_setting = [
        [16, 16, 3, False, 'RE', 1],
        [64, 24, 3, False, 'RE', 2],
        [72, 24, 3, False, 'RE', 1],
        [72, 40, 5, True, 'RE', 2],
        [120, 40, 5, True, 'RE', 1],
        [120, 40, 5, True, 'RE', 1],
        [240, 80, 3, False, 'HS', 2],
        [200, 80, 3, False, 'HS', 1],
        [184, 80, 3, False, 'HS', 1],
        [184, 80, 3, False, 'HS', 1],
        [480, 112, 3, True, 'HS', 1],
        [672, 112, 3, True, 'HS', 1],
        [672, 160, 5, True, 'HS', 2],
        [960, 160, 5, True, 'HS', 1],
        [960, 160, 5, True, 'HS', 1]
    ]

    input_image = layers.Input(shape=(H, W, 3), dtype = 'float32')

    x = ConvBN(out_channel=input_channel, kernel_size=3, stride=2, padding='same', activation='h_swish')(input_image)

    for exp_size, out_channel, kernel_size, SE, NL, stride in bneck_setting:
        exp_size = _make_divisible(exp_size*alpha)
        out_channel = _make_divisible(out_channel*alpha)
        x = Bneck(x.shape[-1], out_channel, kernel_size, stride, SE, NL, exp_size)(x)
    
    last_out = _make_divisible(960*alpha)
    x = ConvBN(last_out, kernel_size=1, stride=1, activation='h_swish', padding='valid')(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(last_channel)(x)
    x = h_swish()(x)
    output = layers.Dense(num_classes)(x)

    model = Model(inputs=input_image, outputs=output)

    return model

In [91]:
net = MobileNetV3()

In [92]:
net.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_18 (InputLayer)        [(None, 224, 224, 3)]     0         
_________________________________________________________________
conv_bn_103 (ConvBN)         (None, 112, 112, 16)      512       
_________________________________________________________________
bneck_44 (Bneck)             (None, 112, 112, 16)      880       
_________________________________________________________________
bneck_45 (Bneck)             (None, 56, 56, 24)        3832      
_________________________________________________________________
bneck_46 (Bneck)             (None, 56, 56, 24)        4872      
_________________________________________________________________
bneck_47 (Bneck)             (None, 28, 28, 40)        9938      
_________________________________________________________________
bneck_48 (Bneck)             (None, 28, 28, 40)        21230 