### Imports

In [None]:
from typing import Union, NamedTuple
from utils import make_spike_raster_dataset
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import pylab as plt


### Parameters

In [None]:
Nc = 2 # Number of Classes
N = [16, 32, Nc] # List of number of neurons per layer
Nepochs = 10
T = 100
NUM_SAMPLES_PER_CLASS = 1000
TRAIN_TEST_SPILT = 0.8
NUM_SAMPLES_TRAIN = int(Nc*NUM_SAMPLES_PER_CLASS*TRAIN_TEST_SPILT)
BATCHSIZE = 48

SEED = 42
rng = np.random.default_rng(SEED)

### Data Creation

In [None]:


# TODO not necessary for keras, but is more general, easier to extend...
def create_dataloader(data, labels, batchsize, shuffle=True):
    dataset = tf.data.Dataset.from_tensor_slices((data, labels))
    num_samples = labels.shape[0]
    if shuffle:
        dataset = dataset.shuffle(num_samples, reshuffle_each_iteration=False)
    # dataset = dataset.repeat()
    # dataset = dataset.interleave(num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batchsize, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

data, labels = make_spike_raster_dataset(rng, nb_classes=Nc, nb_units=N[0], nb_steps=T, step_frac=1.0, dim_manifold=2, nb_spikes=1, nb_samples=NUM_SAMPLES_PER_CLASS, alpha=2.0, shuffle=True)
data_train, labels_train = data[:NUM_SAMPLES_TRAIN], labels[:NUM_SAMPLES_TRAIN]
data_test,  labels_test  = data[NUM_SAMPLES_TRAIN:], labels[NUM_SAMPLES_TRAIN:]
dataloader_train = create_dataloader(data_train, labels_train, BATCHSIZE, shuffle=True)
dataloader_test  = create_dataloader(data_test,  labels_test,  BATCHSIZE, shuffle=False)

### Model Creation

In [None]:
@tf.custom_gradient
def smooth_step(x):
  spikes = tf.experimental.numpy.heaviside(x, 1)
  beta = 10.0
  
  def grad(upstream):
    return upstream * 1/(beta*tf.math.abs(x)+1)**2
  return spikes, grad    

In [None]:
class LIFDenseNeuronState(NamedTuple):
    '''
    Generic Module for storing the state of an RNN/SNN.
    '''
    U: Union[tf.Tensor, tf.TensorShape]
    I: Union[tf.Tensor, tf.TensorShape]
    Ir: Union[tf.Tensor, tf.TensorShape]
    S: Union[tf.Tensor, tf.TensorShape]

In [None]:
def custom_init(shape, dtype=None):
    limit = (6/(shape[0]+shape[1]))**0.5
    return tf.random.uniform(shape, minval=-limit, maxval=limit, dtype=dtype)

class LIFDensePopulation(keras.layers.Layer):
    def __init__(self, out_channels, alpha, beta, betar):
        super().__init__()
        # choose your initialization method...
        self.fc_layer  = keras.layers.Dense(out_channels, use_bias=False, kernel_initializer=keras.initializers.RandomUniform(-0.5, 0.5))
        # self.fc_layer  = keras.layers.Dense(out_channels, use_bias=False, kernel_initializer=keras.initializers.GlorotUniform())
        # self.fc_layer  = keras.layers.Dense(out_channels, use_bias=False, kernel_initializer=custom_init)
        self.rec_layer = keras.layers.Dense(out_channels, use_bias=False, kernel_initializer=keras.initializers.Constant(0.0))
        self.out_channels = out_channels
        self.alpha = alpha
        self.beta = beta
        self.betar = betar
   
    def call(self, Sin_t, state):
        U = self.alpha*(1-tf.stop_gradient(state.S))*state.U + (1-self.alpha)*(20*state.I+state.Ir)
        I = self.beta*state.I + (1-self.beta)*self.fc_layer(Sin_t)
        Ir = self.betar*state.Ir + (1-self.betar)*self.rec_layer(state.S)
        S = smooth_step(U-1)
        new_state = LIFDenseNeuronState(U, I, Ir, S)
        return S, new_state
    
    def get_initial_state(self, inputs, batch_size, dtype):
        return LIFDenseNeuronState(*[tf.zeros((batch_size, self.out_channels)) for _ in range(4)])

    def get_state_size(self):
        return LIFDenseNeuronState(*[tf.TensorShape((self.out_channels,)) for _ in range(4)])

