In [1]:
! pip install tensorflow-addons -qq
! pip install ml-collections -qq

[K     |████████████████████████████████| 1.1 MB 14.1 MB/s 
[K     |████████████████████████████████| 77 kB 4.9 MB/s 
[?25h  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone


# Setup and Imports

In [2]:
import ml_collections
import tensorflow as tf
import tensorflow_addons as tfa

from tensorflow import keras
from tensorflow.keras import layers

# Configuration

References:

- https://github.com/google-research/vision_transformer/blob/main/vit_jax/configs/models.py#L103

In [3]:
def get_config() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()

    config.batch_size = 32
    config.input_shape = (224, 224, 3)

    config.image_size = 224
    config.patch_size = 16
    config.num_patches = (config.image_size // config.patch_size) ** 2
    config.num_classes = 10

    config.pos_emb_mode = "sin-cos"

    config.layer_norm_eps = 1e-6
    config.projection_dim = 768
    config.num_heads = 12
    config.num_layers = 12
    config.mlp_units = [
        config.projection_dim * 4,
        config.projection_dim,
    ]
    config.dropout_rate = 0.0
    config.classifier = "token"

    return config.lock()

# Vision Transformer blocks

References:

- https://github.com/huggingface/transformers/blob/master/src/transformers/models/vit_mae/modeling_vit_mae.py
- https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py
- https://keras.io/examples/vision/image_classification_with_vision_transformer/

In [4]:
class PositionalEmbedding(layers.Layer):
    def __init__(self, config: ml_collections.ConfigDict, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        
        # Compute the positions
        positions = self.config.num_patches
        positions += 1 if self.config.classifier == "token" else 0
        
        # Build the sequence of positions in 1D
        self.pos_flat_patches = tf.range(positions, dtype=tf.float32, delta=1)

        # Encode the positions with an Embedding layer.
        if self.config.pos_emb_mode == "learn":
            self.pos_embedding = layers.Embedding(
                input_dim=self.config.num_patches,
                output_dim=self.config.projection_dim,
                embeddings_initializer=keras.initializers.RandomNormal(stddev=0.02),
            )

    def get_config(self):
        config = super().get_config()
        config.update(self.config)
        return config

    def get_1d_sincos_pos_embed(self):
        # Build the sine-cosine positional embedding.
        omega = tf.range(self.config.projection_dim // 2, dtype=tf.float32)
        omega /= self.config.projection_dim / 2.0
        omega = 1.0 / 10000 ** omega  # (D/2,)

        out = tf.einsum("m,d->md", self.pos_flat_patches, omega)  # (M, D/2), outer product

        emb_sin = tf.sin(out)  # (M, D/2)
        emb_cos = tf.cos(out)  # (M, D/2)

        emb = tf.concat([emb_sin, emb_cos], axis=1)  # (M, D)
        return emb

    def get_learnable_pos_embed(self):
        emb = self.pos_embedding(self.pos_flat_patches)
        return emb

    def call(self, inputs):
        if self.config.pos_emb_mode == "learn":
            pos_emb = self.get_learnable_pos_embed()
        else:
            pos_emb = self.get_1d_sincos_pos_embed()
        
        # Inject the positional embeddings with the tokens
        outputs = inputs + pos_emb
        return outputs

In [5]:
def mlp(x, dropout_rate, hidden_units):
    # Iterate over the hidden units and
    # add Dense => Dropout.
    for idx, units in enumerate(hidden_units):
        x = layers.Dense(
            units,
            activation=tf.nn.gelu if idx == 0 else None,
            kernel_initializer="glorot_uniform",
            bias_initializer=keras.initializers.RandomNormal(stddev=1e-6),
        )(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

In [6]:
def transformer(config, name):
    num_patches = (
        config.num_patches + 1
        if config.classifier == "token"
        else config.num_patches + 0
    )
    encoded_patches = layers.Input((num_patches, config.projection_dim))

    # Layer normalization 1.
    x1 = layers.LayerNormalization(epsilon=config.layer_norm_eps)(encoded_patches)

    # Multi Head Self Attention layer 1.
    attention_output = layers.MultiHeadAttention(
        num_heads=config.num_heads,
        key_dim=config.projection_dim,
        dropout=config.dropout_rate,
    )(x1, x1)
    attention_output = layers.Dropout(config.dropout_rate)(attention_output)

    # Skip connection 1.
    x2 = layers.Add()([attention_output, encoded_patches])

    # Layer normalization 2.
    x3 = layers.LayerNormalization(epsilon=config.layer_norm_eps)(x2)

    # MLP layer 1.
    x4 = mlp(x3, hidden_units=config.mlp_units, dropout_rate=config.dropout_rate)

    # Skip connection 2.
    outputs = layers.Add()([x2, x4])

    return keras.Model(encoded_patches, outputs, name=name)

In [7]:
class ViTClassifier(keras.Model):
    def __init__(self, config: ml_collections.ConfigDict, **kwargs):
        super().__init__(**kwargs)
        self.config = config

        self.projection = layers.Conv2D(
            filters=config.projection_dim,
            kernel_size=(config.patch_size, config.patch_size),
            strides=(config.patch_size, config.patch_size),
            padding="VALID",
            name="projection",
        )

        self.positional_embedding = PositionalEmbedding(
            config, name="positional_embedding"
        )
        self.transformer_blocks = [
            transformer(config, name=f"transformer_block_{i}")
            for i in range(config.num_layers)
        ]

        if config.classifier == "token":
            initial_value = tf.zeros((1, 1, config.projection_dim))
            self.cls_token = tf.Variable(
                initial_value=initial_value, trainable=True, name="cls"
            )

        if config.classifier == "gap":
            self.gap_layer = layers.GlobalAvgPool1D()

        self.dropout = layers.Dropout(config.dropout_rate)
        self.layer_norm = layers.LayerNormalization(epsilon=config.layer_norm_eps)
        self.classifier_head = layers.Dense(
            config.num_classes, kernel_initializer="zeros", name="classifier"
        )

    def call(self, inputs):
        # Create patches and project the pathces.
        projected_patches = self.projection(inputs)
        n, h, w, c = projected_patches.shape
        projected_patches = tf.reshape(projected_patches, [n, h * w, c])

        # Append class token if needed.
        if self.config.classifier == "token":
            cls_token = tf.tile(self.cls_token, (n, 1, 1))
            projected_patches = tf.concat([cls_token, projected_patches], axis=1)

        # Add positional embeddings to the projected patches.
        encoded_patches = self.positional_embedding(
            projected_patches
        )  # (B, number_patches, projection_dim)
        encoded_patches = self.dropout(encoded_patches)

        # Iterate over the number of layers and stack up blocks of
        # Transformer.
        for transformer_module in self.transformer_blocks:
            # Add a Transformer block.
            encoded_patches = transformer_module(encoded_patches)

        # Final layer normalization.
        representation = self.layer_norm(encoded_patches)

        # Pool representation.
        if self.config.classifier == "token":
            encoded_patches = representation[:, 0]
        elif self.config.classifier == "gap":
            encoded_patches = self.gap_layer(representation)

        # Classification head.
        output = self.classifier_head(encoded_patches)

        return output

# Verification

| Name | Classifier | Pos Embed Mode |
| :-- | :--: | :--: |
| vit_classifier | `CLS TOKEN` | `sin-cos` |
| vit_classifier_pos_learn | `CLS TOKEN` | `learn` |
| vit_classifier_w_gap | `GAP` | `sin-cos` |
| vit_classifier_w_gap_pos_learn | `GAP` | `learn` |

In [8]:
vit_b16_config = get_config()

In [9]:
print(f"classifier: {vit_b16_config.classifier}\npos_emb_mode: {vit_b16_config.pos_emb_mode}")
vit_classifier = ViTClassifier(vit_b16_config, name="vit_cls_token_pos_sincos")
random_logits = vit_classifier(tf.ones((10, 224, 224, 3)))
random_logits.shape

classifier: token
pos_emb_mode: sin-cos


TensorShape([10, 10])

In [10]:
with vit_b16_config.unlocked():
    vit_b16_config.pos_emb_mode = "learn"

print(f"classifier: {vit_b16_config.classifier}\npos_emb_mode: {vit_b16_config.pos_emb_mode}")

vit_classifier_pos_learn = ViTClassifier(vit_b16_config, name="vit_cls_token_pos_learn")
random_logits = vit_classifier_pos_learn(tf.ones((10, 224, 224, 3)))
random_logits.shape

classifier: token
pos_emb_mode: learn


TensorShape([10, 10])

In [11]:
with vit_b16_config.unlocked():
    vit_b16_config.pos_emb_mode = "sin-cos"
    vit_b16_config.classifier = "gap"

print(f"classifier: {vit_b16_config.classifier}\npos_emb_mode: {vit_b16_config.pos_emb_mode}")
vit_classifier_w_gap = ViTClassifier(vit_b16_config, name="vit_with_gap_pos_sincos")
random_logits = vit_classifier_w_gap(tf.ones((10, 224, 224, 3)))
random_logits.shape

classifier: gap
pos_emb_mode: sin-cos


TensorShape([10, 10])

In [12]:
with vit_b16_config.unlocked():
    vit_b16_config.pos_emb_mode = "learn"
    vit_b16_config.classifier = "gap"

print(f"classifier: {vit_b16_config.classifier}\npos_emb_mode: {vit_b16_config.pos_emb_mode}")
vit_classifier_w_gap_pos_learn = ViTClassifier(vit_b16_config, name="vit_with_gap_pos_learn")
random_logits = vit_classifier_w_gap_pos_learn(tf.ones((10, 224, 224, 3)))
random_logits.shape

classifier: gap
pos_emb_mode: learn


TensorShape([10, 10])

# Layer Inspection

In [13]:
vit_classifier_w_gap.summary()

Model: "vit_with_gap_pos_sincos"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 projection (Conv2D)         multiple                  590592    
                                                                 
 positional_embedding (Posit  multiple                 0         
 ionalEmbedding)                                                 
                                                                 
 transformer_block_0 (Functi  (None, 196, 768)         33065472  
 onal)                                                           
                                                                 
 transformer_block_1 (Functi  (None, 196, 768)         33065472  
 onal)                                                           
                                                                 
 transformer_block_2 (Functi  (None, 196, 768)         33065472  
 onal)                                     

In [14]:
vit_classifier_w_gap.get_layer("transformer_block_3").summary()

Model: "transformer_block_3"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_28 (InputLayer)          [(None, 196, 768)]   0           []                               
                                                                                                  
 layer_normalization_56 (LayerN  (None, 196, 768)    1536        ['input_28[0][0]']               
 ormalization)                                                                                    
                                                                                                  
 multi_head_attention_27 (Multi  (None, 196, 768)    28339968    ['layer_normalization_56[0][0]', 
 HeadAttention)                                                   'layer_normalization_56[0][0]'] 
                                                                                