In [1]:
import tensorflow as tf
import numpy as np
from snn_conversion.spiking_models import SpikingReLU, Accumulate
from tensorflow.keras.utils import to_categorical
from snn_conversion.operations_layers import SqueezeLayer, ExpandLayer, Tokpos
from snn_conversion.multi_head_self_attention import multi_head_self_attention
from snn_conversion.weight_normalization import robust_weight_normalization
from snn_conversion.utils import evaluate_conversion

In [None]:
tf.random.set_seed(1234)
batch_size=128
epochs=2
dv = 25
nv = -1
vocab_size = 10000  # Only consider the top 20k words
maxlen = 200  # Only consider the first 200 words of each movie review
embed_dim = 32  # Embedding size for each token
mlp_dim = 64
l = 50
num_heads = 4
num_classes = 2


def create_ann_approved_version():
    inputs = tf.keras.layers.Input(shape=(maxlen,))
    x = Tokpos(maxlen, vocab_size, embed_dim)(inputs)
    out = x
    for i in range(1):
        out, add = multi_head_self_attention(out)
        out = tf.keras.layers.Dense(mlp_dim, activation="relu")(add)
        out = tf.keras.layers.Dense(embed_dim)(out)
        out = tf.keras.layers.Add()([out, add])
        
    x = tf.keras.layers.Flatten()(out)
    x = tf.keras.layers.Dense(embed_dim, activation="relu")(x)
    x = tf.keras.layers.Dense(embed_dim)(x)
    x = tf.keras.layers.Dense(mlp_dim, activation="relu")(x)
    # --------------------------------------------------
    x = tf.keras.layers.Dense(num_classes)(x)
    x = tf.keras.layers.Softmax()(x)
    
    ann = tf.keras.models.Model(inputs=inputs, outputs=x)
    
    ann.compile(
        optimizer="adam",
        loss="categorical_crossentropy",
        metrics=["accuracy"])

    ann.fit(
        x_train,
        y_train,
        validation_data=(x_test, y_test),
        batch_size=batch_size,
        epochs=epochs)
    return ann


def convert_tailored_approved_version(weights, y_test):
    inputs = tf.keras.layers.Input(shape=(1, maxlen,), batch_size=y_test.shape[0])
    x = Tokpos(maxlen, vocab_size, embed_dim)(inputs)
    out = x
    for i in range(1):
        out, add = multi_head_self_attention(out)
        out = tf.keras.layers.Dense(mlp_dim)(add)
        out = tf.keras.layers.Reshape([1, num_heads*l*mlp_dim])(out)
        out = tf.keras.layers.RNN(SpikingReLU(num_heads*l*mlp_dim), return_sequences=True, return_state=False, 
                            stateful=True)(out)
        out = tf.keras.layers.Reshape([num_heads, l, mlp_dim])(out)
        
        out = tf.keras.layers.Dense(embed_dim)(out)
        out = tf.keras.layers.Add()([out, add])
        
    x = tf.keras.layers.Flatten()(out)
    x = ExpandLayer()(x)
    x = tf.keras.layers.Dense(embed_dim)(x)
    x = tf.keras.layers.RNN(SpikingReLU(embed_dim), return_sequences=True, return_state=False, 
                            stateful=True)(x)
    x = tf.keras.layers.Dense(embed_dim)(x)
    x = tf.keras.layers.Dense(mlp_dim)(x)
    x = tf.keras.layers.RNN(SpikingReLU(mlp_dim), return_sequences=True, return_state=False, 
                            stateful=True)(x)
    # --------------------------------------------------
    x = tf.keras.layers.Dense(num_classes)(x)
    
    x = tf.keras.layers.RNN(Accumulate(num_classes), return_sequences=True, return_state=False, stateful=True)(x)
    x = tf.keras.layers.Softmax()(x)
    
    x = SqueezeLayer()(x)
    
    spiking = tf.keras.models.Model(inputs=inputs, outputs=x)
    
    print("-"*32 + "\n")
    spiking.compile(
        optimizer="adam",
        loss="categorical_crossentropy",
        metrics=["accuracy"])
    print(spiking.summary())
    spiking.set_weights(weights)
    return spiking


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=vocab_size)
y_train = to_categorical(y_train, 2)
y_test = to_categorical(y_test, 2)

x_train = tf.keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = tf.keras.preprocessing.sequence.pad_sequences(x_test, maxlen=maxlen)

# Analog model
ann = create_ann_approved_version()
print(ann.summary())

_, testacc = ann.evaluate(x_test, y_test, batch_size=batch_size, verbose=0)
# weights = ann.get_weights()
# weights = get_normalized_weights(ann, x_train, percentile=85)

model_normalized = robust_weight_normalization(ann, x_train)
weights = model_normalized.get_weights()

##################################################
# Preprocessing for RNN 
x_train = np.expand_dims(x_train, axis=1)  # (60000, 784) -> (60000, 1, 784)
x_test = np.expand_dims(x_test, axis=1)

##################################################
# Conversion to spiking model
# snn = convert(ann, weights, x_test, y_test)
snn = convert_tailored_approved_version(weights, y_test)
evaluate_conversion(snn, ann, x_test, y_test, testacc, timesteps=10)

  x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
  x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])


Epoch 1/2
Epoch 2/2
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 200)]        0                                            
__________________________________________________________________________________________________
tokpos (Tokpos)                 (None, 200, 32)      326400      input_1[0][0]                    
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 200, 32)      1056        tokpos[0][0]                     
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 200, 32)      1056        tokpos[0][0]                     
__________________________________________________________________________