In [2]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout,  Conv2D, Input, Lambda, Flatten, TimeDistributed
from tensorflow.keras.layers import Add, Reshape, MaxPooling2D, Concatenate, Embedding, RepeatVector, \
BatchNormalization, MultiHeadAttention
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
import keras
import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
#from tensorflow.keras.layers.core import Dense, Dropout, Activation
from keras.utils import np_utils
from keras.engine.topology import Layer
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.utils import to_categorical

# Models

## 0 - Default MultiHeadAttention

In [4]:
num_classes = 10

# the data, shuffled and split between tran and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize input so we can train ANN with it.
# Will be converted back to integers for SNN layer.
x_train = x_train / 255
x_test = x_test / 255

# Add a channel dimension.
axis = 1 if keras.backend.image_data_format() == 'channels_first' else -1
x_train = np.expand_dims(x_train, axis)
x_test = np.expand_dims(x_test, axis)

# One-hot encode target vectors.
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

inp = Input(shape = (28, 28, 1))
x = Conv2D(32, (2,2), activation='relu', padding='same')(inp)
x = MaxPooling2D(pool_size=(2, 2))(x)
c = Conv2D(64, (2,2), activation='relu')
x = c(x)
x = MaxPooling2D(pool_size=(2, 2), padding='same')(x)
x = Conv2D(64*3, (2,2), activation='relu')(x)


x = Reshape([6*6,64*3])(x)


att = MultiHeadAttention(num_heads=4, key_dim=64)
x = att(x,x)

# x = Reshape([6,6,32])(x)

# x = BatchNormalization()(x)

x = Flatten()(x) 
x = Dense(256, activation='relu')(x)
x = Dense(num_classes, activation='softmax')(x)

model = Model(inputs=inp, outputs=x)
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=1, verbose=1, validation_data=(x_test, y_test))

 175/1875 [=>............................] - ETA: 46s - loss: 1.9277 - accuracy: 0.2516

KeyboardInterrupt: 

In [49]:
model.summary()

