# MobileViT: A mobile-friendly Transformer-based model for image classification

**Author:** [Sayak Paul](https://twitter.com/RisingSayak)<br>
**Date created:** 2021/10/20<br>
**Last modified:** 2021/10/20<br>
**Description:** MobileViT for image classification with combined benefits of convolutions and Transformers.

## Introduction

In this example, we implement the MobileViT architecture
([Mehta et al.](https://arxiv.org/abs/2110.02178)),
which combines the benefits of Transformers
([Vaswani et al.](https://arxiv.org/abs/1706.03762))
and convolutions. With Transformers, we can capture long-range dependencies that result
in global representations. With convolutions, we can capture spatial relationships that
model locality.

Besides combining the properties of Transformers and convolutions, the authors introduce
MobileViT as a general-purpose mobile-friendly backbone for different image recognition
tasks. Their findings suggest that, performance-wise, MobileViT is better than other
models with the same or higher complexity ([MobileNetV3](https://arxiv.org/abs/1905.02244),
for example), while being efficient on mobile devices.

## Imports

In [1]:
!pip install tensorflow_addons





In [2]:
import tensorflow as tf
tf.config.list_physical_devices('GPU')
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  1


In [3]:
import tensorflow as tf

from keras.applications import imagenet_utils
from tensorflow.keras import layers
from tensorflow import keras

import tensorflow_addons as tfa


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [4]:
from collections import Counter
import os
import cv2
import random
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from datetime import datetime
import numpy as np

## Hyperparameters

In [5]:
# Values are from table 4.
patch_size = 4  # 2x2, for the Transformer blocks.
image_size = 128
expansion_factor = 2  # expansion factor for the MobileNetV2 blocks.

## MobileViT utilities

The MobileViT architecture is comprised of the following blocks:

* Strided 3x3 convolutions that process the input image.
* [MobileNetV2](https://arxiv.org/abs/1801.04381)-style inverted residual blocks for
downsampling the resolution of the intermediate feature maps.
* MobileViT blocks that combine the benefits of Transformers and convolutions. It is
presented in the figure below (taken from the
[original paper](https://arxiv.org/abs/2110.02178)):


![](https://i.imgur.com/mANnhI7.png)

In [6]:
def conv_block(x, filters=16, kernel_size=3, strides=2):
    conv_layer = layers.Conv2D(
        filters, kernel_size, strides=strides, activation=tf.nn.swish, padding="same"
    )
    return conv_layer(x)


# Reference: https://git.io/JKgtC


def inverted_residual_block(x, expanded_channels, output_channels, strides=1):
    m = layers.Conv2D(expanded_channels, 1, padding="same", use_bias=False)(x)
    m = layers.BatchNormalization()(m)
    m = tf.nn.swish(m)

    if strides == 2:
        m = layers.ZeroPadding2D(padding=imagenet_utils.correct_pad(m, 3))(m)
    m = layers.DepthwiseConv2D(
        3, strides=strides, padding="same" if strides == 1 else "valid", use_bias=False
    )(m)
    m = layers.BatchNormalization()(m)
    m = tf.nn.swish(m)

    m = layers.Conv2D(output_channels, 1, padding="same", use_bias=False)(m)
    m = layers.BatchNormalization()(m)

    if tf.math.equal(x.shape[-1], output_channels) and strides == 1:
        return layers.Add()([m, x])
    return m


# Reference:
# https://keras.io/examples/vision/image_classification_with_vision_transformer/


def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.swish)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


def transformer_block(x, transformer_layers, projection_dim, num_heads=2):
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, x])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=[x.shape[-1] * 2, x.shape[-1]], dropout_rate=0.1,)
        # Skip connection 2.
        x = layers.Add()([x3, x2])

    return x


def mobilevit_block(x, num_blocks, projection_dim, strides=1):
    # Local projection with convolutions.
    local_features = conv_block(x, filters=projection_dim, strides=strides)
    local_features = conv_block(
        local_features, filters=projection_dim, kernel_size=1, strides=strides
    )

    # Unfold into patches and then pass through Transformers.
    num_patches = int((local_features.shape[1] * local_features.shape[2]) / patch_size)
    non_overlapping_patches = layers.Reshape((patch_size, num_patches, projection_dim))(
        local_features
    )
    global_features = transformer_block(
        non_overlapping_patches, num_blocks, projection_dim
    )

    # Fold into conv-like feature-maps.
    folded_feature_map = layers.Reshape((*local_features.shape[1:-1], projection_dim))(
        global_features
    )

    # Apply point-wise conv -> concatenate with the input features.
    folded_feature_map = conv_block(
        folded_feature_map, filters=x.shape[-1], kernel_size=1, strides=strides
    )
    local_global_features = layers.Concatenate(axis=-1)([x, folded_feature_map])

    # Fuse the local and global features using a convoluion layer.
    local_global_features = conv_block(
        local_global_features, filters=projection_dim, strides=strides
    )

    return local_global_features


**More on the MobileViT block**:

* First, the feature representations (A) go through convolution blocks that capture local
relationships. The expected shape of a single entry here would be `(h, w, num_channels)`.
* Then they get unfolded into another vector with shape `(p, n, num_channels)`,
where `p` is the area of a small patch, and `n` is `(h * w) / p`. So, we end up with `n`
non-overlapping patches.
* This unfolded vector is then passed through a Tranformer block that captures global
relationships between the patches.
* The output vector (B) is again folded into a vector of shape `(h, w, num_channels)`
resembling a feature map coming out of convolutions.

Vectors A and B are then passed through two more convolutional layers to fuse the local
and global representations. Notice how the spatial resolution of the final vector remains
unchanged at this point. The authors also present an explanation of how the MobileViT
block resembles a convolution block of a CNN. For more details, please refer to the
original paper.

Next, we combine these blocks together and implement the MobileViT architecture (XXS
variant). The following figure (taken from the original paper) presents a schematic
representation of the architecture:

![](https://i.ibb.co/sRbVRBN/image.png)

## Load and prepare the dataset

## Veri yükleme

In [7]:
# from google.colab import drive
# drive.mount('/content/drive')

In [8]:
# IMAGE
RAW_IMG_SIZE = 128
NUM_CLASSES = 268
INPUT_SHAPE = (RAW_IMG_SIZE, RAW_IMG_SIZE, 3)
SPLIT_SEED = 103

# DATA
BUFFER_SIZE = 512
BATCH_SIZE = 256

# AUGMENTATION
IMAGE_SIZE = RAW_IMG_SIZE
PATCH_SIZE = 16
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2

# OPTIMIZER
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.0001

# TRAINING
EPOCHS = 50

# ARCHITECTURE
LAYER_NORM_EPS = 1e-6
TRANSFORMER_LAYERS = 8
PROJECTION_DIM = 64
NUM_HEADS = 4
TRANSFORMER_UNITS = [
    PROJECTION_DIM * 2,
    PROJECTION_DIM,
]
MLP_HEAD_UNITS = [2048, 1024]

In [9]:
image_size = RAW_IMG_SIZE

In [10]:
num_classes = 268
input_shape = (128, 128, 3)
boyut = (128, 128)

pathh='C:\\Users\\PC\\OneDrive\\Masaüstü\\polen-github\\Veriler\\hepsi\\5\\128'
y_train_aug_le= np.loadtxt(pathh+'/y_train_aug_le.txt', dtype=float)
y_train_aug_cat= np.loadtxt(pathh+'/y_train_aug_cat.txt', dtype=float)
y_val_le= np.loadtxt(pathh+'/y_val_le.txt', dtype=float)
y_val_cat= np.loadtxt(pathh+'/y_val_cat.txt', dtype=float)
y_test_le= np.loadtxt(pathh+'/y_test_le.txt', dtype=float)
y_test_cat= np.loadtxt(pathh+'/y_test_cat.txt', dtype=float)
x_train= np.load(pathh+'/x_train.npy')
x_val= np.load(pathh+'/x_val.npy')
x_test= np.load(pathh+'/x_test.npy')
y_test= np.load(pathh+'/y_test.npy')
y_val= np.load(pathh+'/y_val.npy')
y_train= np.load(pathh+'/y_train.npy')

## Train a MobileViT (XXS) model

In [11]:
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        # layers.Resizing(image_size, image_size),
        # layers.RandomFlip("horizontal_and_vertical"),
        # layers.RandomRotation(factor=0.02),
        # layers.GaussianNoise(0.3),
    ],
    name="data_augmentation",
)
data_augmentation.layers[0].adapt(x_train)
# data_augmentation = keras.Sequential(
#     [
#         layers.Normalization(),
#         layers.Resizing(image_size, image_size),
#         layers.RandomFlip("horizontal"),
#         layers.RandomRotation(factor=0.02),
#         layers.RandomZoom(
#             height_factor=0.2, width_factor=0.2
#         ),
#     ],
#     name="data_augmentation",
# )
# 

# class RandomBrightness: A preprocessing layer which randomly adjusts brightness during training.
# class RandomContrast: A preprocessing layer which randomly adjusts contrast during training.
# class RandomCrop: A preprocessing layer which randomly crops images during training.
# class RandomFlip: A preprocessing layer which randomly flips images during training.
# class RandomHeight: A preprocessing layer which randomly varies image height during training.
# class RandomRotation: A preprocessing layer which randomly rotates images during training.
# class RandomTranslation: A preprocessing layer which randomly translates images during training.
# class RandomWidth: A preprocessing layer which randomly varies image width during training.
# class RandomZoom: A preprocessing layer which randomly zooms images during training.

In [12]:

def create_mobilevit(num_classes=5):
    inputs = keras.Input((image_size, image_size, 3))
  # Degisiklik
  # ORJ    x = layers.Rescaling(scale=1.0 / 255)(inputs)
    x = data_augmentation(inputs)

    # Initial conv-stem -> MV2 block.
    x = conv_block(x, filters=16)
    x = inverted_residual_block(
        x, expanded_channels=16 * expansion_factor, output_channels=16
    )

    # Downsampling with MV2 block.
    x = inverted_residual_block(
        x, expanded_channels=16 * expansion_factor, output_channels=24, strides=2
    )
    x = inverted_residual_block(
        x, expanded_channels=24 * expansion_factor, output_channels=24
    )
    x = inverted_residual_block(
        x, expanded_channels=24 * expansion_factor, output_channels=24
    )

    # First MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=24 * expansion_factor, output_channels=48, strides=2
    )
    x = mobilevit_block(x, num_blocks=2, projection_dim=64)

    # Second MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=64 * expansion_factor, output_channels=64, strides=2
    )
    x = mobilevit_block(x, num_blocks=4, projection_dim=80)

    # Third MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=80 * expansion_factor, output_channels=80, strides=2
    )
    x = mobilevit_block(x, num_blocks=3, projection_dim=96)
    x = conv_block(x, filters=320, kernel_size=1, strides=1)

    # Classification head.
    x = layers.GlobalAvgPool2D()(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    return keras.Model(inputs, outputs)


mobilevit_xxs = create_mobilevit()
mobilevit_xxs.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 128, 128, 3  0           []                               
                                )]                                                                
                                                                                                  
 data_augmentation (Sequential)  (None, 128, 128, 3)  7          ['input_1[0][0]']                
                                                                                                  
 conv2d (Conv2D)                (None, 64, 64, 16)   448         ['data_augmentation[0][0]']      
                                                                                                  
 conv2d_1 (Conv2D)              (None, 64, 64, 32)   512         ['conv2d[0][0]']             

In [13]:
learning_rate = 0.001
label_smoothing_factor = 0.1
epochs = 100

optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
loss_fn = keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing_factor)


