In [1]:
import tensorflow as tf
import numpy as np
from snn_conversion.spiking_models import SpikingReLU, Accumulate
import keras
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist
from snn_conversion.operations_layers import SqueezeLayer, ExpandLayer, ExtractPatchesLayer, PositionalEncodingLayer
from snn_conversion.multi_head_self_attention import multi_head_self_attention
from snn_conversion.old_normalization import get_normalized_weights
from snn_conversion.utils import evaluate_conversion

In [None]:
tf.random.set_seed(1238)
batch_size=128
epochs=2


def create_ann_approved_version():
    inputs = tf.keras.layers.Input(shape=(28, 28, 1))
    patches = ExtractPatchesLayer()(inputs)
    x = tf.keras.layers.Dense(d_model)(patches)
    x = PositionalEncodingLayer()(x)
    out = x
    for i in range(1):
        out, add = multi_head_self_attention(out)
        out = tf.keras.layers.Dense(mlp_dim, activation="relu")(add)
        out = tf.keras.layers.Dense(embed_dim)(out)
        out = tf.keras.layers.Add()([out, add])
        
    x = tf.keras.layers.Flatten()(out)
    x = tf.keras.layers.Dense(embed_dim, activation="relu")(x)
    # --------------------------------------------------
    x = tf.keras.layers.Dense(num_classes)(x)
    x = tf.keras.layers.Softmax()(x)
    
    ann = tf.keras.models.Model(inputs=inputs, outputs=x)
    
    ann.compile(
        optimizer="adam",
        loss="categorical_crossentropy",
        metrics=["accuracy"])

    ann.fit(
        x_train,
        y_train,
        validation_data=(x_test, y_test),
        batch_size=batch_size,
        epochs=epochs)
    return ann


def convert_tailored_approved_version(weights, y_test):
    inputs = tf.keras.layers.Input(shape=(28, 28, 1), batch_size=y_test.shape[0])
    
    patches = ExtractPatchesLayer()(inputs)
    x = tf.keras.layers.Dense(d_model)(patches)
    x = PositionalEncodingLayer()(x)
    out = x
    for i in range(1):
        out, add = multi_head_self_attention(out)
        out = tf.keras.layers.Dense(mlp_dim)(add)
        print(out.shape)
        out = tf.keras.layers.Reshape([1, l*mlp_dim])(out)
        out = tf.keras.layers.RNN(SpikingReLU(l*mlp_dim), return_sequences=True, return_state=False, 
                            stateful=True)(out)
        out = tf.keras.layers.Reshape([1, l, mlp_dim])(out)
        
        out = tf.keras.layers.Dense(embed_dim)(out)
        out = tf.keras.layers.Add()([out, add])
        
    x = tf.keras.layers.Flatten()(out)
    x = ExpandLayer()(x)
    x = tf.keras.layers.Dense(embed_dim)(x)
    x = tf.keras.layers.RNN(SpikingReLU(embed_dim), return_sequences=True, return_state=False, 
                            stateful=True)(x)
    # --------------------------------------------------
    x = tf.keras.layers.Dense(num_classes)(x)
    
    x = tf.keras.layers.RNN(Accumulate(num_classes), return_sequences=True, return_state=False, stateful=True)(x)
    x = tf.keras.layers.Softmax()(x)
    
    x = SqueezeLayer()(x)
    
    spiking = tf.keras.models.Model(inputs=inputs, outputs=x)
    
    print("-"*32 + "\n")
    spiking.compile(
        optimizer="adam",
        loss="categorical_crossentropy",
        metrics=["accuracy"])
    print(spiking.summary())
    spiking.set_weights(weights)
    return spiking

dv = 24
dout = 32
nv = 8
vocab_size = 20000  # Only consider the top 20k words
maxlen = 200  # Only consider the first 200 words of each movie review
embed_dim = d_model = 64  # Embedding size for each token
mlp_dim = 128
l = 50
num_heads = 4
num_classes = 10
image_size = 28
patch_size = 4
num_patches = (image_size // patch_size) ** 2
channels = 1
patch_dim = channels * patch_size ** 2
projection_dim = embed_dim//num_heads


(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize input so we can train ANN with it.
# Will be converted back to integers for SNN layer.
x_train = x_train / 255
x_test = x_test / 255

# Add a channel dimension.
axis = 1 if keras.backend.image_data_format() == 'channels_first' else -1
x_train = np.expand_dims(x_train, axis)
x_test = np.expand_dims(x_test, axis)

# One-hot encode target vectors.
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# Analog model
ann = create_ann_approved_version()
print(ann.summary())

_, testacc = ann.evaluate(x_test, y_test, batch_size=batch_size, verbose=0)
# weights = ann.get_weights()
weights = get_normalized_weights(ann, x_train, percentile=85)

##################################################
# Preprocessing for RNN 
# x_train = np.expand_dims(x_train, axis=1)  # (60000, 784) -> (60000, 1, 784)
# x_test = np.expand_dims(x_test, axis=1)

##################################################
# Conversion to spiking model
# snn = convert(ann, weights, x_test, y_test)
snn = convert_tailored_approved_version(weights, y_test)
evaluate_conversion(snn, ann, x_test, y_test, testacc, timesteps=200)