In [None]:
class LIFNetworkCell(keras.layers.Layer):
    def __init__(self, N, alpha = .95, beta = .9, betar = .85):
        super().__init__()        
        layers = []
        for units in N[1:]:
            layers.append(LIFDensePopulation(out_channels = units,
                                             alpha = alpha,
                                             beta = beta,
                                             betar = betar))
        self.layers = layers
        self.state_size = [layer.get_state_size() for layer in self.layers]
        # self.output_size = [layer.get_state_size() for layer in self.layers]
        self.output_size = [layer.get_state_size().S for layer in self.layers]
    
    def call(self, Sin_t, state):
        new_state = []
        for layer,state_ilay in zip(self.layers, state):
            Sin_t, new_state_ilay = layer(Sin_t, state_ilay)
            new_state.append(new_state_ilay)
        return [stat.S for stat in new_state], new_state # Returns final state of last layer

def model_fn(seq_len, batchsize, dims, alpha = .95, beta = .9, betar = .85, return_sequences=True):
    inp_spikes = keras.Input(shape=(seq_len, dims[0]), batch_size=batchsize, dtype=tf.float32)
    out = keras.layers.RNN(LIFNetworkCell(dims, alpha, beta, betar), return_sequences=return_sequences, time_major=False)(inp_spikes)
    return inp_spikes, out

In [None]:
net = keras.Model(*model_fn(T, BATCHSIZE, N))

### Learning Setup

In [None]:
def sum_and_sparse_categorical_crossentropy(y_true, y_pred):
    sum_spikes = tf.reduce_sum(y_pred, axis=1) # (batch, seq_len, neurons)
    softmax_pred = tf.nn.softmax(sum_spikes, axis=1)
    one_hot_target = tf.one_hot(y_true, softmax_pred.shape[-1], axis=-1, dtype=softmax_pred.dtype)
    return tf.math.reduce_sum((softmax_pred-one_hot_target)**2)/y_true.shape[-1]
    # return tf.keras.metrics.sparse_categorical_crossentropy(y_true, sum_spikes, from_logits=True)

def calc_activity(y_true, y_pred):
    sum_spikes = tf.reduce_sum(y_pred) /(y_pred.shape[0]*y_pred.shape[2])
    return sum_spikes

def calc_accuracy(y_true, y_pred):
    sum_spikes = tf.reduce_sum(y_pred, axis=1)
    return tf.keras.metrics.sparse_categorical_accuracy(y_true, sum_spikes)


opt = keras.optimizers.SGD(learning_rate=1e-1, momentum=0.9, name="SGD")

num_layers = len(N)-1
out_name = net.layers[-1].name
# only calculate loss for last layer
loss_funcs = {f"{out_name}_{i}" if i>0 else out_name : lambda x, y: 0.0 for i in range(num_layers-1)}
loss_funcs[f"{out_name}_{num_layers-1}"] = sum_and_sparse_categorical_crossentropy

# calculate activity of all layers and additionally accuracy of last layer
metrics = {f"{out_name}_{i}" if i>0 else out_name: calc_activity for i in range(num_layers)}
metrics[f"{out_name}_{num_layers-1}"] = [calc_activity, calc_accuracy]

In [None]:
net.compile(opt, loss_funcs,
            metrics=metrics,
            steps_per_execution=24, # Execute multiple batches with a single call
            # jit_compile=True # jit compile the model for faster execution
)
net.summary()

In [None]:
net.fit(dataloader_train, epochs=Nepochs, workers=BATCHSIZE)