def run_experiment(epochs=epochs):
    mobilevit_xxs = create_mobilevit(num_classes=num_classes)
    mobilevit_xxs.compile(
        optimizer= optimizer,
        loss= loss_fn, 
        metrics=['accuracy', 'top_k_categorical_accuracy', tfa.metrics.CohenKappa(num_classes=268, sparse_labels=False),] 
    )

    kayit_adresi = "C:\\Users\\PC\OneDrive\\Masaüstü\\polen-github\\grad_modeller\\mobilevit-20230521T093724Z-001\\mobilevit\\model_kayitlari"
    if not os.path.exists(kayit_adresi):
        os.mkdir(kayit_adresi)
        
    zaman = datetime.today().strftime('%d-%m-%Y-%H-%M')
    model_kayit_noktasi_adresi = os.path.join(kayit_adresi,zaman)
    os.mkdir(model_kayit_noktasi_adresi)

    # kayit adresleri
    checkpoint_filepath  = os.path.join(model_kayit_noktasi_adresi,"model2.hdf5")

    # CHECKPOINT
    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath = checkpoint_filepath,
    save_weights_only = True,
    monitor='val_accuracy',
    save_best_only = True)

    # LRR
    lrr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy',  
    factor=.1,  # Factor by which learning rate will be reduced
    patience=20,
    min_lr=1e-5)

    callbacks = [lrr, checkpoint_callback]

    history = mobilevit_xxs.fit(
        x=x_train,
        y=y_train_aug_cat,
        epochs=epochs,
        batch_size=128,
        validation_data=(x_val,y_val_cat),
        callbacks=callbacks
    )


    return mobilevit_xxs, model_kayit_noktasi_adresi


