In [1]:
import collections
#import environments
import sonnet as snt
import numpy as np
import tensorflow as tf
from tensorflow.python.util import nest
import matplotlib.pyplot as plt 
import matplotlib.image as mpimg
# import skimage
# from skimage import data
# from skimage import transform
# import os 

from alif_functions import CustomALIFWithReset, spike_function, CustomALIF
from util import to_bool, switch_time_and_batch_dimension, exp_convolve

#### SNN functions

In [2]:
@tf.custom_gradient
def spike_function(v_scaled, dampening_factor):
    z_ = tf.greater(v_scaled, 0.)
    z_ = tf.cast(z_, dtype=tf.float32)

    def grad(dy):
        dE_dz = dy
        dz_dv_scaled = tf.maximum(1 - tf.abs(v_scaled), 0)
        dz_dv_scaled *= dampening_factor

        dE_dv_scaled = dE_dz * dz_dv_scaled

        return [dE_dv_scaled,
                tf.zeros_like(dampening_factor)]

    return tf.identity(z_, name="spike_function"), grad


In [3]:
def lif_dynamic(v, i, decay, v_th, dampening_factor=.3):
    old_z = spike_function((v - v_th) / v_th, dampening_factor)
    new_v = decay * v + i - old_z * v_th
    new_z = spike_function((new_v - v_th) / v_th, dampening_factor)
    return new_v, new_z

#### SNN model architecture

