In [1]:
%load_ext tensorboard
import tensorflow as tf
import tensorflow_datasets as tfds
from absl import app
from absl import flags
import numpy as np
from spiking_tf.src import spiking_models, plots, file_handling
import datetime

In [2]:
timesteps = 1
max_rate = 2

n_in = 28*28
n_hidden = 800
n_out = 10

batch_size = 128
epochs = 10

output_path = file_handling.get_default_path_str()

In [3]:
def flatten(image, label):
    '''Transform image to the flattened version of itself'''
    flattened = tf.reshape(image, [28*28])
    flattened = tf.expand_dims(flattened, 0)
    return tf.cast(flattened, tf.float32), label

In [4]:
tf.random.set_seed(1234)

(ds_train, ds_test) = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True
)
ds_train = ds_train.map(
    flatten, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(batch_size)
ds_train = ds_train.batch(batch_size)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)
print(ds_train)

ds_test = ds_test.map(
    flatten, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(batch_size)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

<PrefetchDataset shapes: ((None, 1, 784), (None,)), types: (tf.float32, tf.int64)>


In [5]:
model = tf.keras.Sequential([
  tf.keras.layers.Dense(800, activation='relu', input_shape=(784,), use_bias=False),
  tf.keras.layers.Dense(10, use_bias=False),
])

model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer="adam",
    metrics=["sparse_categorical_accuracy"],
    run_eagerly=False
)

log_dir = "logs/fit/analog_feedforward" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

history = model.fit(
    ds_train, epochs=epochs, validation_data=ds_test, callbacks=[tensorboard_callback]
)

Epoch 1/10








Instructions for updating:
use `tf.profiler.experimental.stop` instead.


Instructions for updating:
use `tf.profiler.experimental.stop` instead.










Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [6]:
from tensorflow.keras.utils import plot_model

plot_model(model, to_file='model_plot.png', show_shapes=True, show_layer_names=True)

('Failed to import pydot. You must `pip install pydot` and install graphviz (https://graphviz.gitlab.io/download/), ', 'for `pydotprint` to work.')


In [19]:
%tensorboard --logdir logs/fit