In [1]:
import tensorflow as tf
from keras.models import Model
from keras.applications import imagenet_utils
from keras.layers import Conv2D, ZeroPadding2D, DepthwiseConv2D
from keras.layers import Input, Dense, Dropout, MultiHeadAttention, GlobalAvgPool2D
from keras.layers import BatchNormalization, LayerNormalization
from keras.layers import Add, Reshape, Concatenate, Rescaling

[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 7483232799281881318
xla_global_id: -1
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 15337455616
locality {
  bus_id: 1
  links {
  }
}
incarnation: 9522336491907571479
physical_device_desc: "device: 0, name: Tesla V100-SXM2-16GB, pci bus id: 0000:00:1e.0, compute capability: 7.0"
xla_global_id: 416903419
]


In [3]:
class HPR:
    IMG_SIZE = 256
    PATCH_SIZE = 4 # 2x2, transformer block
    EXPANSION_FACTOR = 2 # MobileNetV2 blocks
    LEARNING_RATE = 0.002
    LABEL_SMMOTHING_FACTOR = 0.1
    EPOCHS = 300
    BATCH_SIZE = 32
    SEED = 41
    N_CLASS = 5

# MobileViT Architecture

In [4]:
def conv_block(x, filters=16, kernel_size=3, strides=2):
    conv_layer = Conv2D(filters, 
                        kernel_size, 
                        strides=strides, 
                        activation=tf.nn.swish, 
                        padding="same")
    return conv_layer(x)


# Reference: https://git.io/JKgtC


def inverted_residual_block(x, expanded_channels, output_channels, strides=1):
    m = Conv2D(expanded_channels, 1, padding="same", use_bias=False)(x)
    m = BatchNormalization()(m)
    m = tf.nn.swish(m)

    if strides == 2:
        m = ZeroPadding2D(padding=imagenet_utils.correct_pad(m, 3))(m)
    m = DepthwiseConv2D(3, 
                        strides=strides, 
                        padding="same" if strides == 1 else "valid", 
                        use_bias=False)(m)
    m = BatchNormalization()(m)
    m = tf.nn.swish(m)

    m = Conv2D(output_channels, 1, padding="same", use_bias=False)(m)
    m = BatchNormalization()(m)

    if tf.math.equal(x.shape[-1], output_channels) and strides == 1:
        return Add()([m, x])
    return m


def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = Dense(units, activation=tf.nn.swish)(x)
        x = Dropout(dropout_rate)(x)
    return x


def transformer_block(x, transformer_layers, projection_dim, num_heads=2):
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = LayerNormalization(epsilon=1e-6)(x)
        # Create a multi-head attention layer.
        attention_output = MultiHeadAttention(num_heads=num_heads, 
                                              key_dim=projection_dim, 
                                              dropout=0.1)(x1, x1)
        # Skip connection 1.
        x2 = Add()([attention_output, x])
        # Layer normalization 2.
        x3 = LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=[x.shape[-1] * 2, x.shape[-1]], dropout_rate=0.1,)
        # Skip connection 2.
        x = Add()([x3, x2])

    return x


def mobilevit_block(x, num_blocks, projection_dim, strides=1):
    # Local projection with convolutions.
    local_features = conv_block(x, filters=projection_dim, strides=strides)
    local_features = conv_block(local_features, 
                                filters=projection_dim, 
                                kernel_size=1, 
                                strides=strides)

    # Unfold into patches and then pass through Transformers.
    num_patches = int((local_features.shape[1] * local_features.shape[2]) / HPR.PATCH_SIZE)
    non_overlapping_patches = Reshape((HPR.PATCH_SIZE, 
                                       num_patches, 
                                       projection_dim))(local_features)
    global_features = transformer_block(non_overlapping_patches, 
                                        num_blocks, 
                                        projection_dim)

    # Fold into conv-like feature-maps.
    folded_feature_map = Reshape((*local_features.shape[1:-1], 
                                  projection_dim))(global_features)

    # Apply point-wise conv -> concatenate with the input features.
    folded_feature_map = conv_block(folded_feature_map, 
                                    filters=x.shape[-1], 
                                    kernel_size=1, 
                                    strides=strides)
    local_global_features = Concatenate(axis=-1)([x, folded_feature_map])

    # Fuse the local and global features using a convoluion layer.
    local_global_features = conv_block(local_global_features, 
                                       filters=projection_dim, 
                                       strides=strides)

    return local_global_features

