## Setup and imports

In [1]:
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

## Constants

References:

* https://github.com/google-research/vision_transformer/blob/main/vit_jax/configs/models.py#L103

In [2]:
# DATA
BATCH_SIZE = 32
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (224, 224, 3)


# AUGMENTATION
IMAGE_SIZE = 224
PATCH_SIZE = 16  # Size of the patches to be extracted from the input images.
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2

# ViT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 768
NUM_HEADS = 12
NUM_LAYERS = 12
MLP_UNITS = [
    PROJECTION_DIM * 4,
    PROJECTION_DIM,
]
DROPOUT_RATE = 0.0

## Vision Transformer blocks

References:

* https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py
* https://keras.io/examples/vision/image_classification_with_vision_transformer/

In [3]:
def position_embedding(
    projected_patches,
    num_patches=NUM_PATCHES,
    projection_dim=PROJECTION_DIM,
    classifier="token",
):
    # Build the positions.
    num_patches += 1 if classifier == "token" else 0
    positions = tf.range(start=0, limit=num_patches, delta=1)

    # Encode the positions with an Embedding layer.
    encoded_positions = layers.Embedding(
        input_dim=num_patches,
        output_dim=projection_dim,
        embeddings_initializer=keras.initializers.RandomNormal(stddev=0.02),
    )(positions)

    # Add encoded positions to the projected patches.
    return projected_patches + encoded_positions

In [4]:
def mlp(x, dropout_rate, hidden_units):
    # Iterate over the hidden units and
    # add Dense => Dropout.
    for idx, units in enumerate(hidden_units):
        x = layers.Dense(
            units,
            activation=tf.nn.gelu if idx == 0 else None,
            kernel_initializer="glorot_uniform",
            bias_initializer=keras.initializers.RandomNormal(stddev=1e-6),
        )(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

In [5]:
def transformer(encoded_patches):
    # Layer normalization 1.
    x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)

    # Multi Head Self Attention layer 1.
    attention_output = layers.MultiHeadAttention(
        num_heads=NUM_HEADS, key_dim=PROJECTION_DIM, dropout=DROPOUT_RATE
    )(x1, x1)
    attention_output = layers.Dropout(DROPOUT_RATE)(attention_output)

    # Skip connection 1.
    x2 = layers.Add()([attention_output, encoded_patches])

    # Layer normalization 2.
    x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)

    # MLP layer 1.
    x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=DROPOUT_RATE)

    # Skip connection 2.
    encoded_patches = layers.Add()([x2, x4])

    return encoded_patches

In [6]:
class ViTClassifier(keras.Model):
    def __init__(self, classifier="token", num_classes=10):
        super().__init__()
        self.classifier = classifier
        self.num_classes = num_classes

        self.projection = layers.Conv2D(
            filters=PROJECTION_DIM,
            kernel_size=(PATCH_SIZE, PATCH_SIZE),
            strides=(PATCH_SIZE, PATCH_SIZE),
            padding="VALID",
        )

        if self.classifier == "token":
            initial_value = tf.zeros((1, 1, PROJECTION_DIM))
            self.cls_token = tf.Variable(
                initial_value=initial_value, trainable=True, name="cls"
            )

        if self.classifier == "gap":
            self.gap_layer = layers.GlobalAvgPool1D()

        self.dropout = layers.Dropout(DROPOUT_RATE)
        self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)
        self.classifier_head = layers.Dense(num_classes, kernel_initializer="zeros")

    def call(self, inputs):
        # Create patches and project the pathces.
        projected_patches = self.projection(inputs)
        n, h, w, c = projected_patches.shape
        projected_patches = tf.reshape(projected_patches, [n, h * w, c])

        # Append class token if needed.
        if self.classifier == "token":
            cls_token = tf.tile(self.cls_token, (n, 1, 1))
            projected_patches = tf.concat([cls_token, projected_patches], axis=1)

        # Add positional embeddings to the projected patches.
        encoded_patches = position_embedding(
            projected_patches, classifier=self.classifier
        )  # (B, number_patches, projection_dim)
        encoded_patches = self.dropout(encoded_patches)

        # Iterate over the number of layers and stack up blocks of
        # Transformer.
        for i in range(NUM_LAYERS):
            # Add a Transformer block.
            encoded_patches = transformer(encoded_patches)

        # Final layer normalization.
        representation = self.layer_norm(encoded_patches)

        # Pool representation.
        if self.classifier == "token":
            encoded_patches = representation[:, 0]
        elif self.classifier == "gap":
            encoded_patches = self.gap_layer(representation)

        # Classification head.
        output = self.classifier_head(encoded_patches)

        return output

## Verification

In [7]:
vit_classifier = ViTClassifier()
random_logits = vit_classifier(tf.ones((10, 224, 224, 3)))
random_logits.shape

2022-03-12 13:10:28.603093: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


TensorShape([10, 10])

In [8]:
vit_classifier_w_gap = ViTClassifier(classifier="gap")
vit_classifier_w_gap(tf.ones((10, 224, 224, 3))).shape

TensorShape([10, 10])