Model: "model_21"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_32 (InputLayer)           [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d_84 (Conv2D)              (None, 28, 28, 32)   160         input_32[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_56 (MaxPooling2D) (None, 14, 14, 32)   0           conv2d_84[0][0]                  
__________________________________________________________________________________________________
conv2d_85 (Conv2D)              (None, 13, 13, 64)   8256        max_pooling2d_56[0][0]           
___________________________________________________________________________________________

## 1 - Multi head attention, defined in a functional way from github

In [2]:
def MultiHeadsAttModel(l=8*8, d=512, dv=64, dout=512, nv = 8 ):

    v1 = tf.keras.layers.Input(shape = (l, d))
    q1 = tf.keras.layers.Input(shape = (l, d))
    k1 = tf.keras.layers.Input(shape = (l, d))

    v2 = tf.keras.layers.Dense(dv*nv, activation = "relu")(v1)
    q2 = tf.keras.layers.Dense(dv*nv, activation = "relu")(q1)
    k2 = tf.keras.layers.Dense(dv*nv, activation = "relu")(k1)

    v = tf.keras.layers.Reshape([l, nv, dv])(v2)
    q = tf.keras.layers.Reshape([l, nv, dv])(q2)
    k = tf.keras.layers.Reshape([l, nv, dv])(k2)
    
    # CAN BE SIMULATED WITH tf experiment EinsumDense, but it has weights, which it learns
    att = tf.einsum('baik,baij->bakj',q, k)/np.sqrt(dv)
    #att = Lambda(lambda x: K.batch_dot(x[0],x[1] ,axes=[-1,-1]) / np.sqrt(dv),output_shape=(l, nv, nv))([q,k])# l, nv, nv
    #att = tf.einsum('', q, k)
    att = tf.keras.layers.Softmax(axis=-1)(att)
#     att = Lambda(lambda x:  K.softmax(x) , output_shape=(l, nv, nv))(att)
    out = tf.einsum('bajk,baik->baji',att, v)
    #out = Lambda(lambda x: K.batch_dot(x[0], x[1],axes=[2,2]),  output_shape=(l, nv, dv))([att, v])
    out = tf.keras.layers.Reshape([l, d])(out)
    
    out = tf.keras.layers.Add()([out, q1])

    out = tf.keras.layers.Dense(dout, activation = "relu")(out)

    return  Model(inputs=[q1,k1,v1], outputs=out)

### Functional Modular ViT with Convolutions

In [3]:
"""
Fully Functional Modular version with matmul instead of einsums
"""

num_classes = 10

# the data, shuffled and split between tran and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize input so we can train ANN with it.
# Will be converted back to integers for SNN layer.
x_train = x_train / 255
x_test = x_test / 255

# Add a channel dimension.
axis = 1 if keras.backend.image_data_format() == 'channels_first' else -1
x_train = np.expand_dims(x_train, axis)
x_test = np.expand_dims(x_test, axis)

# One-hot encode target vectors.
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

inp = Input(shape = (28, 28, 1))
x = Conv2D(32, (2,2), activation='relu', padding='same')(inp)
x = MaxPooling2D(pool_size=(2, 2))(x)
c = Conv2D(64, (2,2), activation='relu')
x = c(x)
x = MaxPooling2D(pool_size=(2, 2), padding='same')(x)
x = Conv2D(64*3, (2,2), activation='relu')(x)


x = Reshape([6*6,64*3])(x)

# att = MultiHeadsAttModel(l=6*6, d=64*3 , dv=8*3, dout=32, nv = 8 )
# x = att([x,x,x])
l=6*6
d=64*3
dv=8*3
dout=32
nv = 8
# v1 = tf.keras.layers.Input()(x)
# q1 = tf.keras.layers.Input()(x)
# k1 = tf.keras.layers.Input()(x)

v2 = tf.keras.layers.Dense(dv*nv, activation = "relu")(x)
q2 = tf.keras.layers.Dense(dv*nv, activation = "relu")(x)
k2 = tf.keras.layers.Dense(dv*nv, activation = "relu")(x)

v = tf.keras.layers.Reshape([l, nv, dv])(v2)
q = tf.keras.layers.Reshape([l, nv, dv])(q2)
k = tf.keras.layers.Reshape([l, nv, dv])(k2)

att = tf.matmul(q, k, transpose_b=True)/np.sqrt(dv)
# att = tf.einsum('baik,baij->bakj',q, k)/np.sqrt(dv)
# print(att.shape)
# print(q.shape)
# print(k.shape)
att = tf.keras.layers.Softmax(axis=-1)(att)
# out = tf.einsum('bajk,baik->baji',att, v)
out = tf.matmul(att, v)
out = tf.keras.layers.Reshape([l, d])(out)

out = tf.keras.layers.Add()([out, x])

out = tf.keras.layers.Dense(dout, activation = "relu")(out)

x = Reshape([6,6,32])(out)

x = BatchNormalization()(x)

x = Flatten()(x) 
x = Dense(256, activation='relu')(x)
x = Dense(num_classes, activation='softmax')(x)

model = Model(inputs=inp, outputs=x)
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=1, verbose=1, validation_data=(x_test, y_test))



<tensorflow.python.keras.callbacks.History at 0x7fa3b3136c40>

### Functional modular ViT compatible with snn_toolbox

In [5]:
def extract_patches(images, patch_size, patch_dim):
    batch_size = tf.shape(images)[0]
    print(f"Extract patches {batch_size}")
    patches = tf.image.extract_patches(
        images=images,
        sizes=[1, patch_size, patch_size, 1],
        strides=[1, patch_size, patch_size, 1],
        rates=[1, 1, 1, 1],
        padding="VALID",
    )
    patches = tf.reshape(patches, [batch_size, -1, patch_dim])
    return patches

In [6]:
class ScaleLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(ScaleLayer, self).__init__()

    def call(self, inputs):
        scale = inputs[1]
        return inputs[0] / scale

    
class MatMulLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(MatMulLayer, self).__init__()

    def call(self, inputs):
        return tf.matmul(inputs[0], inputs[1])
    

class TransposeLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(TransposeLayer, self).__init__()

    def call(self, inputs):
        return tf.transpose(inputs, perm=[0, 2, 1, 3])

    
class SqueezeLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(SqueezeLayer, self).__init__()

    def call(self, inputs):
        return tf.squeeze(inputs, axis=1)

class SliceLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(SliceLayer, self).__init__()

    def call(self, inputs):
        return inputs[:, 0]
    
class MatMulLayerTranspose(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(MatMulLayerTranspose, self).__init__()

    def call(self, inputs):
        return tf.matmul(inputs[0], inputs[1], transpose_b=True)

    
class PositionalEncodingLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(PositionalEncodingLayer, self).__init__()
        self.pos_emb = self.add_weight("pos_emb", shape=(1, num_patches + 1, d_model))
        self.class_emb = self.add_weight("class_emb", shape=(1, 1, d_model))

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        class_emb = tf.broadcast_to(self.class_emb, [batch_size, 1, d_model])
        x = tf.concat([class_emb, inputs], axis=1)
        return x + self.pos_emb

    
class ExtractPatchesLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(ExtractPatchesLayer, self).__init__()
        self.patch_size = 4
        self.patch_dim = 16
        
    def extract_patches(self, images, patch_size, patch_dim):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, patch_size, patch_size, 1],
            strides=[1, patch_size, patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patches = tf.reshape(patches, [batch_size, -1, patch_dim])
        return patches

    def call(self, inputs):
        return self.extract_patches(inputs, self.patch_size, self.patch_dim)

### Model

In [7]:
def multi_head_attention(x):
    # ================== Multi Head Self Attention ===============
    v2 = tf.keras.layers.Dense(embed_dim)(x)
    q2 = tf.keras.layers.Dense(embed_dim)(x)
    k2 = tf.keras.layers.Dense(embed_dim)(x)

    v = tf.keras.layers.Reshape([l, num_heads, projection_dim])(v2)
    v = TransposeLayer()(v)
    q = tf.keras.layers.Reshape([l, num_heads, projection_dim])(q2)
    q = TransposeLayer()(q)
    k = tf.keras.layers.Reshape([l, num_heads, projection_dim])(k2)
    k = TransposeLayer()(k)

    # =============== Scaled dot-product attention =================
    # QK^T
    att = MatMulLayerTranspose()([q, k])
    # softmax(QK^T)
    att = tf.keras.layers.Softmax(axis=-1)(att)
    # softmax(QK^T)*V
    out = MatMulLayer()([att, v])

    att = TransposeLayer()(out)
    out = tf.keras.layers.Reshape([-1, l, embed_dim])(att)
    out = tf.keras.layers.Dense(embed_dim)(out)
    out = Dropout(dropout)(out)
    # out = tf.keras.layers.Reshape([l, d_model, 1])(out)
    x = tf.keras.layers.Reshape([-1, l, embed_dim])(x)
    # ============== End of Multi Head Self Attention =============
    # Concat Layer
    add = tf.keras.layers.Add()([out, x])

    out = tf.keras.layers.Dense(mlp_dim, activation="relu")(add)
    out = Dropout(dropout)(out)
    out = tf.keras.layers.Dense(embed_dim)(out)
    out = Dropout(dropout)(out)
    out = tf.keras.layers.Add()([out, add])
    # ================== End of Transformer =======================
    return out    


In [8]:
"""
Fully Functional Modular version with Multiply layers instead of einsums
"""

num_classes = 10
image_size = 28
patch_size = 4
num_patches = (image_size // patch_size) ** 2
channels = 1
patch_dim = channels * patch_size ** 2
batch_size = 64
embed_dim = d_model = 64
num_heads = 4
projection_dim = embed_dim//num_heads
mlp_dim = 128
l = 50
dropout = 0.1
# the data, shuffled and split between tran and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize input so we can train ANN with it.
# Will be converted back to integers for SNN layer.
x_train = x_train / 255
x_test = x_test / 255

# Add a channel dimension.
axis = 1 if keras.backend.image_data_format() == 'channels_first' else -1
x_train = np.expand_dims(x_train, axis)
x_test = np.expand_dims(x_test, axis)

# One-hot encode target vectors.
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

dv = 24
dout = 32
nv = 8

inp = Input(shape=(28, 28, 1))
patches = ExtractPatchesLayer()(inp)
x = tf.keras.layers.Dense(d_model)(patches)
x = PositionalEncodingLayer()(x)

out = x
for i in range(4):
    out = multi_head_attention(out)

out = SqueezeLayer()(out)
out = SliceLayer()(out)
# out = tf.keras.layers.Flatten()(out)
out = tf.keras.layers.Dense(embed_dim, activation='relu')(out)
out = Dropout(dropout)(out)
out = tf.keras.layers.Dense(num_classes, activation='softmax')(out)
model = Model(inputs=inp, outputs=out)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.fit(x_train, y_train, batch_size=64, epochs=10, verbose=1, validation_data=(x_test, y_test))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f39abc86580>

In [9]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
extract_patches_layer_1 (Extrac (None, None, 16)     0           input_2[0][0]                    
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, None, 64)     1088        extract_patches_layer_1[0][0]    
__________________________________________________________________________________________________
positional_encoding_layer_1 (Po (None, 50, 64)       3264        dense_1[0][0]                    
______________________________________________________________________________________________

### Original ViT / Comparison

In [22]:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.layers import (
    Dense,
    Dropout,
    LayerNormalization,
)
from tensorflow.keras.layers.experimental.preprocessing import Rescaling
import tensorflow_datasets as tfds


class MultiHeadSelfAttention(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads=8):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        if embed_dim % num_heads != 0:
            raise ValueError(
                f"embedding dimension = {embed_dim} should be divisible by number of heads = {num_heads}"
            )
        self.projection_dim = embed_dim // num_heads
        self.query_dense = Dense(embed_dim)
        self.key_dense = Dense(embed_dim)
        self.value_dense = Dense(embed_dim)
        self.combine_heads = Dense(embed_dim)

    def attention(self, query, key, value):
        score = tf.matmul(query, key, transpose_b=True)
        dim_key = tf.cast(tf.shape(key)[-1], tf.float32)
        scaled_score = score / tf.math.sqrt(dim_key)
        weights = tf.nn.softmax(scaled_score, axis=-1)
        output = tf.matmul(weights, value)
        return output, weights

    def separate_heads(self, x, batch_size):
        x = tf.reshape(
            x, (batch_size, -1, self.num_heads, self.projection_dim)
        )
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        query = self.query_dense(inputs)
        key = self.key_dense(inputs)
        value = self.value_dense(inputs)
        print(f"Query shape - {query.shape}")
        query = self.separate_heads(query, batch_size)
        print(f"Query shape after separating heads - {query.shape}")
        key = self.separate_heads(key, batch_size)
        value = self.separate_heads(value, batch_size)

        attention, weights = self.attention(query, key, value)
        attention = tf.transpose(attention, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(
            attention, (batch_size, -1, self.embed_dim)
        )
        # Dense Layer in the end
        output = self.combine_heads(concat_attention)
        return output


class TransformerBlock(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.att = MultiHeadSelfAttention(embed_dim, num_heads)
        self.mlp = tf.keras.Sequential(
            [
                Dense(mlp_dim, activation=tfa.activations.gelu),
                Dropout(dropout),
                Dense(embed_dim),
                Dropout(dropout),
            ]
        )
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

    def call(self, inputs, training):
        inputs_norm = self.layernorm1(inputs)
        attn_output = self.att(inputs_norm)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = attn_output + inputs

        out1_norm = self.layernorm2(out1)
        mlp_output = self.mlp(out1_norm)
        mlp_output = self.dropout2(mlp_output, training=training)
        return mlp_output + out1


class VisionTransformer(tf.keras.Model):
    def __init__(
        self,
        image_size,
        patch_size,
        num_layers,
        num_classes,
        d_model,
        num_heads,
        mlp_dim,
        channels=3,
        dropout=0.1,
    ):
        super(VisionTransformer, self).__init__()
        num_patches = (image_size // patch_size) ** 2
        self.patch_dim = channels * patch_size ** 2

        self.patch_size = patch_size
        self.d_model = d_model
        self.num_layers = num_layers

#         self.rescale = Rescaling(1.0 / 255)
        self.pos_emb = self.add_weight(
            "pos_emb", shape=(1, num_patches + 1, d_model)
        )
        self.class_emb = self.add_weight("class_emb", shape=(1, 1, d_model))
        self.patch_proj = Dense(d_model)
        self.enc_layers = [
            TransformerBlock(d_model, num_heads, mlp_dim, dropout)
            for _ in range(num_layers)
        ]
        self.mlp_head = tf.keras.Sequential(
            [
                LayerNormalization(epsilon=1e-6),
                Dense(mlp_dim, activation=tfa.activations.gelu),
                Dropout(dropout),
                Dense(num_classes),
            ]
        )

    def extract_patches(self, images):
        batch_size = tf.shape(images)[0]
#         print(f"Batch size for extracting patches - {batch_size}")
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patches = tf.reshape(patches, [batch_size, -1, self.patch_dim])
        return patches
    
    def call(self, x, training):
        batch_size = tf.shape(x)[0]
        print(f"Batch size - {batch_size.shape}")
#         x = self.rescale(x)
        #  print(x.shape) # (None, 28, 28, 1)
        patches = self.extract_patches(x)
        # print(patches.shape) # (None, None, 16)
        x = self.patch_proj(patches)
        # print(x.shape) # (None, None, 64)
        class_emb = tf.broadcast_to(
            self.class_emb, [batch_size, 1, self.d_model]
        )
        # print(class_emb.shape) # (None, 1, 64)
        # print(self.pos_emb.shape) # (1, 50, 64)
        x = tf.concat([class_emb, x], axis=1)
        # print(x.shape) # (None, None, 64)
        x = x + self.pos_emb
#         print(x.shape)
        for layer in self.enc_layers:
            x = layer(x, training)
        print(x.shape)
        # First (class token) is used for classification
        x = x[:, 0]
        print(x.shape)
        x = self.mlp_head(x)
        print(x.shape)
        return x

In [23]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

logdir = "logs"
image_size = 28
patch_size = 4
num_layers = 4
d_model = 64
num_heads = 4
mlp_dim = 128
lr = 3e-4
weight_decay = 1e-4
batch_size = 64
epochs = 5

import numpy as np
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Normalize input so we can train ANN with it.
# Will be converted back to integers for SNN layer.
x_train = x_train / 255
x_test = x_test / 255

# Add a channel dimension.
axis = 1 if tf.keras.backend.image_data_format() == 'channels_first' else -1
x_train = tf.expand_dims(x_train, axis)
x_test = tf.expand_dims(x_test, axis)

# One-hot encode target vectors.
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

### Result

In [24]:
model = VisionTransformer(
    image_size=image_size,
    patch_size=patch_size,
    num_layers=num_layers,
    num_classes=10,
    d_model=d_model,
    num_heads=num_heads,
    mlp_dim=mlp_dim,
    channels=1,
)

# TODO: THIS THING IS VERY IMPORTANT
model.compile(
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    optimizer="adam",
    metrics=["accuracy"],
)

# model.compile('adam', 'categorical_crossentropy', ['accuracy'])
model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=1, validation_data=(x_test, y_test))

Batch size - ()
Query shape - (None, 50, 64)
Query shape after separating heads - (None, 4, None, 16)
Query shape - (None, 50, 64)
Query shape after separating heads - (None, 4, None, 16)
Query shape - (None, 50, 64)
Query shape after separating heads - (None, 4, None, 16)
Query shape - (None, 50, 64)
Query shape after separating heads - (None, 4, None, 16)
(None, 50, 64)
(None, 64)
(None, 10)
Batch size - ()
Query shape - (None, 50, 64)
Query shape after separating heads - (None, 4, None, 16)
Query shape - (None, 50, 64)
Query shape after separating heads - (None, 4, None, 16)
Query shape - (None, 50, 64)
Query shape after separating heads - (None, 4, None, 16)
Query shape - (None, 50, 64)
Query shape after separating heads - (None, 4, None, 16)
(None, 50, 64)
(None, 64)
(None, 10)


KeyboardInterrupt: 

# 2 - Layer Multi Head Attention, defined in Vision Transformer

In [38]:
"""
Strange behaviour, when commenting transpose permutations, accuracy increases
"""
class MultiHeadSelfAttention(tf.keras.layers.Layer):
    
    def __init__(self, embed_dim, num_heads=8, **kwargs):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        if embed_dim % num_heads != 0:
           raise ValueError(
               f"embedding dimension = {embed_dim} should be divisible by number of heads = {num_heads}"
           )
        self.projection_dim = self.embed_dim // self.num_heads
        self.query_dense = Dense(self.embed_dim)
        self.key_dense = Dense(self.embed_dim)
        self.value_dense = Dense(self.embed_dim)
        self.combine_heads = Dense(self.embed_dim)

    def attention(self, query, key, value):
        score = tf.matmul(query, key, transpose_b=True)
        dim_key = tf.cast(tf.shape(key)[-1], tf.float32)
        scaled_score = score / tf.math.sqrt(dim_key)
        weights = tf.nn.softmax(scaled_score, axis=-1)
        output = tf.matmul(weights, value)
        return output, weights

    def separate_heads(self, x, batch_size):
        x = tf.reshape(
            x, (batch_size, -1, self.num_heads, self.projection_dim)
        )
#         return tf.transpose(x, perm=[0, 2, 1, 3])
        return x
    
    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'embed_dim': self.embed_dim,
            'num_heads': self.num_heads,
            'projection_dim': self.projection_dim,
            'query_dense': self.query_dense,
            'key_dense': self.key_dense,
            'value_dense': self.value_dense,
            'combine_heads': self.combine_heads,
        })
        return config

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        query = self.query_dense(inputs)
        key = self.key_dense(inputs)
        value = self.value_dense(inputs)
        query = self.separate_heads(query, batch_size)
        key = self.separate_heads(key, batch_size)
        value = self.separate_heads(value, batch_size)

        attention, weights = self.attention(query, key, value)
#         attention = tf.transpose(attention, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(
            attention, (batch_size, -1, self.embed_dim)
        )
        output = self.combine_heads(concat_attention)
        return output

In [44]:
class MultiHeadSelfAttention(tf.keras.layers.Layer):
    
    def __init__(self, embed_dim, num_heads=8, **kwargs):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        if embed_dim % num_heads != 0:
           raise ValueError(
               f"embedding dimension = {embed_dim} should be divisible by number of heads = {num_heads}"
           )
        self.projection_dim = self.embed_dim // self.num_heads
        self.query_dense = Dense(self.embed_dim)
        self.key_dense = Dense(self.embed_dim)
        self.value_dense = Dense(self.embed_dim)
        self.combine_heads = Dense(self.embed_dim)

    def attention(self, query, key, value):
        score = tf.matmul(query, key, transpose_b=True)
        dim_key = tf.cast(tf.shape(key)[-1], tf.float32)
        scaled_score = score / tf.math.sqrt(dim_key)
        weights = tf.nn.softmax(scaled_score, axis=-1)
        output = tf.matmul(weights, value)
        return output, weights

    def separate_heads(self, x, batch_size):
        x = tf.reshape(
            x, (batch_size, -1, self.num_heads, self.projection_dim)
        )
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    def get_config(self):

        config = super().get_config().copy()num_classes = 10
image_size = 28
patch_size = 4
num_patches = (image_size // patch_size) ** 2
channels = 1
patch_dim = channels * patch_size ** 2
batch_size = 64
embed_dim = d_model = 64

l = 50
d = 192
dv = 24
dout = 32
nv = 8

inp = Input(shape=(28, 28, 1))
patches = ExtractPatchesLayer()(inp)
x = Dense(d_model)(patches)
x = PositionalEncodingLayer()(x)

# Attention Module
v2 = tf.keras.layers.Dense(nv*nv, activation="relu")(x)
q2 = tf.keras.layers.Dense(nv*nv, activation="relu")(x)
k2 = tf.keras.layers.Dense(nv*nv, activation="relu")(x)

v = tf.keras.layers.Reshape([l, nv, nv])(v2)
q = tf.keras.layers.Reshape([l, nv, nv])(q2)
k = tf.keras.layers.Reshape([l, nv, nv])(k2)

# =============== Scaled dot-product attention =================
# QK^T
att = MatMulLayerTranspose()([q, k])
# softmax(QK^T)
att = tf.keras.layers.Softmax(axis=-1)(att)
# softmax(QK^T)*V
out = MatMulLayer()([att, v])

out = tf.keras.layers.Reshape([l, d_model, 1])(out)
x = tf.keras.layers.Reshape([l, d_model, 1])(x)

# Concat Layer
add = tf.keras.layers.Add()([out, x])
out = tf.keras.layers.Dense(32, activation="relu")(add)

out = tf.keras.layers.Flatten()(out)
# out = tf.keras.layers.Dense(32, activation='relu')(out)
out = tf.keras.layers.Dense(num_classes, activation='softmax')(out)

model = Model(inputs=inp, outputs=out)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.fit(x_train, y_train, batch_size=64, epochs=1, verbose=1, validation_data=(x_test, y_test))
        config.update({
            'embed_dim': self.embed_dim,
            'num_heads': self.num_heads,
            'projection_dim': self.projection_dim,
            'query_dense': self.query_dense,
            'key_dense': self.key_dense,
            'value_dense': self.value_dense,
            'combine_heads': self.combine_heads,
        })
        return config

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        query = self.query_dense(inputs)
        key = self.key_dense(inputs)
        value = self.value_dense(inputs)
        query = self.separate_heads(query, batch_size)
        key = self.separate_heads(key, batch_size)
        value = self.separate_heads(value, batch_size)

        attention, weights = self.attention(query, key, value)
        attention = tf.transpose(attention, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(
            attention, (batch_size, -1, self.embed_dim)
        )
        output = self.combine_heads(concat_attention)
        return output

In [45]:
num_classes = 10

# the data, shuffled and split between tran and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize input so we can train ANN with it.
# Will be converted back to integers for SNN layer.
x_train = x_train / 255
x_test = x_test / 255

# Add a channel dimension.
axis = 1 if keras.backend.image_data_format() == 'channels_first' else -1
x_train = np.expand_dims(x_train, axis)
x_test = np.expand_dims(x_test, axis)

# One-hot encode target vectors.
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

inp = Input(shape = (28, 28, 1))
x = Conv2D(32, (2,2), activation='relu', padding='same')(inp)
x = MaxPooling2D(pool_size=(2, 2))(x)
c = Conv2D(64, (2,2), activation='relu')
x = c(x)
x = MaxPooling2D(pool_size=(2, 2), padding='same')(x)
x = Conv2D(64*3, (2,2), activation='relu')(x)


x = Reshape([6*6,64*3])(x)


att = MultiHeadSelfAttention(embed_dim=128, num_heads=8)
x = att(x)

x = Reshape([24,6,32])(x)

# x = Reshape([6,6,32])(x)

x = BatchNormalization()(x)

x = Flatten()(x) 
x = Dense(256, activation='relu')(x)
x = Dense(num_classes, activation='softmax')(x)

model = Model(inputs=inp, outputs=x)
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=5, verbose=1, validation_data=(x_test, y_test))

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7fc288cc2190>

## 3 - Functional Multi Head Attentional Model with separate heads and attention

In [23]:
"""
Multi Head attention in a functional fashion. Includes two more complicated functions > attention and separate heads
"""

def separate_heads(x, batch_size, num_head, projection_dim):
    x = tf.reshape(
        x, (batch_size, -1, num_heads, projection_dim)
    )
    return tf.transpose(x, perm=[0, 2, 1, 3])

def attention(query, key, value):
    score = tf.matmul(query, key, transpose_b=True)
    dim_key = tf.cast(tf.shape(key)[-1], tf.float32)
    scaled_score = score / tf.math.sqrt(dim_key)
    weights = tf.nn.softmax(scaled_score, axis=-1)
    output = tf.matmul(weights, value)
    return output, weights


embed_dim = 64
num_heads = 4
batch_size = tf.shape(x)[0]

projection_dim = embed_dim // num_heads
query_dense = tf.keras.layers.Dense(embed_dim)(x)
key_dense = tf.keras.layers.Dense(embed_dim)(x)
value_dense = tf.keras.layers.Dense(embed_dim)(x)

# query = separate_heads(query_dense, batch_size, num_heads, projection_dim)
query = tf.keras.layers.Reshape([batch_size, -1, num_heads, projection_dim])(query_dense)
query = tf.keras.layers.Permute()(query)
key = separate_heads(key_dense, batch_size, num_heads, projection_dim)
value = separate_heads(value_dense, batch_size, num_heads, projection_dim)

attention, weights = attention(query, key, value)
concat_attention = tf.reshape(attention, (batch_size, -1, embed_dim))
x = tf.keras.layers.Dense(embed_dim)(concat_attention)


NameError: name 'inputs' is not defined

In [58]:
num_classes = 10

# the data, shuffled and split between tran and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize input so we can train ANN with it.
# Will be converted back to integers for SNN layer.
x_train = x_train / 255
x_test = x_test / 255

# Add a channel dimension.
axis = 1 if keras.backend.image_data_format() == 'channels_first' else -1
x_train = np.expand_dims(x_train, axis)
x_test = np.expand_dims(x_test, axis)

# One-hot encode target vectors.
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

inp = Input(shape = (28, 28, 1))
x = Conv2D(32, (2,2), activation='relu', padding='same')(inp)
x = MaxPooling2D(pool_size=(2, 2))(x)
c = Conv2D(64, (2,2), activation='relu')
x = c(x)
x = MaxPooling2D(pool_size=(2, 2), padding='same')(x)
x = Conv2D(64*3, (2,2), activation='relu')(x)


x = Reshape([6*6,64*3])(x)

# ===========================================================================================
def separate_heads(x, batch_size, num_head, projection_dim):
    x = tf.reshape(
        x, (batch_size, -1, num_heads, projection_dim)
    )
    return tf.transpose(x, perm=[0, 2, 1, 3])

def attention(query, key, value):
    score = tf.matmul(query, key, transpose_b=True)
    dim_key = tf.cast(tf.shape(key)[-1], tf.float32)
    scaled_score = score / tf.math.sqrt(dim_key)
    weights = tf.nn.softmax(scaled_score, axis=-1)
    output = tf.matmul(weights, value)
    return output, weights


embed_dim = 64
num_heads = 4
batch_size = tf.shape(x)[0]

projection_dim = embed_dim // num_heads
query_dense = tf.keras.layers.Dense(embed_dim)(x)
key_dense = tf.keras.layers.Dense(embed_dim)(x)
value_dense = tf.keras.layers.Dense(embed_dim)(x)

# query = separate_heads(query_dense, batch_size, num_heads, projection_dim)
# query = tf.transpose(query, perm=[0, 2, 1, 3])
query = tf.keras.layers.Reshape((-1, 36, num_heads, projection_dim))(query_dense)
# query = tf.reshape(query_dense, (batch_size, -1, num_heads, projection_dim))
query = tf.keras.layers.Permute((1, 3, 2, 4))(query)

# key = separate_heads(key_dense, batch_size, num_heads, projection_dim)
key = tf.keras.layers.Reshape((-1, 36, num_heads, projection_dim))(key_dense)
# key = tf.reshape(key_dense, (batch_size, -1, num_heads, projection_dim))
key = tf.keras.layers.Permute((1, 3, 2, 4))(key)

# value = separate_heads(value_dense, batch_size, num_heads, projection_dim)
value = tf.keras.layers.Reshape((-1, 36, num_heads, projection_dim))(value_dense)
# value = tf.reshape(value_dense, (batch_size, -1, num_heads, projection_dim))
value = tf.keras.layers.Permute((1, 3, 2, 4))(value)

# attention, weights = attention(query, key, value)
score = tf.matmul(query, key, transpose_b=True)
dim_key = tf.cast(tf.shape(key)[-1], tf.float32)
scaled_score = score / tf.math.sqrt(dim_key)
# weights = tf.nn.softmax(scaled_score, axis=-1)
weights = tf.keras.layers.Softmax(axis=-1)(scaled_score)
attention = tf.matmul(weights, value)

concat_attention = tf.keras.layers.Reshape((-1, 36, embed_dim))(attention)
# concat_attention = tf.reshape(attention, (batch_size, -1, embed_dim))
x = tf.keras.layers.Dense(embed_dim)(concat_attention)

# ===========================================================================================

x = Reshape([6, 12, 32])(x)

x = BatchNormalization()(x)

x = Flatten()(x) 
x = Dense(256, activation='relu')(x)
x = Dense(num_classes, activation='softmax')(x)

model = Model(inputs=inp, outputs=x)
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=1, verbose=1, validation_data=(x_test, y_test))

(None, 1, 4, 36, 16)
(None, 1, 4, 36, 16)


ValueError: Dimensions must be equal, but are 16 and 36 for '{{node tf.linalg.matmul_42/MatMul}} = BatchMatMulV2[T=DT_FLOAT, adj_x=false, adj_y=false](Placeholder, Placeholder_1)' with input shapes: [?,1,4,36,16], [?,1,4,36,16].

In [59]:
model.summary()

Model: "model_21"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_55 (InputLayer)           [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d_144 (Conv2D)             (None, 28, 28, 32)   160         input_55[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_96 (MaxPooling2D) (None, 14, 14, 32)   0           conv2d_144[0][0]                 
__________________________________________________________________________________________________
conv2d_145 (Conv2D)             (None, 13, 13, 64)   8256        max_pooling2d_96[0][0]           
___________________________________________________________________________________________

## 4 - Simplest Attention Model only with attention

In [None]:
"""
Not multi head attention with simplest operations. Only attention function is somehow more complicated
"""

def attention(query, key, value):
    score = tf.matmul(query, key, transpose_b=True)
    dim_key = tf.cast(tf.shape(key)[-1], tf.float32)
    scaled_score = score / tf.math.sqrt(dim_key)
    weights = tf.nn.softmax(scaled_score, axis=-1)
    output = tf.matmul(weights, value)
    return output, weights


embed_dim = 64
num_heads = 1
batch_size = tf.shape(inputs)[0]

projection_dim = embed_dim // num_heads
query_dense = Dense(embed_dim)(x)
key_dense = Dense(embed_dim)(x)
value_dense = Dense(embed_dim)(x)

query = tf.reshape(query_dense, (batch_size, -1, num_heads, self.projection_dim))
key = tf.reshape(key_dense, (batch_size, -1, num_heads, self.projection_dim))
value = tf.reshape(value_dense, (batch_size, -1, num_heads, self.projection_dim))

attention, weights = attention(query, key, value)
concat_attention = tf.reshape(attention, (batch_size, -1, embed_dim))
output = Dense(embed_dim)(concat_attention)

In [None]:
num_classes = 10

# the data, shuffled and split between tran and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize input so we can train ANN with it.
# Will be converted back to integers for SNN layer.
x_train = x_train / 255
x_test = x_test / 255

# Add a channel dimension.
axis = 1 if keras.backend.image_data_format() == 'channels_first' else -1
x_train = np.expand_dims(x_train, axis)
x_test = np.expand_dims(x_test, axis)

# One-hot encode target vectors.
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

inp = Input(shape = (28, 28, 1))
x = Conv2D(32, (2,2), activation='relu', padding='same')(inp)
x = MaxPooling2D(pool_size=(2, 2))(x)
c = Conv2D(64, (2,2), activation='relu')
x = c(x)
x = MaxPooling2D(pool_size=(2, 2), padding='same')(x)
x = Conv2D(64*3, (2,2), activation='relu')(x)


x = Reshape([6*6,64*3])(x)

def attention(query, key, value):
    score = tf.matmul(query, key, transpose_b=True)
    dim_key = tf.cast(tf.shape(key)[-1], tf.float32)
    scaled_score = score / tf.math.sqrt(dim_key)
    weights = tf.nn.softmax(scaled_score, axis=-1)
    output = tf.matmul(weights, value)
    return output, weights


embed_dim = 64
num_heads = 1
batch_size = tf.shape(inputs)[0]

projection_dim = embed_dim // num_heads

query_dense = Dense(embed_dim)(x)
key_dense = Dense(embed_dim)(x)
value_dense = Dense(embed_dim)(x)

query = tf.reshape(query_dense, (batch_size, -1, num_heads, projection_dim))
key = tf.reshape(key_dense, (batch_size, -1, num_heads, projection_dim))
value = tf.reshape(value_dense, (batch_size, -1, num_heads, projection_dim))

attention, weights = attention(query, key, value)
concat_attention = tf.reshape(attention, (batch_size, -1, embed_dim))
x = Dense(embed_dim)(concat_attention)

# att = MultiHeadSelfAttention(embed_dim=64, num_heads=1)
# x = att(x)
# # att = MultiHeadAttention(num_heads=4, key_dim=64)
# # x = att(x, x)



x = Reshape([12,6,32])(x)

# x = Reshape([6,6,32])(x)

x = BatchNormalization()(x)

x = Flatten()(x)
x = Dense(256, activation='relu')(x)
x = Dense(num_classes, activation='softmax')(x)

model = Model(inputs=inp, outputs=x)
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=1, verbose=1, validation_data=(x_test, y_test))

# Validation and Saving

In [6]:
def parse_multi_head_attention(layer, attributes):
    attributes['parameters'] = list(layer.get_weights())
    return attributes

attrs = parse_multi_head_attention(att, {})

AttributeError: 'KerasTensor' object has no attribute 'get_weights'

[array([[[-0.05251397,  0.05196556,  0.04602323, ..., -0.0906411 ,
           0.05808846, -0.05304115],
         [ 0.07712557,  0.06473607,  0.06472164, ..., -0.09025709,
           0.1291429 , -0.06681956],
         [ 0.14465804,  0.05553916,  0.07451871, ...,  0.05384645,
           0.13792342, -0.09032944],
         [-0.07952031,  0.10550651,  0.08444484, ...,  0.08768411,
          -0.10532371,  0.11276584]],
 
        [[-0.13679186,  0.15290363,  0.16425851, ..., -0.22392105,
           0.18364385, -0.12704042],
         [ 0.14089543,  0.15607025,  0.13811384, ..., -0.17982219,
           0.1849788 , -0.18002434],
         [ 0.20148136,  0.21233286,  0.23205617, ...,  0.21763808,
           0.18484046, -0.19925883],
         [-0.2809631 ,  0.20544429,  0.2593414 , ...,  0.2762511 ,
          -0.2782754 ,  0.27995571]],
 
        [[ 0.20744342, -0.19630732, -0.17898946, ...,  0.08385453,
          -0.1826546 ,  0.23804533],
         [-0.14874424, -0.15827128, -0.24028395, ...,  0.1

In [82]:
model.summary()

Model: "model_17"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_28 (InputLayer)           [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d_81 (Conv2D)              (None, 28, 28, 32)   160         input_28[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_54 (MaxPooling2D) (None, 14, 14, 32)   0           conv2d_81[0][0]                  
__________________________________________________________________________________________________
conv2d_82 (Conv2D)              (None, 13, 13, 64)   8256        max_pooling2d_54[0][0]           
___________________________________________________________________________________________

In [12]:
model.evaluate(x_test, y_test, batch_size=32)



[0.06048298254609108, 0.9819999933242798]

In [50]:
import os
from tensorflow import keras
keras.models.save_model(
    model, 
    os.path.join("/home/viktor/PycharmProjects/guided_research/transformer-to-snn-conversion", 
    "mnist_transformer" + '.h5')
)

In [51]:
# reconstructed_model = keras.models.load_model(
#     os.path.join("/home/viktor/PycharmProjects/guided_research/transformer-to-snn-conversion", 
#     "mnist_transformer" + '.h5'), custom_objects={'MultiHeadSelfAttention': MultiHeadSelfAttention})
reconstructed_model = keras.models.load_model(
    os.path.join("/home/viktor/PycharmProjects/guided_research/transformer-to-snn-conversion", 
    "mnist_transformer" + '.h5'), custom_objects={'ScaleLayer': ScaleLayer, 'MatMulLayer': MatMulLayer, 
                                                  'MatMulLayerTranspose': MatMulLayerTranspose,
                                                  'PositionalEncodingLayer': PositionalEncodingLayer,
                                                  'ExtractPatchesLayer': ExtractPatchesLayer})

In [23]:
reconstructed_model.summary()

Model: "model_5"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_13 (InputLayer)           [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d_36 (Conv2D)              (None, 28, 28, 32)   160         input_13[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_24 (MaxPooling2D) (None, 14, 14, 32)   0           conv2d_36[0][0]                  
__________________________________________________________________________________________________
conv2d_37 (Conv2D)              (None, 13, 13, 64)   8256        max_pooling2d_24[0][0]           
____________________________________________________________________________________________

In [24]:
reconstructed_model.evaluate(x_test, y_test)



[0.05877845361828804, 0.9815000295639038]