In [4]:
class SpikingCNN(tf.compat.v1.nn.rnn_cell.RNNCell):
    def __init__(self, n_kernel_1=8, n_filter_1=16, stride_1=4, n_kernel_2=8, n_filter_2=32, stride_2=4, ba=False,
				 avg_ba=False, ba_config=None, tau=1, thr=1., avg_pool_1_stride=4, avg_pool_1_k=8, avg_pool_2_stride=2,
				 avg_pool_2_k=4):
        super().__init__()
        self.decay = np.exp(-1 / tau)
        self.v_th = thr
        self.n_filters_1 = n_filter_1
        self.n_filters_2 = n_filter_2
        self.ba = ba
        self.avg_ba = avg_ba
        self.n_w_1 = (128 - n_kernel_1) // stride_1 + 2
        self.n_w_2 = (self.n_w_1 - n_kernel_2) // stride_2 + 1
        self.n_kernel_1 = n_kernel_1
        self.stride_1 = stride_1
        self.n_kernel_2 = n_kernel_2
        self.stride_2 = stride_2
        self.avg_pool_1_stride = avg_pool_1_stride
        self.avg_pool_2_stride = avg_pool_2_stride
        self.avg_pool_1_k = avg_pool_1_k
        self.avg_pool_2_k = avg_pool_2_k

        self.n_avg_1 = (self.n_w_1 - avg_pool_1_k) // avg_pool_1_stride + 1
        self.n_avg_2 = (self.n_w_2 - avg_pool_2_k) // avg_pool_2_stride + 1

        if ba_config is not None:
            self.ba_filters_1_1 = ba_config['ba_filters_1_1']
            self.ba_kernel_1_1 = ba_config['ba_kernel_1_1']
            self.ba_stride_1_1 = ba_config['ba_stride_1_1']
            self.ba_filters_1_2 = ba_config['ba_filters_1_2']
            self.ba_kernel_1_2 = ba_config['ba_kernel_1_2']
            self.ba_stride_1_2 = ba_config['ba_stride_1_2']
            self.ba_filters_2 = ba_config['ba_filters_2']
            self.ba_kernel_2 = ba_config['ba_kernel_2']
            self.ba_stride_2 = ba_config['ba_stride_2']
        else:
            self.ba_filters_1_1 = 16
            self.ba_kernel_1_1 = 8
            self.ba_stride_1_1 = 4
            self.ba_filters_1_2 = 32
            self.ba_kernel_1_2 = 4
            self.ba_stride_1_2 = 2
            self.ba_filters_2 = 32
            self.ba_kernel_2 = 4
            self.ba_stride_2 = 2

        self.n_ba_1 = (self.n_w_1 - self.ba_kernel_1_1) // self.ba_stride_1_1 + 1
        self.n_ba_2 = (self.n_w_2 - self.ba_kernel_2) // self.ba_stride_2 + 1

        # self.n_ba_1 = (self.n_w_1 - ba_config['ba_kernel_1_1']) // ba_config['ba_stride_1_1'] + 1
        # self.n_ba_2 = (self.n_w_2 - ba_config['ba_kernel_2']) // ba_config['ba_stride_2'] + 1

    @property
    def output_size(self):
        if self.ba:
            return self.n_w_1 * self.n_w_1 * self.n_filters_1, \
                   self.n_w_1 * self.n_w_1 * self.n_filters_1, \
                   self.n_w_2 * self.n_w_2 * self.n_filters_2, \
                   self.n_w_2 * self.n_w_2 * self.n_filters_2, \
                   self.n_ba_1 * self.n_ba_1 * self.ba_filters_1_1, \
                   self.n_ba_2 * self.n_ba_2 * self.ba_filters_2
        if self.avg_ba:
            return self.n_w_1 * self.n_w_1 * self.n_filters_1, \
                   self.n_w_1 * self.n_w_1 * self.n_filters_1, \
                   self.n_w_2 * self.n_w_2 * self.n_filters_2, \
                   self.n_w_2 * self.n_w_2 * self.n_filters_2, \
                   self.n_avg_1 * self.n_avg_1 * self.n_filters_1, \
                   self.n_avg_2 * self.n_avg_2 * self.n_filters_2
        return self.n_w_1 * self.n_w_1 * self.n_filters_1, \
               self.n_w_1 * self.n_w_1 * self.n_filters_1, \
               self.n_w_2 * self.n_w_2 * self.n_filters_2, \
               self.n_w_2 * self.n_w_2 * self.n_filters_2

    def zero_state(self, batch_size, dtype):
        return tf.zeros((batch_size, self.n_w_1, self.n_w_1, self.n_filters_1), dtype), \
               tf.zeros((batch_size, self.n_w_2, self.n_w_2, self.n_filters_2), dtype)

    @property
    def state_size(self):
        return (self.n_w_1, self.n_w_1, self.n_filters_1), (self.n_w_2, self.n_w_2, self.n_filters_2)

    def __call__(self, inputs, state):
        v_conv_1, z_conv_1 = lif_dynamic(state[0], inputs, self.decay, self.v_th, 1.)
        if self.ba and not self.avg_ba:
            with tf.compat.v1.variable_scope('broadcast_1'):
                c1_r = snt.Conv2D(
                    self.ba_filters_1_1,
                    self.ba_kernel_1_1,
                    stride=self.ba_stride_1_1,
                    padding='VALID'
                )(z_conv_1)
                c1_r = snt.BatchFlatten()(c1_r)

                z_conv_1 = tf.stop_gradient(z_conv_1)
        elif self.avg_ba:
            with tf.compat.v1.variable_scope('broadcast_1'):
                c1_r = tf.nn.avg_pool2d(
                    input=z_conv_1, ksize=self.avg_pool_1_k, 
                    strides=self.avg_pool_1_stride, padding='VALID')
                c1_r = snt.BatchFlatten()(c1_r)
                z_conv_1 = tf.stop_gradient(z_conv_1)

        i_conv_2 = snt.Conv2D(self.n_filters_2, self.n_kernel_2, stride=self.stride_2, padding='VALID')(z_conv_1)
        v_conv_2, z_conv_2 = lif_dynamic(state[1], i_conv_2, self.decay, self.v_th, 1.)
        if self.ba and not self.avg_ba:
            with tf.compat.v1.variable_scope('broadcast_2'):
                layer_c2_r = snt.Conv2D(
                    self.ba_filters_2,
                    self.ba_kernel_2,
                    stride=self.ba_stride_2,
                    padding='VALID'
                )
                c2_r = layer_c2_r(z_conv_2)
                c2_r = snt.BatchFlatten()(c2_r)

                z_conv_2 = tf.stop_gradient(z_conv_2)
        elif self.avg_ba:
            with tf.compat.v1.variable_scope('broadcast_2'):
                c2_r = tf.nn.avg_pool2d(
                    input=z_conv_2, ksize=self.avg_pool_2_k, 
                    strides=self.avg_pool_2_stride, padding='VALID')
                c2_r = snt.BatchFlatten()(c2_r)
                z_conv_2 = tf.stop_gradient(z_conv_2)
        new_state = (v_conv_1, v_conv_2)
        if self.ba or self.avg_ba:
            return (tf.reshape(z_conv_1, (-1, self.n_w_1 * self.n_w_1 * self.n_filters_1)),
                    tf.reshape(v_conv_1, (-1, self.n_w_1 * self.n_w_1 * self.n_filters_1)),
                    tf.reshape(z_conv_2, (-1, self.n_w_2 * self.n_w_2 * self.n_filters_2)),
                    tf.reshape(v_conv_2, (-1, self.n_w_2 * self.n_w_2 * self.n_filters_2)), c1_r, c2_r), new_state
        return (tf.reshape(z_conv_1, (-1, self.n_w_1 * self.n_w_1 * self.n_filters_1)),
                tf.reshape(v_conv_1, (-1, self.n_w_1 * self.n_w_1 * self.n_filters_1)),
                tf.reshape(z_conv_2, (-1, self.n_w_2 * self.n_w_2 * self.n_filters_2)),
                tf.reshape(v_conv_2, (-1, self.n_w_2 * self.n_w_2 * self.n_filters_2))), new_state


