In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Parameters
image_size = 48
patch_size = 4
num_patches = (image_size // patch_size) ** 2
embedding_dim = 64
num_heads = 4
mlp_hidden_dim = 128
num_transformer_layers = 4
num_classes = 7
dropout_rate = 0.1

class ReduceMeanLayer(layers.Layer):
    def __init__(self, axis=None, **kwargs):
        super().__init__(**kwargs)
        self.axis = axis
    def call(self, x):
        return tf.reduce_mean(x, axis=self.axis)

# 1. Patch Embedding Layer
class PatchExtractor(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        # images: (batch, height, width, channels)
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1,1,1,1],
            padding='VALID'
        )
        # patches shape: (batch, h/p, w/p, patch_size*patch_size*channels)
        patch_dims = tf.shape(patches)[-1]
        # Flatten h/p and w/p:
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

class PatchEmbedding(layers.Layer):
    def __init__(self, num_patches, embedding_dim):
        super().__init__()
        self.num_patches = num_patches
        self.embedding = layers.Dense(embedding_dim)

        self.cls_token = self.add_weight(name="cls_token", shape=[1, 1, embedding_dim])

        self.pos_embedding = self.add_weight(
          name="pos_embedding",
          shape=(1, num_patches+1, embedding_dim),
          initializer='zeros',
          trainable=True
      )

    def call(self, patch_inputs):
        # patch_inputs: (batch, num_patches, patch_dims)
        x = self.embedding(patch_inputs)
        batch_size = tf.shape(x)[0]
        cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0) # (batch, 1, embedding_dim)
        x = tf.concat([cls_tokens, x], axis=1)
        return x + self.pos_embedding

# 2. Transformer Encoder Block
def mlp(x, hidden_dim, dropout_rate):
    x = layers.Dense(hidden_dim, activation='gelu')(x)
    x = layers.Dropout(dropout_rate)(x)
    x = layers.Dense(embedding_dim)(x)
    x = layers.Dropout(dropout_rate)(x)
    return x

class TransformerEncoder(layers.Layer):
    def __init__(self, embedding_dim, num_heads, mlp_hidden_dim, dropout_rate):
        super().__init__()
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embedding_dim, dropout=dropout_rate)
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)
        self.mlp_block = keras.Sequential([
            layers.Dense(mlp_hidden_dim, activation='gelu'),
            layers.Dropout(dropout_rate),
            layers.Dense(embedding_dim),
            layers.Dropout(dropout_rate),
        ])

    def call(self, x):
        # x: (batch, num_patches, embedding_dim)
        # Self-attention
        attn_output = self.attn(x, x)
        x = x + attn_output
        x = self.norm1(x)

        # MLP
        mlp_output = self.mlp_block(x)
        x = x + mlp_output
        x = self.norm2(x)
        return x


# 3. Build the Model
inputs = keras.Input(shape=(image_size, image_size, 1))
patches = PatchExtractor(patch_size)(inputs)
x = PatchEmbedding(num_patches, embedding_dim)(patches)

for _ in range(num_transformer_layers):
    x = TransformerEncoder(embedding_dim, num_heads, mlp_hidden_dim, dropout_rate)(x)

x = layers.LayerNormalization(epsilon=1e-6)(x)
cls_representation = x[:, 0, :]
x = layers.Dense(num_classes, activation='softmax')(cls_representation)

model = keras.Model(inputs, x)
model.summary()

lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=0.001,
    decay_steps=10000,
    alpha=0.1
)

# Compile the model
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2)
]

model.fit(train_generator, validation_data=validation_generator, epochs=15, callbacks=callbacks)

# this version yields
# 221/221 [==============================] - 10s 46ms/step - loss: 1.3991 - accuracy: 0.4515
# Test Loss: 1.3990623950958252
# Test Accuracy: 0.451475590467453