mobilevit_xxs, model_kayit_noktasi_adresi = run_experiment()

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

In [None]:
import numpy as np
from sklearn.metrics import confusion_matrix

#Load the predicted labels and true labels
y_pred = mobilevit_xxs.predict(x_test)  # Load your predicted labels
y_true = y_test  # Load your true labels

#Create the confusion matrix
cm = confusion_matrix(y_true, y_pred.argmax(axis=-1))

tp = np.diag(cm)
fp = np.sum(cm, axis=0) - tp
fn = np.sum(cm, axis=1) - tp

precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * (precision * recall) / (precision + recall)

arr_replaced = np.nan_to_num(f1, nan=0)
#Print the F1 score
print("F1 Score:", np.mean(arr_replaced))

In [None]:
loss, accuracy, top_5_accuracy, kappa_score = mobilevit_xxs.evaluate(x_test, y_test_cat)
print(f"Test loss: {round(loss, 2)}")
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
print(f"Kappa Score:{kappa_score}")

In [16]:
txt_kayit_adresi    = os.path.join(model_kayit_noktasi_adresi, f"agumentasyon_modeli.txt")
try:
    with open(txt_kayit_adresi, "w") as fh:
        fh.write(f"Epok sayisi: {EPOCHS}\nbasarisi: Test accuracy: {round(accuracy * 100, 2)}%\nCohenKappa Skoru: {kappa_score}\nF1 Skor: {np.mean(arr_replaced)}\nResim Boyutu: {image_size}\nPatch Boyutu: {patch_size}\nTest top 5 accuracy: {round(top_5_accuracy * 100, 2)}%\n loss='categorical_crossentropy' metrics=['accuracy']\n")
        data_augmentation.summary(print_fn=lambda x: fh.write(x + '\n'))