In [30]:
class SpikingAgent(snt.RNNCore):
    def __init__(self, action_set, rnn_units, stop_gradient=False, n_rnn_step_factor=1,
                 tau=20, tau_readout=5, thr=.615, n_filters_1=16, n_filters_2=32,
                 n_kernel_1=8, n_kernel_2=4, stride_1=4, stride_2=2,
                 beta=.16, tau_adaptation=300, ba=False, avg_ba=False,
                 ba_config=None, fraction_adaptive=.4, n_refractory=3,
                 tau_scnn=1., thr_scnn=1., avg_pool_1_stride=4, avg_pool_1_k=8,
                 avg_pool_2_stride=2, avg_pool_2_k=4):
        super(SpikingAgent, self).__init__(name='agent')

        self._num_actions = len(action_set)
        self.ba = ba
        self.avg_ba = avg_ba
        self.rnn_units = rnn_units
        tau_readout = np.atleast_1d(tau_readout)
        self.decay = np.exp(-1 / tau_readout)
        self.n_rnn_step_factor = n_rnn_step_factor
        self.thr = thr
        self.n_filters_1 = n_filters_1
        self.n_kernel_1 = n_kernel_1
        self.stride_1 = stride_1
        self.n_filters_2 = n_filters_2
        self.n_kernel_2 = n_kernel_2
        self.stride_2 = stride_2
        self.no_linear = True
        if ba_config is not None:
            self.ba_filters_1_1 = ba_config['ba_filters_1_1']
            self.ba_kernel_1_1 = ba_config['ba_kernel_1_1']
            self.ba_stride_1_1 = ba_config['ba_stride_1_1']
            self.ba_filters_1_2 = ba_config['ba_filters_1_2']
            self.ba_kernel_1_2 = ba_config['ba_kernel_1_2']
            self.ba_stride_1_2 = ba_config['ba_stride_1_2']
            self.ba_filters_2 = ba_config['ba_filters_2']
            self.ba_kernel_2 = ba_config['ba_kernel_2']
            self.ba_stride_2 = ba_config['ba_stride_2']
        else:
            self.ba_filters_1_1 = 32
            self.ba_kernel_1_1 = 8
            self.ba_stride_1_1 = 4
            self.ba_filters_1_2 = 64
            self.ba_kernel_1_2 = 4
            self.ba_stride_1_2 = 2
            self.ba_filters_2 = 64
            self.ba_kernel_2 = 4
            self.ba_stride_2 = 2

        with self._enter_variable_scope():
            n_regular = int(rnn_units * (1. - fraction_adaptive))
            n_adaptive = rnn_units - n_regular
            beta = np.concatenate((np.zeros(n_regular), np.ones(n_adaptive))).astype(np.float32) * beta
            self.beta = beta
            n_w_1 = (128 - n_kernel_1) // stride_1 + 2
            n_w_2 = (n_w_1 - n_kernel_2) // stride_2 + 1
            n_input = n_w_2 * n_w_2 * n_filters_2

            #rnn_units can be experimented with 

            self.core = CustomALIFWithReset(n_input, rnn_units, tau=tau, beta=beta, thr=thr,
                                            tau_adaptation=tau_adaptation, stop_gradients=stop_gradient,
                                            n_refractory=n_refractory)
            
            self.scnn = SpikingCNN(n_filter_1=n_filters_1, stride_1=stride_1, n_kernel_1=n_kernel_1, n_filter_2=n_filters_2, stride_2=stride_2, n_kernel_2=n_kernel_2, ba=ba, avg_ba=avg_ba, ba_config=ba_config, tau=tau_scnn, thr=thr_scnn, avg_pool_1_stride=avg_pool_1_stride, avg_pool_1_k=avg_pool_1_k, avg_pool_2_stride=avg_pool_2_stride, avg_pool_2_k=avg_pool_2_k)

    def initial_state(self, batch_size):
        return self.core.zero_state(batch_size, tf.float32), \
               (tf.zeros((batch_size, self._num_actions)), tf.zeros((batch_size, 1))), \
               self.scnn.zero_state(batch_size, tf.float32)

    def initial_eligibility_traces(self, batch_size):
        initial_eligibility_traces = [
            tf.tile(tf.zeros_like(self.core.w_in_var[None, ..., None]), (batch_size, 1, 1, 2)),
            tf.tile(tf.zeros_like(self.core.w_rec_var[None, ..., None]), (batch_size, 1, 1, 2))
        ]
        return initial_eligibility_traces

    def _head(self, core_output, head_state, torso_dict):
        def f(core_output):
            i_policy_logits = snt.Linear(self._num_actions, name='policy_logits')(core_output)
            i_baseline = tf.squeeze(snt.Linear(1, name='baseline')(core_output), axis=-1)
            return i_policy_logits, i_baseline

        core_output = tf.concat((core_output, torso_dict['c1_r'], torso_dict['c2_r']), -1)
        policy = 0.
        baseline = 0.
        for decay in self.decay:
            i_policy_logits, i_baseline = snt.BatchApply(f)(core_output)
            policy += exp_convolve(i_policy_logits, decay, initializer=head_state[0])
            baseline += exp_convolve(i_baseline[..., None], decay, initializer=head_state[1])
        policy = policy[self.n_rnn_step_factor - 1::self.n_rnn_step_factor]
        baseline = baseline[self.n_rnn_step_factor - 1::self.n_rnn_step_factor]

        def g(policy):
            new_action = tf.multinomial(policy, num_samples=1,
                                        output_dtype=tf.int64)
            new_action = tf.squeeze(new_action, 1, name='new_action')
            return new_action

        new_action = snt.BatchApply(g)(policy)
        new_head_state = (policy[-1], baseline[-1])

        return AgentOutput(new_action, policy, baseline[..., 0]), new_head_state

    def _build(self, input_, core_state):
        action, env_output = input_
        env_outputs = environments.StepOutput(
            reward=env_output.reward[None, ...],
            info=nest.map_structure(lambda t: t[None, ...], env_output.info),
            done=to_bool(tf.cast(env_output.done, tf.int64)[None, ...]),
            observation=(tf.to_float(env_output.observation[0])[None, ...], tf.zeros(()))
        )
        actions = action[None, ...]
        outputs, core_state, custom_rnn_output, torso_outputs = self.unroll(actions, env_outputs, core_state)
        return nest.map_structure(lambda t: tf.squeeze(t, 0), outputs), core_state, \
               custom_rnn_output, torso_outputs

    @snt.reuse_variables
    def unroll(self, actions, env_outputs, core_state, write_to_collection=False):
        _, _, done, _ = env_outputs

        env_outputs = environments.StepOutput(
            reward=env_outputs.reward,
            info=env_outputs.info,
            done=env_outputs.done,
            observation=(env_outputs.observation[0], tf.zeros(()))
        )

        """
        env_ouputs = tuple of images and blicket state (boolean)
        
        """

        head_state = core_state[1]
        n_time, n_batch = actions.get_shape()
        done = tf.cast(done, tf.float32)[..., None]
        expanded_dones = tf.reshape(
            tf.tile(done[:, None, ...],
                    (1, self.n_rnn_step_factor, 1, 1)), (n_time * self.n_rnn_step_factor, n_batch, -1))
        frame = tf.cast(env_outputs.observation[0], tf.float32) / 255.
        with tf.variable_scope('convnet'):
            i_conv1 = snt.BatchApply(snt.Conv2D(self.n_filters_1, self.n_kernel_1, stride=self.stride_1, padding='VALID', use_bias=True))(frame)
            shp = i_conv1.get_shape()
            i_conv1 = tf.reshape(
                tf.tile(i_conv1[:, None, ...],
                        (1, self.n_rnn_step_factor, 1, 1, 1, 1)), (n_time * self.n_rnn_step_factor, n_batch, *shp[2:]))
            i_conv1 = tf.transpose(i_conv1, (1, 0, 2, 3, 4))

            scnn_output, new_scnn_state = tf.nn.dynamic_rnn(self.scnn, i_conv1, initial_state=core_state[2]) #look into this

            to_collection = dict()
            to_collection['lin_z'] = tf.zeros_like(scnn_output[1])
            to_collection['lin_act'] = tf.zeros_like(scnn_output[1])
            to_collection['c1_act'] = scnn_output[1]
            to_collection['c1_z'] = scnn_output[0]
            to_collection['c2_act'] = scnn_output[3]
            to_collection['c2_z'] = scnn_output[2]
            if self.ba or self.avg_ba:
                to_collection['c1_r'] = tf.transpose(scnn_output[4], (1, 0, 2))
                to_collection['c2_r'] = tf.transpose(scnn_output[5], (1, 0, 2))
            if write_to_collection:
                tf.add_to_collection('torso_output', to_collection)
            torso_outputs = scnn_output[2]
            expanded_dones = tf.transpose(expanded_dones, (1, 0, 2))
            dynamic_rnn_inputs = tf.concat((scnn_output[2], expanded_dones), -1)

        custom_rnn_output, core_state = tf.nn.dynamic_rnn(
            self.core, dynamic_rnn_inputs, initial_state=core_state[0])
        custom_rnn_output = nest.map_structure(switch_time_and_batch_dimension, custom_rnn_output)
        core_output = custom_rnn_output[0]
        core_output.set_shape((n_time * self.n_rnn_step_factor, n_batch, self.rnn_units))
        head_output, head_state = self._head(core_output, head_state, to_collection)
        core_state = (core_state, head_state, new_scnn_state)
        return head_output, core_state, custom_rnn_output, torso_outputs