In [6]:
def create_mobilevit(tensor, num_classes=HPR.N_CLASS, include_top=True):
    # Initial conv-stem -> MV2 block.
    x = conv_block(tensor, filters=16)
    x = inverted_residual_block(
        x, expanded_channels=16 * HPR.EXPANSION_FACTOR, output_channels=16
    )

    # Downsampling with MV2 block.
    x = inverted_residual_block(
        x, expanded_channels=16 * HPR.EXPANSION_FACTOR, output_channels=24, strides=2
    )
    x = inverted_residual_block(
        x, expanded_channels=24 * HPR.EXPANSION_FACTOR, output_channels=24
    )
    x = inverted_residual_block(
        x, expanded_channels=24 * HPR.EXPANSION_FACTOR, output_channels=24
    )

    # First MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=24 * HPR.EXPANSION_FACTOR, output_channels=48, strides=2
    )
    x = mobilevit_block(x, num_blocks=2, projection_dim=64)

    # Second MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=64 * HPR.EXPANSION_FACTOR, output_channels=64, strides=2
    )
    x = mobilevit_block(x, num_blocks=4, projection_dim=80)

    # Third MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=80 * HPR.EXPANSION_FACTOR, output_channels=80, strides=2
    )
    x = mobilevit_block(x, num_blocks=3, projection_dim=96)
    x = conv_block(x, filters=320, kernel_size=1, strides=1)

    if include_top:
        # Classification head.
        x = GlobalAvgPool2D()(x)

    return x

inputs = Input((HPR.IMG_SIZE, HPR.IMG_SIZE, 3))
x = Rescaling(scale=1.0 / 255)(inputs)
x = create_mobilevit(x, include_top=False)
x = GlobalAvgPool2D()(x)
outputs = Dense(HPR.N_CLASS, activation="softmax")(x)
mobilevit_xxs = Model(inputs, outputs)
mobilevit_xxs.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 rescaling (Rescaling)          (None, 256, 256, 3)  0           ['input_1[0][0]']                
                                                                                                  
 conv2d (Conv2D)                (None, 128, 128, 16  448         ['rescaling[0][0]']              
                                )                                                                 
                                                                                              

 add_1 (Add)                    (None, 64, 64, 24)   0           ['batch_normalization_8[0][0]',  
                                                                  'batch_normalization_5[0][0]']  
                                                                                                  
 conv2d_7 (Conv2D)              (None, 64, 64, 48)   1152        ['add_1[0][0]']                  
                                                                                                  
 batch_normalization_9 (BatchNo  (None, 64, 64, 48)  192         ['conv2d_7[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 tf.nn.silu_6 (TFOpLambda)      (None, 64, 64, 48)   0           ['batch_normalization_9[0][0]']  
                                                                                                  
 depthwise

                                                                                                  
 add_5 (Add)                    (None, 4, 256, 64)   0           ['multi_head_attention_1[0][0]', 
                                                                  'add_4[0][0]']                  
                                                                                                  
 layer_normalization_3 (LayerNo  (None, 4, 256, 64)  128         ['add_5[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 dense_2 (Dense)                (None, 4, 256, 128)  8320        ['layer_normalization_3[0][0]']  
                                                                                                  
 dropout_2 (Dropout)            (None, 4, 256, 128)  0           ['dense_2[0][0]']                
          

 eadAttention)                                                    'layer_normalization_6[0][0]']  
                                                                                                  
 add_9 (Add)                    (None, 4, 64, 80)    0           ['multi_head_attention_3[0][0]', 
                                                                  'add_8[0][0]']                  
                                                                                                  
 layer_normalization_7 (LayerNo  (None, 4, 64, 80)   160         ['add_9[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 dense_6 (Dense)                (None, 4, 64, 160)   12960       ['layer_normalization_7[0][0]']  
                                                                                                  
 dropout_6

                                                                                                  
 depthwise_conv2d_6 (DepthwiseC  (None, 8, 8, 160)   1440        ['zero_padding2d_3[0][0]']       
 onv2D)                                                                                           
                                                                                                  
 batch_normalization_19 (BatchN  (None, 8, 8, 160)   640         ['depthwise_conv2d_6[0][0]']     
 ormalization)                                                                                    
                                                                                                  
 tf.nn.silu_13 (TFOpLambda)     (None, 8, 8, 160)    0           ['batch_normalization_19[0][0]'] 
                                                                                                  
 conv2d_22 (Conv2D)             (None, 8, 8, 80)     12800       ['tf.nn.silu_13[0][0]']          
          

                                                                                                  
 dropout_17 (Dropout)           (None, 4, 16, 96)    0           ['dense_17[0][0]']               
                                                                                                  
 add_20 (Add)                   (None, 4, 16, 96)    0           ['dropout_17[0][0]',             
                                                                  'add_19[0][0]']                 
                                                                                                  
 reshape_5 (Reshape)            (None, 8, 8, 96)     0           ['add_20[0][0]']                 
                                                                                                  
 conv2d_25 (Conv2D)             (None, 8, 8, 80)     7760        ['reshape_5[0][0]']              
                                                                                                  
 concatena

In [7]:
print(mobilevit_xxs.input)
print(mobilevit_xxs.output)

KerasTensor(type_spec=TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name='input_1'), name='input_1', description="created by layer 'input_1'")
KerasTensor(type_spec=TensorSpec(shape=(None, 5), dtype=tf.float32, name=None), name='dense_18/Softmax:0', description="created by layer 'dense_18'")
