## Setup and imports

In [1]:
import ml_collections
import tensorflow as tf
import tensorflow_addons as tfa

from tensorflow import keras
from tensorflow.keras import layers

## Configuration

References:

* https://github.com/google-research/vision_transformer/blob/main/vit_jax/configs/models.py#L103

In [2]:
def get_config() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()

    config.batch_size = 32
    config.input_shape = (224, 224, 3)

    config.image_size = 224
    config.patch_size = 16
    config.num_patches = (config.image_size // config.patch_size) ** 2
    config.num_classes = 10
    
    config.layer_norm_eps = 1e-6
    config.projection_dim = 768
    config.num_heads = 12
    config.num_layers = 12
    config.mlp_units = [
        config.projection_dim * 4,
        config.projection_dim,
    ]
    config.dropout_rate = 0.0
    config.classifier = "token"

    return config.lock()

## Vision Transformer blocks

References:

* https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py
* https://keras.io/examples/vision/image_classification_with_vision_transformer/

In [3]:
class PositionalEmbedding(layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        
        self.num_patches = (
            self.config.num_patches + 1
            if self.config.classifier == "token"
            else self.config.num_patches + 0
        )
        self.positions = tf.range(start=0, limit=self.num_patches, delta=1)

        self.embedding = layers.Embedding(
            input_dim=self.num_patches,
            output_dim=self.config.projection_dim,
            embeddings_initializer=keras.initializers.RandomNormal(stddev=0.02),
        )

    def call(self, inputs):
        encoded_positions = self.embedding(self.positions)
        return inputs + encoded_positions

    def get_config(self):
        config = super().get_config()
        config.update(self.config)
        return config

In [4]:
def mlp(x, dropout_rate, hidden_units):
    # Iterate over the hidden units and
    # add Dense => Dropout.
    for idx, units in enumerate(hidden_units):
        x = layers.Dense(
            units,
            activation=tf.nn.gelu if idx == 0 else None,
            kernel_initializer="glorot_uniform",
            bias_initializer=keras.initializers.RandomNormal(stddev=1e-6),
        )(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

In [5]:
def transformer(config, name):
    num_patches = (
        config.num_patches + 1
        if config.classifier == "token"
        else config.num_patches + 0
    )
    encoded_patches = layers.Input((num_patches, config.projection_dim))

    # Layer normalization 1.
    x1 = layers.LayerNormalization(epsilon=config.layer_norm_eps)(encoded_patches)

    # Multi Head Self Attention layer 1.
    attention_output = layers.MultiHeadAttention(
        num_heads=config.num_heads,
        key_dim=config.projection_dim,
        dropout=config.dropout_rate,
    )(x1, x1)
    attention_output = layers.Dropout(config.dropout_rate)(attention_output)

    # Skip connection 1.
    x2 = layers.Add()([attention_output, encoded_patches])

    # Layer normalization 2.
    x3 = layers.LayerNormalization(epsilon=config.layer_norm_eps)(x2)

    # MLP layer 1.
    x4 = mlp(x3, hidden_units=config.mlp_units, dropout_rate=config.dropout_rate)

    # Skip connection 2.
    outputs = layers.Add()([x2, x4])

    return keras.Model(encoded_patches, outputs, name=name)

In [6]:
class ViTClassifier(keras.Model):
    def __init__(self, config: ml_collections.ConfigDict, **kwargs):
        super().__init__(**kwargs)
        self.config = config

        self.projection = layers.Conv2D(
            filters=config.projection_dim,
            kernel_size=(config.patch_size, config.patch_size),
            strides=(config.patch_size, config.patch_size),
            padding="VALID",
            name="projection",
        )

        self.positional_embedding = PositionalEmbedding(
            config, name="positional_embedding"
        )
        self.transformer_blocks = [
            transformer(config, name=f"transformer_block_{i}")
            for i in range(config.num_layers)
        ]

        if config.classifier == "token":
            initial_value = tf.zeros((1, 1, config.projection_dim))
            self.cls_token = tf.Variable(
                initial_value=initial_value, trainable=True, name="cls"
            )

        if config.classifier == "gap":
            self.gap_layer = layers.GlobalAvgPool1D()

        self.dropout = layers.Dropout(config.dropout_rate)
        self.layer_norm = layers.LayerNormalization(epsilon=config.layer_norm_eps)
        self.classifier_head = layers.Dense(
            config.num_classes, kernel_initializer="zeros", name="classifier"
        )

    def call(self, inputs):
        # Create patches and project the pathces.
        projected_patches = self.projection(inputs)
        n, h, w, c = projected_patches.shape
        projected_patches = tf.reshape(projected_patches, [n, h * w, c])

        # Append class token if needed.
        if self.config.classifier == "token":
            cls_token = tf.tile(self.cls_token, (n, 1, 1))
            projected_patches = tf.concat([cls_token, projected_patches], axis=1)

        # Add positional embeddings to the projected patches.
        encoded_patches = self.positional_embedding(
            projected_patches
        )  # (B, number_patches, projection_dim)
        encoded_patches = self.dropout(encoded_patches)

        # Iterate over the number of layers and stack up blocks of
        # Transformer.
        for transformer_module in self.transformer_blocks:
            # Add a Transformer block.
            encoded_patches = transformer_module(encoded_patches)

        # Final layer normalization.
        representation = self.layer_norm(encoded_patches)

        # Pool representation.
        if self.config.classifier == "token":
            encoded_patches = representation[:, 0]
        elif self.config.classifier == "gap":
            encoded_patches = self.gap_layer(representation)

        # Classification head.
        output = self.classifier_head(encoded_patches)

        return output

## Verification

In [7]:
vit_b16_config = get_config()

In [8]:
vit_classifier = ViTClassifier(vit_b16_config, name="vit_with_cls_token")
random_logits = vit_classifier(tf.ones((10, 224, 224, 3)))
random_logits.shape

2022-03-14 10:52:50.065208: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


TensorShape([10, 10])

In [9]:
with vit_b16_config.unlocked():
    vit_b16_config.classifier = "gap"

vit_classifier_w_gap = ViTClassifier(vit_b16_config, name="vit_with_gap")
vit_classifier_w_gap(tf.ones((10, 224, 224, 3))).shape

TensorShape([10, 10])

## Layer inspection

In [10]:
vit_classifier_w_gap.summary()

Model: "vit_with_gap"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 projection (Conv2D)         multiple                  590592    
                                                                 
 positional_embedding (Posit  multiple                 150528    
 ionalEmbedding)                                                 
                                                                 
 transformer_block_0 (Functi  (None, 196, 768)         33065472  
 onal)                                                           
                                                                 
 transformer_block_1 (Functi  (None, 196, 768)         33065472  
 onal)                                                           
                                                                 
 transformer_block_2 (Functi  (None, 196, 768)         33065472  
 onal)                                                

In [11]:
vit_classifier_w_gap.get_layer("transformer_block_3").summary()

Model: "transformer_block_3"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_16 (InputLayer)          [(None, 196, 768)]   0           []                               
                                                                                                  
 layer_normalization_31 (LayerN  (None, 196, 768)    1536        ['input_16[0][0]']               
 ormalization)                                                                                    
                                                                                                  
 multi_head_attention_15 (Multi  (None, 196, 768)    28339968    ['layer_normalization_31[0][0]', 
 HeadAttention)                                                   'layer_normalization_31[0][0]'] 
                                                                                