#### Initialization

In [5]:
batch_size = 5
n_time = 10
rnn_units = 100

img = np.random.rand(batch_size, n_time, 128,128,4)
img_tensor = tf.convert_to_tensor(img, tf.float32)
i_conv1 = snt.BatchApply(snt.Conv2D(16, 8, stride=4))(img_tensor)
conv_shape = i_conv1.get_shape()

print('i_conv1 shape:', i_conv1.get_shape())

scnn = SpikingCNN()
scnn_state = scnn.zero_state(batch_size, tf.float32)

print('scnn state:', scnn_state[0].get_shape(), scnn_state[1].get_shape())

scnn_output, new_scnn_state = tf.compat.v1.nn.dynamic_rnn(scnn, i_conv1, initial_state=scnn_state)

print('scnn output shape:', [o.get_shape() for o in scnn_output])
print('new scnn state shape:', [s.get_shape() for s in new_scnn_state])




i_conv1 shape: (5, 10, 32, 32, 16)
scnn state: (5, 32, 32, 16) (5, 7, 7, 32)
Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API
scnn output shape: [TensorShape([5, 10, 16384]), TensorShape([5, 10, 16384]), TensorShape([5, 10, 1568]), TensorShape([5, 10, 1568])]
new scnn state shape: [TensorShape([5, 32, 32, 16]), TensorShape([5, 7, 7, 32])]


In [6]:
n_kernel_1 = 8
n_kernel_2 = 8
n_stride_1 = 4
n_stride_2 = 4
n_filters_2 = 32

