In [1]:
%load_ext tensorboard
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.layers import Dense, Dropout, LayerNormalization
from tensorflow.keras.layers.experimental.preprocessing import Rescaling

import numpy as np
import os
from argparse import ArgumentParser
import tensorflow_datasets as tfds
from tensorflow.keras.callbacks import TensorBoard


class TransformerBlock(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.att = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.mlp = tf.keras.Sequential(
            [
                Dense(mlp_dim, activation=tfa.activations.gelu),
                Dropout(dropout),
                Dense(embed_dim),
                Dropout(dropout),
            ]
        )
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)
        

    def call(self, inputs, training):
        inputs_norm = self.layernorm1(inputs)
        attn_output = self.att(inputs_norm)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = attn_output + inputs

        out1_norm = self.layernorm2(out1)
        mlp_output = self.mlp(out1_norm)
        mlp_output = self.dropout2(mlp_output, training=training)
        return mlp_output + out1


class VisionTransformer(tf.keras.Model):
    def __init__(
        self,
        image_size,
        patch_size,
        num_layers,
        num_classes,
        d_model,
        num_heads,
        mlp_dim,
        channels=1,
        dropout=0.1,
    ):
        super(VisionTransformer, self).__init__()
        num_patches = (image_size // patch_size) ** 2
        self.patch_dim = channels * patch_size ** 2

        self.patch_size = patch_size
        self.d_model = d_model
        self.num_layers = num_layers

        self.rescale = Rescaling(1.0 / 255)
        self.pos_emb = self.add_weight(
            "pos_emb", shape=(1, num_patches + 1, d_model)
        )
        self.class_emb = self.add_weight("class_emb", shape=(1, 1, d_model))
        self.patch_proj = Dense(d_model)
        self.enc_layers = [
            TransformerBlock(d_model, num_heads, mlp_dim, dropout)
            for _ in range(num_layers)
        ]
        self.mlp_head = tf.keras.Sequential(
            [
                LayerNormalization(epsilon=1e-6),
                Dense(mlp_dim, activation=tfa.activations.gelu),
                Dropout(dropout),
                Dense(num_classes),
            ]
        )

    def extract_patches(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patches = tf.reshape(patches, [batch_size, -1, self.patch_dim])
        return patches

    def call(self, x, training):
        batch_size = tf.shape(x)[0]
        x = self.rescale(x)
        patches = self.extract_patches(x)
        x = self.patch_proj(patches)

        class_emb = tf.broadcast_to(self.class_emb, [batch_size, 1, self.d_model])
        x = tf.concat([class_emb, x], axis=1)
        x = x + self.pos_emb

        for layer in self.enc_layers:
            x = layer(x, training)

        # First (class token) is used for classification
        x = self.mlp_head(x[:, 0])
        return x

 The versions of TensorFlow you are currently using is 2.4.1 and is not supported. 
Some things might work, some things might not.
If you were to encounter a bug, do not file an issue.
If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
You can find the compatibility matrix in TensorFlow Addon's readme:
https://github.com/tensorflow/addons


In [2]:
class MultiHeadSelfAttention(tf.keras.layers.Layer):
    # TODO: moved init arguments to constants embed_dim and num_heads
    def __init__(self):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = 64
        self.num_heads = 4
        #if embed_dim % num_heads != 0:
        #    raise ValueError(
        #        f"embedding dimension = {embed_dim} should be divisible by number of heads = {num_heads}"
        #    )
        self.projection_dim = self.embed_dim // self.num_heads
        self.query_dense = Dense(self.embed_dim)
        self.key_dense = Dense(self.embed_dim)
        self.value_dense = Dense(self.embed_dim)
        self.combine_heads = Dense(self.embed_dim)

    def attention(self, query, key, value):
        score = tf.matmul(query, key, transpose_b=True)
        dim_key = tf.cast(tf.shape(key)[-1], tf.float32)
        scaled_score = score / tf.math.sqrt(dim_key)
        weights = tf.nn.softmax(scaled_score, axis=-1)
        output = tf.matmul(weights, value)
        return output, weights

    def separate_heads(self, x, batch_size):
        x = tf.reshape(
            x, (batch_size, -1, self.num_heads, self.projection_dim)
        )
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'embed_dim': self.embed_dim,
            'num_heads': self.num_heads,
            'projection_dim': self.projection_dim,
            'query_dense': self.query_dense,
            'key_dense': self.key_dense,
            'value_dense': self.value_dense,
            'combine_heads': self.combine_heads,
        })
        return config

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        query = self.query_dense(inputs)
        key = self.key_dense(inputs)
        value = self.value_dense(inputs)
        query = self.separate_heads(query, batch_size)
        key = self.separate_heads(key, batch_size)
        value = self.separate_heads(value, batch_size)

        attention, weights = self.attention(query, key, value)
        attention = tf.transpose(attention, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(
            attention, (batch_size, -1, self.embed_dim)
        )
        output = self.combine_heads(concat_attention)
        return output

In [3]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

logdir = "logs"
image_size = 28
patch_size = 4
num_layers = 4
d_model = 64
num_heads = 4
mlp_dim = 128
lr = 3e-4
weight_decay = 1e-4
batch_size = 16
epochs = 10
channels = 1
dropout = 0.1
num_classes = 10
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Normalize input so we can train ANN with it.
# Will be converted back to integers for SNN layer.
# x_train = x_train / 255
# x_test = x_test / 255

# Add a channel dimension.
axis = 1 if tf.keras.backend.image_data_format() == 'channels_first' else -1
x_train = tf.expand_dims(x_train, axis)
x_test = tf.expand_dims(x_test, axis)

# One-hot encode target vectors.
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

def extract_patches(images):
    batch_size = tf.shape(images)[0]
    
    patches = tf.image.extract_patches(
        images=images,
        sizes=[1, patch_size, patch_size, 1],
        strides=[1, patch_size, patch_size, 1],
        rates=[1, 1, 1, 1],
        padding="VALID",
    )
    patches = tf.reshape(patches, [batch_size, -1, patch_dim])
    return patches

In [4]:
input_shape = x_train.shape[1:]
inp = tf.keras.layers.Input(shape=(input_shape))
x = Rescaling(1.0 / 255)(inp)

# =============== VISION PART =====================
# patching, positional embedding and class embedding
patches = extract_patches(x)
x = Dense(d_model)(patches)

pos_emb = tf.Variable(initial_value=tf.random.uniform(shape=(1, num_patches + 1, d_model)), 
                      name="pos_emb", validate_shape=(1, num_patches + 1, d_model), trainable=True)
class_emb = tf.Variable(initial_value=tf.random.uniform(shape=(1, 1, d_model)), name="class_emb", 
                        validate_shape=(1, 1, d_model), trainable=True)

class_emb = tf.broadcast_to(class_emb, [batch_size, 1, d_model])

x = tf.concat([class_emb, x], axis=1)
x = tf.keras.layers.Add()([x, pos_emb])

# Transformer Blocks
#x = TransformerBlock(d_model, num_heads, mlp_dim, dropout)(x)
# 1 ============ ATTENTION BLOCK ===================
inputs_norm = LayerNormalization(epsilon=1e-6)(x)
attn_output = MultiHeadSelfAttention()(inputs_norm)
attn_output = Dropout(dropout)(attn_output)
out1 = attn_output + x

out1_norm = LayerNormalization(epsilon=1e-6)(out1)

# MLP module inside of attention block
mlp_dense = Dense(mlp_dim, activation=tf.nn.relu)(out1_norm)
mlp_dropout = Dropout(dropout)(mlp_dense)
mlp_dense = Dense(d_model)(mlp_dropout)
mlp_dropout = Dropout(dropout)(mlp_dense)

mlp_output = Dropout(dropout)(mlp_dropout)
x = mlp_output + out1

# 2 ============ ATTENTION BLOCK ===================
inputs_norm = LayerNormalization(epsilon=1e-6)(x)
attn_output = MultiHeadSelfAttention()(inputs_norm)
attn_output = Dropout(dropout)(attn_output)
out1 = attn_output + x

out1_norm = LayerNormalization(epsilon=1e-6)(out1)

# MLP module inside of attention block
mlp_dense = Dense(mlp_dim, activation=tf.nn.relu)(out1_norm)
mlp_dropout = Dropout(dropout)(mlp_dense)
mlp_dense = Dense(d_model)(mlp_dropout)
mlp_dropout = Dropout(dropout)(mlp_dense)

mlp_output = Dropout(dropout)(mlp_dropout)
x = mlp_output + out1

# 3 ============ ATTENTION BLOCK ===================
inputs_norm = LayerNormalization(epsilon=1e-6)(x)
attn_output = MultiHeadSelfAttention()(inputs_norm)
attn_output = Dropout(dropout)(attn_output)
out1 = attn_output + x

out1_norm = LayerNormalization(epsilon=1e-6)(out1)

# MLP module inside of attention block
mlp_dense = Dense(mlp_dim, activation=tf.nn.relu)(out1_norm)
mlp_dropout = Dropout(dropout)(mlp_dense)
mlp_dense = Dense(d_model)(mlp_dropout)
mlp_dropout = Dropout(dropout)(mlp_dense)

mlp_output = Dropout(dropout)(mlp_dropout)
x = mlp_output + out1

# 4 ============ ATTENTION BLOCK ===================
inputs_norm = LayerNormalization(epsilon=1e-6)(x)
attn_output = MultiHeadSelfAttention()(inputs_norm)
attn_output = Dropout(dropout)(attn_output)
out1 = attn_output + x

out1_norm = LayerNormalization(epsilon=1e-6)(out1)

# MLP module inside of attention block
mlp_dense = Dense(mlp_dim, activation=tf.nn.relu)(out1_norm)
mlp_dropout = Dropout(dropout)(mlp_dense)
mlp_dense = Dense(d_model)(mlp_dropout)
mlp_dropout = Dropout(dropout)(mlp_dense)

mlp_output = Dropout(dropout)(mlp_dropout)
x = mlp_output + out1

# ================= MLP HEAD ===================
x = LayerNormalization(epsilon=1e-6)(x[:, 0])
x = Dense(mlp_dim, activation=tf.nn.relu)(x)
x = Dropout(dropout)(x)
x = Dense(num_classes)(x)

# ================ Model compilation and training ==================
model = tf.keras.models.Model(inputs=inp, outputs=x)

model.compile(
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    optimizer="adam",
    metrics=["accuracy"],
)

model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=1, validation_data=(x_test, y_test))

