In [1]:
# Install the tensorflow-addons library
! pip install tensorflow-addons

Collecting tensorflow-addons
  Downloading tensorflow_addons-0.16.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 2.6 MB/s 
Installing collected packages: tensorflow-addons
Successfully installed tensorflow-addons-0.16.1


# Setup and Imports

In [2]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

# Constants

In [3]:
# DATA
BATCH_SIZE = 256
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (32, 32, 3)

# AUGMENTATION
IMAGE_SIZE = 32
PATCH_SIZE = 8  # Size of the patches to be extracted from the input images.
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2

# ViT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 4
NUM_LAYERS = 4
MLP_UNITS = [
    PROJECTION_DIM * 4,
    PROJECTION_DIM,
]
DROPOUT_RATE = 0.0

# Vision Transformer blocks

References:

- https://github.com/huggingface/transformers/blob/master/src/transformers/models/vit_mae/modeling_vit_mae.py
- https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py
- https://keras.io/examples/vision/image_classification_with_vision_transformer/

In [4]:
class PositionEmbedding(keras.layers.Layer):
    """Positional embedding layer.

    Args:
        pos_mode (`string`): Either "sin-cos" or "learn".
        embed_dim (`int`): The dimensions for embedding.
        num_patches (`int`): The number of patches of the image.
        add_cls_token (`boolean`): Whether class token is added or not.
    """
    def __init__(
        self,
        pos_mode,
        embed_dim,
        num_patches,
        add_cls_token,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.pos_mode = pos_mode
        self.embed_dim = embed_dim
        self.num_patches = num_patches
        self.add_cls_token = add_cls_token
        
        # Compute the positions
        positions = self.num_patches
        positions += 1 if self.add_cls_token else 0
        
        # Build the range of patches
        self.pos_flat_patches = tf.range(positions, dtype=tf.float32, delta=1)

    def build(self, input_shape):
        # Encode the positions with an Embedding layer.
        if self.pos_mode == "learn":
            self.pos_embedding = layers.Embedding(
                input_dim=self.num_patches,
                output_dim=self.embed_dim,
                embeddings_initializer=keras.initializers.RandomNormal(
                    stddev=0.02
                ),
            )

    def get_config(self):
        config = super().get_config()
        config.update({
            "pos_mode": self.pos_mode,
            "embed_dim": self.embed_dim,
            "num_patches": self.num_patches,
            "add_cls_token": self.add_cls_token,
        })
        return config

    def get_1d_sincos_pos_embed(self):
        # Build the sine-cosine positional embedding.
        omega = tf.range(self.embed_dim // 2, dtype="float32")
        omega /= self.embed_dim / 2.0
        omega = 1.0 / 10000 ** omega  # (D/2,)

        out = tf.einsum("m,d->md", self.pos_flat_patches, omega)  # (M, D/2), outer product

        emb_sin = tf.sin(out)  # (M, D/2)
        emb_cos = tf.cos(out)  # (M, D/2)

        emb = tf.concat([emb_sin, emb_cos], axis=1)  # (M, D)
        return emb

    def get_learnable_pos_embed(self):
        emb = self.pos_embedding(self.pos_flat_patches)

        return emb

    def call(self, inputs):
        if self.pos_mode == "learn":
            pos_emb = self.get_learnable_pos_embed()
        else:
            pos_emb = self.get_1d_sincos_pos_embed()
        
        # Inject the positional embeddings with the tokens
        inputs = inputs + pos_emb
        return inputs

In [5]:
class TransformerBlock(keras.layers.Layer):
    def __init__(
        self,
        layer_norm_eps,
        num_heads,
        projection_dim,
        dropout_rate,
        mlp_hidden_dim,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.layer_norm_eps = layer_norm_eps
        self.num_heads = num_heads
        self.projection_dim = projection_dim
        self.dropout_rate = dropout_rate
        self.mlp_hidden_dim = mlp_hidden_dim

    def build(self, input_shapes):
        self.layer_norm1 = layers.LayerNormalization(epsilon=self.layer_norm_eps)
        self.mhsa = layers.MultiHeadAttention(
            num_heads=self.num_heads, key_dim=self.projection_dim, dropout=self.dropout_rate
        )
        self.dropout = layers.Dropout(self.dropout_rate)
        self.add = layers.Add()
        self.layer_norm2 = layers.LayerNormalization(epsilon=self.layer_norm_eps)
        self.mlp = self.build_mlp()
    
    def build_mlp(self):
        # Iterate over the hidden units and
        # add Dense => Dropout.
        mlp_layers = list()
        for idx, units in enumerate(self.mlp_hidden_dim):
            mlp_layers.append(
                layers.Dense(
                    units,
                    activation=tf.nn.gelu if idx == 0 else None,
                    kernel_initializer="glorot_uniform",
                    bias_initializer=keras.initializers.RandomNormal(stddev=1e-6),
                )
            )
            mlp_layers.append(
                layers.Dropout(self.dropout_rate)
            )
        
        mlp = keras.Sequential(mlp_layers)
        return mlp

    def get_config(self):
        config = super().get_config()
        config.update({
            "layer_norm_eps": self.layer_norm_eps,
            "num_heads": self.num_heads,
            "projection_dim": self.projection_dim,
            "dropout_rate": self.dropout_rate,
            "mlp_hidden_dim": self.mlp_hidden_dim,
        })
        return config
    
    def call(self, encoded_patches):
        # Layer normalization 1.
        x1 = self.layer_norm1(encoded_patches)

        # Multi Head Self Attention layer 1.
        attention_output = self.mhsa(x1, x1)
        attention_output = self.dropout(attention_output)

        # Skip connection 1.
        x2 = self.add([attention_output, encoded_patches])

        # Layer normalization 2.
        x3 = self.layer_norm2(x2)

        # MLP layer 1.
        x4 = self.mlp(x3)

        # Skip connection 2.
        encoded_patches = self.add([x2, x4])
        return encoded_patches

In [6]:
class ViT(keras.layers.Layer):
    def __init__(
        self,
        pos_mode,
        embed_dim,
        num_patches,
        add_cls_token,
        patch_size,
        batch_size,
        num_layers,
        layer_norm_eps,
        num_heads,
        dropout_rate,
        mlp_hidden_dim,
        is_gap,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.pos_mode = pos_mode
        self.embed_dim = embed_dim
        self.num_patches = num_patches
        self.add_cls_token = add_cls_token
        self.patch_size = patch_size
        self.batch_size = batch_size
        self.num_layers = num_layers
        self.layer_norm_eps = layer_norm_eps
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.mlp_hidden_dim = mlp_hidden_dim
        self.is_gap = is_gap
    
    def build(self, input_shapes):
        # Build the patchification module
        self.patchify = keras.Sequential([
            layers.Conv2D(
                filters=self.embed_dim,
                kernel_size=self.patch_size,
                strides=self.patch_size,
                padding="VALID",
            ),
            layers.Reshape(target_shape=(-1, self.embed_dim))    
        ])
        
        # Build the postional embedding module
        self.position_embedding = PositionEmbedding(
            pos_mode=self.pos_mode,
            embed_dim=self.embed_dim,
            num_patches=self.num_patches,
            add_cls_token=self.add_cls_token
        )
        
        # Build the transformer encoder module
        self.transformer_layers = self.build_transformer()
        
        # If class token needs to be added, then build the class token
        if self.add_cls_token:
            initial_value = tf.zeros((1, 1, self.embed_dim))
            self.class_token = tf.Variable(
                initial_value=initial_value, trainable=True
            )
        
        # If GAP is used, build the GAP layer
        if self.is_gap:
            self.gap = layers.GlobalAveragePooling1D(keepdims=True)
        
    def build_transformer(self):
        self.transformer_layers = list()
        for _ in range(self.num_layers):
            self.transformer_layers.append(
                TransformerBlock(
                    layer_norm_eps=self.layer_norm_eps,
                    num_heads=self.num_heads,
                    projection_dim=self.embed_dim,
                    dropout_rate=self.dropout_rate,
                    mlp_hidden_dim=self.mlp_hidden_dim,
                )
            )
        return self.transformer_layers

    def get_config(self):
        config = super().get_config()
        config.update({
            "pos_mode": self.pos_mode,
            "embed_dim": self.embed_dim,
            "num_patches": self.num_patches,
            "add_cls_token": self.add_cls_token,
            "patch_size": self.patch_size,
            "batch_size": self.batch_size,
            "num_layers": self.num_layers,
            "layer_norm_eps": self.layer_norm_eps,
            "num_heads": self.num_heads,
            "dropout_rate": self.dropout_rate,
            "mlp_hidden_dim": self.mlp_hidden_dim,
            "is_gap": self.is_gap,
        })
        return config

    def call(self, images):
        # Get the patches and linearly project them
        projected_patches = self.patchify(images)

        # Add the class token if add_class_token is True
        if self.add_cls_token:
            class_token = tf.tile(self.class_token, (self.batch_size, 1, 1))
            projected_patches = tf.concat([class_token, projected_patches], axis=1)

        # Inject the positional embedding to the projected patches
        encoded_patches = self.position_embedding(projected_patches)

        x = encoded_patches
        # Pass through the transfomer layers
        for layer in self.transformer_layers:
            x = layer(x)

        # compute the global represenatation vector
        if self.is_gap:
            x = self.gap(x)
        else:
            x = x[:, 0:1]
        
        # Return the logits
        return x

In [7]:
images = tf.random.normal((BATCH_SIZE, ) + INPUT_SHAPE)

vit = ViT(
    pos_mode="learn",
    embed_dim=PROJECTION_DIM,
    num_patches=NUM_PATCHES,
    add_cls_token=True,
    patch_size=PATCH_SIZE,
    batch_size=BATCH_SIZE,
    num_layers=NUM_LAYERS,
    layer_norm_eps=LAYER_NORM_EPS,
    num_heads=NUM_HEADS,
    dropout_rate=DROPOUT_RATE,
    mlp_hidden_dim=MLP_UNITS,
    is_gap=True,
)

logits = vit(images)
print(logits.shape)

(256, 1, 128)