n_w_1 = (128 - n_kernel_1) // n_stride_1 + 2
n_w_2 = (n_w_1 - n_kernel_2) // n_stride_2 + 1
n_inputs = n_w_2 * n_w_2 * n_filters_2

In [9]:
core = CustomALIF(n_in = n_inputs, n_rec = rnn_units)
rnn_state = core.zero_state(batch_size, tf.float32)

print('scnn_output 2', scnn_output[2].get_shape())

rnn_output, rnn_state = tf.compat.v1.nn.dynamic_rnn(core, scnn_output[2], initial_state=rnn_state)

print('rnn output:', [o.get_shape() for o in rnn_output])
print('rnn state:', [s.get_shape() for s in rnn_state])

rnn_output = nest.map_structure(switch_time_and_batch_dimension, rnn_output)
core_output = rnn_output[0]
core_output.set_shape((n_time * 1, batch_size, rnn_units))

print('core_output:', core_output.get_shape())

# print('4th and 5th output:', scnn_output[0][3], scnn_output[0][4])

num_actions = 3
tau = 1
decays = [np.exp(-1.0 / tau)]

controller_state = (tf.zeros((batch_size, num_actions)), tf.zeros((batch_size, 1)))

def f(core_output):
    i_policy_logits = snt.Linear(num_actions, name='policy_logits')(core_output)
    i_baseline = tf.squeeze(snt.Linear(1, name='baseline')(core_output), axis=-1)
    return i_policy_logits, i_baseline

policy = 0.
baseline = 0.
for decay in decays:
    i_policy_logits, i_baseline = (f(core_output))
    policy += exp_convolve(i_policy_logits, decay, initializer=controller_state[0])
    baseline += exp_convolve(i_baseline[..., None], decay, initializer=controller_state[1])
policy = policy[1 - 1::1]
baseline = baseline[1 - 1::1]

print('policy:', policy.get_shape())
print('baseline:', baseline.get_shape())

def g(policy):
    new_action = tf.compat.v1.multinomial(policy, num_samples=1,
                                        output_dtype=tf.int64)
    new_action = tf.squeeze(new_action, 1, name='new_action')
    return new_action

new_action = (g(policy))
# print(new_action.eval())
new_controller_state = (policy[-1], baseline[-1])

print('new action shape:', new_action.get_shape())
print('new controller state:', [s.get_shape() for s in new_controller_state])

scnn_output 2 (5, 10, 1568)
inputs: (5, 1568)
w_in_val: (1568, 100)
w_rec_val: (100, 100)
z: (5, 100)
inputs: (5, 1568)
w_in_val: (1568, 100)
w_rec_val: (100, 100)
z: (5, 100)
inputs: (5, 1568)
w_in_val: (1568, 100)
w_rec_val: (100, 100)
z: (5, 100)
inputs: (5, 1568)
w_in_val: (1568, 100)
w_rec_val: (100, 100)
z: (5, 100)
inputs: (5, 1568)
w_in_val: (1568, 100)
w_rec_val: (100, 100)
z: (5, 100)
inputs: (5, 1568)
w_in_val: (1568, 100)
w_rec_val: (100, 100)
z: (5, 100)
inputs: (5, 1568)
w_in_val: (1568, 100)
w_rec_val: (100, 100)
z: (5, 100)
inputs: (5, 1568)
w_in_val: (1568, 100)
w_rec_val: (100, 100)
z: (5, 100)
inputs: (5, 1568)
w_in_val: (1568, 100)
w_rec_val: (100, 100)
z: (5, 100)
inputs: (5, 1568)
w_in_val: (1568, 100)
w_rec_val: (100, 100)
z: (5, 100)
rnn output: [TensorShape([5, 10, 100]), TensorShape([5, 10, 100, 2]), TensorShape([5, 10, 1]), TensorShape([5, 10, 1])]
rnn state: [TensorShape([5, 100, 2]), TensorShape([5, 100]), TensorShape([5, 100])]
core_output: (10, 5, 100)
po

In [None]:
tf.reset_default_graph()

