In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from absl import app
from absl import flags
import numpy as np
from spiking_tf.src import spiking_models, plots, file_handling
import keras
from keras import backend as K
import os

In [2]:
timesteps = 1
max_rate = 2


n_in = 28*28
n_hidden = 800
n_out = 10

batch_size = 128
epochs = 1

thr = 0.1
tau = 10.0

output_path = file_handling.get_default_path_str()

In [3]:
class MinimalRNNCell(keras.layers.Layer):

    def __init__(self, units, **kwargs):
        self.units = units
        self.state_size = units
        super(MinimalRNNCell, self).__init__(**kwargs)

    def build(self, input_shape):
        self.kernel = self.add_weight(shape=(input_shape[-1], self.units), initializer='uniform', name='kernel')
        self.recurrent_kernel = self.add_weight(shape=(self.units, self.units), initializer='uniform',
                                                name='recurrent_kernel')
        self.built = True

    def call(self, inputs, states):
        prev_output = states[0]
        h = K.dot(inputs, self.kernel)
        output = h + K.dot(prev_output, self.recurrent_kernel)
        return output, [output]


def flatten(image, label):
    '''Transform image to the flattened version of itself'''
    flattened = tf.reshape(image, [28*28])
    flattened = tf.expand_dims(flattened, 0)
    return tf.cast(flattened, tf.float32), label

In [11]:
tf.random.set_seed(1234)

(ds_train, ds_test) = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True
)
ds_train = ds_train.map(
    flatten, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(batch_size)
ds_train = ds_train.batch(batch_size)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)
print(ds_train)

ds_test = ds_test.map(
    flatten, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(batch_size)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

inputs = tf.keras.layers.Input(shape=(timesteps, 28 * 28))
mid_z = tf.keras.layers.RNN(MinimalRNNCell(n_hidden), return_sequences=True, use_bias=False)(inputs)
out_z = tf.keras.layers.RNN(MinimalRNNCell(n_out), return_sequences=False, use_bias=False)(mid_z)

print("==================== Start training =======================")
model = tf.keras.models.Model(inputs=inputs, outputs=[out_z])

model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer="adam",
    metrics=["sparse_categorical_accuracy"],
)
history = model.fit(ds_train, epochs=epochs, validation_data=ds_test)

model.save_weights(os.path.join("weights", "simple_rnn"))

<PrefetchDataset shapes: ((None, 1, 784), (None,)), types: (tf.float32, tf.int64)>


TypeError: ('Keyword argument not understood:', 'use_bias')

In [5]:
np.save("analog_rnn.npy", model.get_weights())

In [10]:
np.array(model.get_weights())[3].shape

(10, 10)