except Exception as e2:
    print("Model txt ye yazdirilamadi, hata: ", e2)

In [None]:
grafik_adresi = os.path.join(model_kayit_noktasi_adresi, "Grafik.png")

try:
        
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 12))
    ax1.plot(mobilevit_xxs.history['loss'], color='b', label="Training loss")
    ax1.plot(mobilevit_xxs.history['val_loss'], color='r', label="validation loss")
    ax1.set_xticks(np.arange(0, epochs, 10))
    ax1.legend(loc='upper left')

    ax2.plot(mobilevit_xxs.history['accuracy'], color='b', label="Training accuracy")
    ax2.plot(mobilevit_xxs.history['val_accuracy'], color='r',label="Validation accuracy")
    ax2.set_xticks(np.arange(0, epochs, 10))
    ax2.legend(loc='upper left')

    ax3.plot(mobilevit_xxs.history['top-5-accuracy'], color='b',label="train-top-5-accuracy-accuracy")
    ax3.plot(mobilevit_xxs.history['val_top-5-accuracy'], color='r',label="val_top-5-accuracy")
    ax3.set_xticks(np.arange(0, epochs, 10))
    ax3.legend(loc='upper left')
    plt.savefig(grafik_adresi, bbox_inches='tight', facecolor='w')
    plt.show()

except Exception as e:
    print(f"Grafik çizilemedi {e}")