with tf.Session() as run_session:
    batch_size = 5
    n_time = 10
    rnn_units = 100

    #img = np.random.rand(batch_size, n_time, 128,128,4)
    #img_tensor = tf.convert_to_tensor(img, tf.float32)
    i_conv1 = snt.BatchApply(snt.Conv2D(16, 8, stride=4))(image1)
    conv_shape = i_conv1.get_shape()


    #i_conv1 = tf.transpose(i_conv1, (1, 0, 2, 3, 4))

    # temp = i_conv1[:, None, ...]
    # print('temp:', temp.get_shape())

    print('i_conv1 shape:', i_conv1.get_shape())

    scnn = SpikingCNN()
    scnn_state = scnn.zero_state(batch_size, tf.float32)

    print('scnn state:', scnn_state[0].get_shape(), scnn_state[1].get_shape())

    scnn_output, new_scnn_state = tf.nn.dynamic_rnn(scnn, i_conv1, initial_state=scnn_state)

    print('scnn output shape:', [o.get_shape() for o in scnn_output])
    print('new scnn state shape:', [s.get_shape() for s in new_scnn_state])

    core = CustomALIF(n_in = n_inputs, n_rec = rnn_units)
    rnn_state = core.zero_state(batch_size, tf.float32)

    print('scnn_output 2', scnn_output[2].get_shape())

    rnn_output, rnn_state = tf.nn.dynamic_rnn(core, scnn_output[2], initial_state=rnn_state)

    print('rnn output:', [o.get_shape() for o in rnn_output])
    print('rnn state:', [s.get_shape() for s in rnn_state])

    rnn_output = nest.map_structure(switch_time_and_batch_dimension, rnn_output)
    core_output = rnn_output[0]
    core_output.set_shape((n_time * 1, batch_size, rnn_units))

    print('core_output:', core_output.get_shape())

    # print('4th and 5th output:', scnn_output[0][3], scnn_output[0][4])

    num_actions = 2
    tau = 1
    decays = [np.exp(-1.0 / tau)]

    controller_state = (tf.zeros((batch_size, num_actions)), tf.zeros((batch_size, 1)))

    def f(core_output):
        i_policy_logits = snt.Linear(num_actions, name='policy_logits')(core_output)
        i_baseline = tf.squeeze(snt.Linear(1, name='baseline')(core_output), axis=-1)
        return i_policy_logits, i_baseline

    policy = 0.
    baseline = 0.
    for decay in decays:
        i_policy_logits, i_baseline = snt.BatchApply(f)(core_output)
        policy += exp_convolve(i_policy_logits, decay, initializer=controller_state[0])
        baseline += exp_convolve(i_baseline[..., None], decay, initializer=controller_state[1])
    policy = policy[1 - 1::1]
    baseline = baseline[1 - 1::1]

    print('policy:', policy.get_shape())
    print('baseline:', baseline.get_shape())

    def g(policy):
        new_action = tf.distributions.Bernoulli(policy).sample()
        #new_action = tf.cast(new_action, tf.float32, name='new_action')
        #new_action = tf.squeeze(new_action, 1, name='new_action')
        print('new action shape', new_action[0])
        return new_action

    new_action = snt.BatchApply(g)(policy)
    # print(new_action.eval())
    new_controller_state = (policy[-1], baseline[-1])

    print('new action shape:', new_action.get_shape())
    print('new controller state:', [s.get_shape() for s in new_controller_state])

#### Rearranging code

In [15]:
#parameters 

n_kernel_1 = 8
n_kernel_2 = 8
n_stride_1 = 4
n_stride_2 = 4
n_filters_2 = 32
batch_size = 5
n_time = 10
rnn_units = 100

img = np.random.rand(batch_size, n_time, 128,128,4)

In [None]:
tf.reset_default_graph()

class SpikingAgent(snt.RNNCore):
    def __init__(self, action_set, batch_size, rnn_units, n_rnn_step_factor=1,
                 tau=20, tau_readout=5, thr=.615, n_filters_1=16, n_filters_2=32,
                 n_kernel_1=8, n_kernel_2=8, stride_1=4, stride_2=4,
                 beta=.16, tau_adaptation=300, fraction_adaptive=.4, n_refractory=3,
                 tau_scnn=1., thr_scnn=1., n_time = 10):

        self._num_actions = len(action_set)
        self._batch_size = batch_size
        self.rnn_units = rnn_units
        tau_readout = np.atleast_1d(tau_readout)
        self.decay = np.exp(-1 / tau_readout)
        self.n_rnn_step_factor = n_rnn_step_factor
        self.thr = thr
        self.n_filters_1 = n_filters_1
        self.n_kernel_1 = n_kernel_1
        self.stride_1 = stride_1
        self.n_filters_2 = n_filters_2
        self.n_kernel_2 = n_kernel_2
        self.stride_2 = stride_2
        self.n_time = n_time

        n_w_1 = (128 - n_kernel_1) // n_stride_1 + 2
        n_w_2 = (n_w_1 - n_kernel_2) // n_stride_2 + 1
        n_inputs = n_w_2 * n_w_2 * n_filters_2

        self.core = CustomALIF(n_in = n_inputs, n_rec = rnn_units)

        self.scnn = SpikingCNN()
        self.scnn_state = scnn.zero_state(self._batch_size, tf.float32)
    
    def initial_state(self, batch_size, dtype=tf.float32):
        return (tf.zeros((batch_size, self._num_actions)), tf.zeros((batch_size, 1)))
    
    def _controller(self, core_output, controller_state):
        def f(core_output):
            i_policy_logits = snt.Linear(self._num_actions, name='policy_logits')(core_output)
            i_baseline = tf.squeeze(snt.Linear(1, name='baseline')(core_output), axis=-1)
            return i_policy_logits, i_baseline

        policy = 0.
        baseline = 0.
        for decay in decays:
            i_policy_logits, i_baseline = snt.BatchApply(f)(core_output)
            policy += exp_convolve(i_policy_logits, decay, initializer=controller_state[0])
            baseline += exp_convolve(i_baseline[..., None], decay, initializer=controller_state[1])
        policy = policy[self.n_rnn_step_factor - 1::self.n_rnn_step_factor]
        baseline = baseline[self.n_rnn_step_factor - 1::self.n_rnn_step_factor]

        def g(policy):
            new_action = tf.compat.v1.multinomial(policy, num_samples=1,
                                        output_dtype=tf.int64)
            new_action = tf.squeeze(new_action, 1, name='new_action')
            return new_action

        new_action = snt.BatchApply(g)(policy)
        new_controller_state = (policy[-1], baseline[-1])

        return new_action, new_controller_state
    
    @snt.reuse_variables
    def scnn_unroll(self, inputs):
        i_conv1 = snt.BatchApply(snt.Conv2D(self.n_filters_1, self.n_kernel_1, stride=self.stride_1))(inputs)
        scnn_output, new_scnn_state = tf.compat.v1.nn.dynamic_rnn(self.scnn, i_conv1, initial_state=self.scnn_state)

        # rnn_state = core.zero_state(self._batch_size, tf.float32)
        # rnn_output, rnn_state = tf.nn.dynamic_rnn(self.core, scnn_output[2], initial_state=rnn_state)

        # rnn_output = nest.map_structure(switch_time_and_batch_dimension, rnn_output)
        # core_output = rnn_output[0]
        # core_output.set_shape((self.n_time * self.n_rnn_step_factor, self._batch_size, self.rnn_units))
    
        # initial_controller_state = self.initial_state(self._batch_size)
        # controller_output, controller_state = self._controller(core_output, initial_controller_state)
    
        return scnn_output, new_scnn_state
    
    def context_pass(self, inputs):
        """
        inputs: scnn output for the 6 context images
        """

        rnn_state = core.zero_state(self._batch_size, tf.float32)
        rnn_output, rnn_state = tf.nn.dynamic_rnn(self.core, scnn_output[2], initial_state=rnn_state)

        rnn_output = nest.map_structure(switch_time_and_batch_dimension, rnn_output)
        core_output = rnn_output[0]
        core_output.set_shape((self.n_time * self.n_rnn_step_factor, self._batch_size, self.rnn_units))
    
    def query_pass(self, query_input, rnn_state):
        """
        inputs: scnn output for the query image
        """

        rnn_state = core.zero_state(self._batch_size, tf.float32)
        rnn_output, rnn_state = tf.nn.dynamic_rnn(self.core, scnn_output[2], initial_state=rnn_state)

        rnn_output = nest.map_structure(switch_time_and_batch_dimension, rnn_output)
        core_output = rnn_output[0]

        initial_controller_state = self.initial_state(self._batch_size)
        controller_output, controller_state = self._controller(core_output, initial_controller_state)