ValueError: Can not merge tensors with different batch sizes. Got tensors with shapes : [(16, None, 64), (1, 50, 64)]

In [6]:
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
rescaling_1 (Rescaling)         (None, 28, 28, 1)    0           input_2[0][0]                    
__________________________________________________________________________________________________
tf.compat.v1.shape_1 (TFOpLambd (4,)                 0           rescaling_1[0][0]                
__________________________________________________________________________________________________
tf.image.extract_patches_1 (TFO (None, 7, 7, 16)     0           rescaling_1[0][0]                
____________________________________________________________________________________________

In [7]:
from tensorflow import keras
keras.models.save_model(
    model, 
    os.path.join("/home/viktor/PycharmProjects/guided_research/transformer-to-snn-conversion", 
    "mnist_transformer" + '.h5')
)

TypeError: Layer tf.__operators__.add_9 was passed non-JSON-serializable arguments. Arguments had types: {'y': <class 'tensorflow.python.ops.resource_variable_ops.ResourceVariable'>, 'name': <class 'NoneType'>}. They cannot be serialized out when saving the model.

In [14]:
%tensorboard --logdir logs

In [42]:
reconstructed_model = keras.models.load_model(
    os.path.join("/home/viktor/PycharmProjects/guided_research/transformer-to-snn-conversion", 
    "mnist_transformer" + '.h5'))

ValueError: Unknown layer: MultiHeadSelfAttention