#### Backpropagation


In [None]:
learner_outputs, final_agent_state, custom_rnn_output, torso_outputs = agent.unroll(
        agent_outputs.action, env_outputs,
        agent_state, write_to_collection=True)

rnn_v = custom_rnn_output[1][..., 0]
rnn_thr = FLAGS.thr + agent.core.beta * custom_rnn_output[1][..., 1]
rnn_pos = tf.nn.relu(rnn_v - rnn_thr)
rnn_neg = tf.nn.relu(-rnn_v - rnn_thr)
voltage_reg_rnn = tf.reduce_sum(tf.reduce_mean(tf.square(rnn_pos), 1))
voltage_reg_rnn += tf.reduce_sum(tf.reduce_mean(tf.square(rnn_neg), 1))
rnn_rate = tf.reduce_mean(custom_rnn_output[0], (0, 1))
rnn_mean_rate = tf.reduce_mean(rnn_rate)
analysis_tensors['rnn_rate'] = rnn_mean_rate
rate_loss = tf.reduce_sum(tf.square(rnn_rate - .02)) * 1.
torso_from_collection = tf.get_collection('torso_output')[-1]

In [None]:
with tf.name_scope('Output'):
    w_out_init = rd.randn(FLAGS.n_lstm * 2, dataset.n_phns) / np.sqrt(FLAGS.n_lstm * 2)  # original
    w_out = tf.Variable(w_out_init, dtype=tf.float32)
    if FLAGS.eprop in ['random']:
        BA_out = tf.constant(rd.randn(FLAGS.n_lstm * 2, dataset.n_phns),
                             dtype=tf.float32, name='BroadcastWeights')

        BA_out = tf.get_variable(name="BAout", initializer=BA_out, dtype=tf.float32)
        phn_logits = BA_logits(outputs, w_out, BA_out)
    elif FLAGS.eprop in ['adaptive']:
        BA_out = tf.constant(rd.randn(FLAGS.n_lstm * 2, dataset.n_phns) / np.sqrt(FLAGS.n_lstm * 2),
                             dtype=tf.float32, name='BroadcastWeights')

        BA_out = tf.get_variable(name="BAout", initializer=BA_out, dtype=tf.float32)
        phn_logits = BA_logits(outputs, w_out, BA_out)
    else:
        phn_logits = einsum_bij_jk_to_bik(outputs, w_out)
    b_out = tf.Variable(np.zeros(dataset.n_phns), dtype=tf.float32)
    phn_logits = phn_logits + b_out


# Define the graph for the loss function and the definition of the error
with tf.name_scope('Loss'):
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=phns, logits=phn_logits)
    loss = tf.reduce_mean(loss)
    if FLAGS.l2 > 0:
        losses_l2 = [tf.reduce_sum(tf.square(w)) for w in tf.trainable_variables()]
        loss += FLAGS.l2 * tf.reduce_sum(losses_l2)

    phn_prediction = tf.argmax(phn_logits, axis=2)
    is_correct = tf.equal(phns, phn_prediction)
    is_correct_float = tf.cast(is_correct, dtype=tf.float32)
    ler = tf.reduce_sum(is_correct_float * weighted_relevant_mask, axis=1)
    ler = 1. - tf.reduce_mean(ler)

# Define the training step operation
with tf.name_scope('Train'):
    if not FLAGS.adam:
        train_step = tf.train.GradientDescentOptimizer(lr).minimize(loss, global_step=global_step)
    else:
        train_step = tf.train.AdamOptimizer(lr, epsilon=FLAGS.adam_epsilon).minimize(loss, global_step=global_step)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()


In [None]:
import tensorflow as tf
from tensorflow import keras

# Define CNN model
cnn_model = keras.Sequential(...)
cnn_model.compile(...)

# Define RNN model
rnn_model = keras.Sequential(...)
rnn_model.compile(...)

# Define controller function
def controller(hidden_state):
    # Define the logic for predicting actions
    ...

# Define custom loss function
def custom_loss(y_true, y_pred, cnn_output, rnn_output):
    # Calculate individual losses for each model
    cnn_loss = ...
    rnn_loss = ...
    controller_loss = ...

    # Calculate total loss
    total_loss = cnn_loss + rnn_loss + controller_loss

    return total_loss

# Create an optimizer
optimizer = tf.keras.optimizers.Adam()

# Define a training step function
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        # Pass images through CNN
        cnn_output = cnn_model(images)

        # Pass CNN output through RNN
        rnn_output = rnn_model(cnn_output)

        # Predict actions using controller function
        actions = controller(rnn_output)

        # Calculate loss
        loss = custom_loss(labels, actions, cnn_output, rnn_output)

    # Calculate gradients
    gradients = tape.gradient(loss, cnn_model.trainable_variables + rnn_model.trainable_variables)

    # Update CNN and RNN models separately
    cnn_optimizer.apply_gradients(zip(gradients[:len(cnn_model.trainable_variables)], cnn_model.trainable_variables))
    rnn_optimizer.apply_gradients(zip(gradients[len(cnn_model.trainable_variables):], rnn_model.trainable_variables))

# Training loop
for epoch in range(num_epochs):
    for batch_images, batch_labels in train_dataset:
        train_step(batch_images, batch_labels)


#### Extra code


In [None]:
# Load the images. Creating a loop to load all the images
images = []
for i in range(10):
    images.append(mpimg.imread("train\ACRE_train_000000_0" + str(i) + ".png"))

#convert the images to tensors
tensors = []
for i in range(10):
    tensors.append(tf.convert_to_tensor(images[i]))

# Stack the image tensors along a new dimension
images_tensor = tf.stack(tensors)

# The images_tensor now has shape (10, 240, 320, 4), which represents 10 images of size 240x320, with 4 color channels per pixel.

# Resize the images to a specific size (e.g., 224x224)
resized_images = tf.image.resize_images(images_tensor, [128, 128])

# Convert the image data type to float32
float_images = tf.cast(resized_images, tf.float32)

# Normalize the image pixels to the range [0, 1]
normalized_images = float_images / 255.0

normalized_images.shape # (10, 240, 240, 4)

TensorShape([Dimension(10), Dimension(128), Dimension(128), Dimension(4)])

In [32]:
# #Maell's code

# def pad(effective_kernel_size):	 # pylint: disable=unused-argument
# 	"""No padding."""
# 	return [2, 2]

# print(n_w_1)
# print(n_w_2)
# print(n_inputs)

# img = np.random.rand(5, 10,128,128,4)
# img_tensor = tf.convert_to_tensor(img, tf.float32)
# i_conv1 = snt.Conv2D(16, 8, stride=4)(img_tensor)
# print('conv1:', i_conv1.get_shape())
	
# scnn = SpikingCNN()
# scnn_state = scnn.zero_state(10, tf.float32)

# scnn_output, scnn_state = scnn(i_conv1, scnn_state)
# print([o.get_shape() for o in scnn_output])
# print([s.get_shape() for s in scnn_state])

# core = CustomALIF(n_in = n_inputs, n_rec = 100)
# rnn_state = core.zero_state(10, tf.float32)

# rnn_output, rnn_state = core(scnn_output[2], rnn_state)

# print([o.get_shape() for o in rnn_output])





32
7
1568


IncompatibleShapeError: Input Tensor must have rank 4 corresponding to data_format NHWC, but instead was (5, 10, 128, 128, 4) of rank 5.

originally defined at:
  File "c:\Users\saai2\Anaconda3\envs\reward-based-e-prop\lib\site-packages\sonnet\python\modules\conv.py", line 1758, in __init__
    custom_getter=custom_getter, name=name)
  File "c:\Users\saai2\Anaconda3\envs\reward-based-e-prop\lib\site-packages\sonnet\python\modules\conv.py", line 464, in __init__
    super(_ConvND, self).__init__(custom_getter=custom_getter, name=name)
  File "c:\Users\saai2\Anaconda3\envs\reward-based-e-prop\lib\site-packages\sonnet\python\modules\base.py", line 180, in __init__
    custom_getter_=self._custom_getter)
  File "C:\Users\saai2\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\ops\template.py", line 160, in make_template
    **kwargs)
