# Simple recurrent neural network with time-averaging

In [None]:
# Install a pip package created for our course
!pip install connectionist

In [None]:
import tensorflow as tf
import numpy as np
from connectionist.layers import MultiInputTimeAveraging

This tutorial follows on from the [previous module](https://colab.research.google.com/drive/1Fiy4g4h3B2oLJ5co4ZHP2Z7aLh-pNsSW#scrollTo=I_yKZ_Uym-g8) on simple recurrent networks (SRNs). In that unit, all units update their activations instantaneously at each time step. That is, at each time-step, each unit computes its net inputs, then directly sets its activation to the corresponding activation specified by the activation function. In neural systems and for many interesting phenomena in cognitive science, however, it is useful to think of unit activation change as itself occurring in continuous time. If a unit currently has very low activation, but is receiving strong excitatory input, we might imagine that, instead of "jumping" immediately to a high activation state, instead the unit adjust its activation in gradual increments over time. As its activation adjusts, the inputs it contributes to downstream units also gradually adjust, so the propagation of activation in a single step of some processing sequence is viewed as occurring in continuous time. (In the literature this is sometimes referred to as *cascaded* activation). Such continuous-time change is potentially important for understanding behaviors that take place in time (e.g. reaction time data), or for understanding how and when different information sources combine during processing in some task.

In this module we will build a recurrent network that simulates continuous-time unit activation via *multiple input time-averaging* (MITA). Understanding continuous time-averaging is an important step in working with continuous and fully recurrent networks, the topic of the next module.



## Spelling model

To start, let’s build the network shown below, which learns to map from localist representations of words to temporally distributed representations of their spelling. The network takes a one-hot word input and outputs the corresponding ordered sequence of letters. To achieve this, the model has dense feed-forward connectivity from word input to a hidden layer and from the hidden layer to output units. This is a simpler version of the "fully recurrent" version of the model that comes in the next module.

<div>
<br>
<img src='https://drive.google.com/uc?id=1-1msHzJbXizDk8AzliaWZCPXl0ouQcIg' width="250"/>
<div>

The training patterns will include a single, static input for each word, but temporally distributed output patterns indicating, in correct order, the three letters for each word. To keep things simple, we will use a restricted number of letters in each position. In the figure the network is laid out so that there are 3 possible letters in the first position (top row), three possible vowels (middle row), and three possible letters in the third position (bottom row). There are 20 English 3-letter-words that can be formed from these letters, one for each inputs unit, and we will train the network on all of these.

This module is in part based on the tutorial [here](http://concepts.psych.wisc.edu/index.php/resources/quick-lens-tutorial-for-fully-recurrent-networks/), where the fully recurrent extension of this architecture is developed in a subsequent module in this series.

## Training patterns
The data you need to train and test the network are pretty simple, and based on human language patterns that are easy to specify ourselves with some simple commands, so we will walk through how to do it.

In [None]:
letters = 'crsaoutbn' # here we will use lowercase to make encoding simpler
words = ['cat','cab','can','cot','cob','con','cut','cub','rat','ran','rot','rob','rut','rub','run','sat','sot','sob','son','sun']
print(len(words))

### Encoding
We've specified the letters we need to represent for the outputs, and the words that will comprise the inputs. Now we need to take these string representations and turn them into machine readable vectors.

In [None]:
n = len(words)
word_repr = np.zeros((n, n))
spelling_repr = np.zeros((n, 3, 9))

for i, word in enumerate(words):
    word_repr[i][i] = 1.

    for j, letter in enumerate(word):
        spelling_repr[i][j][letters.index(letter)] = 1.
    

This is what the first input looks like for the word `cat`:

In [None]:
print(word_repr[0])

This is what the first output looks like, with each row representing a step in the sequence, that is, which letter should be generated in the first, second, and third step of the sequence. The first row is the first letter `C`, the second row is the second letter `A`, and the third row is the third letter `T`.

In [None]:
print(spelling_repr[0])

This is the current shape of the entire input and output; note that we will need to reshape it with respect to the time dimension of each pattern.

In [None]:
print(f"Shape of word (input) representation: {word_repr.shape=}")
print(f"Shape of spelling (output): {spelling_repr.shape=}")

### Repeat/Stack to get a matching time dimension

In the simple recurrent network with instantaneous updating, each step of the sequence corresponds to one element of the corresponding time dimension in input and output tensors. For continuous time, units cannot fully update in a single step&mdash;you need to provide room for several activation updates between steps of the sequence. A single update pass is sometimes called a *tick* of time, and continuous models specify that each step of a sequence will take a certain number of ticks to complete. In this example, we will allow each step of the sequence to run for 10 ticks. Since each sequence has 3 steps, that means we need input and output tensors with 30 elements along the time dimension. Also, recall that the time dimension is the second dimension of the tensor, and that Python starts counting at 0, so time extends along dimension 1 of the input and output tensors. So:

- An input pattern will remain the same across time, with the word representation copied 30 times along a new axis 1, taking the shape from (20, 20) to (20, 30, 20).
- The output is time varying, so the letter representation is copied 10 times along the existing axis 1. This will take the shape from (20, 3, 9) to (20, 30, 9).

In [None]:
# Stack 30 copies of word_repr on top of each other, i.e., clamping input in all 30 time steps equal
x_train = np.stack([word_repr for x in range(30)], axis=1)

# Repeat 10 times at axis 1, i.e., e.g., cat -> ccccccccccaaaaaaaaaatttttttttt
y_train = np.repeat(spelling_repr, 10, axis=1)

print(f"Shape of inputs: {x_train.shape=}")
print(f"Shape of outputs: {y_train.shape=}")

## Create an RNN with time-averaging

To create this model you will again use the connectionist package along with `keras`. In order to build a model (using the keras `Sequential` API), you can call the `TimeAveragedRNN` layer which is available in `connectionist.layers`. This syntax is just like what you will find with other `keras` models using the `Sequential` API. Here is a simple workflow to put this RNN layer to work:

In [None]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from connectionist.layers import TimeAveragedRNN

model = Sequential()
model.add(TimeAveragedRNN(tau=0.1, units=5))  # Unlike simple RNN, returns_sequences=True is not needed, it is the default behavior.
model.add(Dense(units=9, activation='softmax'))

y = model(x_train)
print(f"{y.shape=}")
model.summary()

The construction of the model is just like the simple recurrent case, but using the `TimeAveragedRNN` layer in place of the standard RNN. This layer requires a parameter $\tau$ (i.e. *tau*) that determines how closely continuous time is simulated&mdash;the smaller the value, the less a model updates its activation on each tick, and the more ticks are needed for it to reach the state indicated by its inputs. The number of ticks needed for each step of the sequence is determined by $1/\tau$. Here $\tau=0.1$ so 10 ticks per sequence step are needed.

The block of code applies the (untrained) model to the input patterns in `x_train` and generates the outputs stored in `y`. From the printed output, you can see that this has shape `(20, 30, 9)`: 20 words, each run for 30 ticks of time, with activation values generated over the 9 output units.

### Building an entire TimeAveragedRNN layer

Now that we know we can just call the `TimeAveragedRNN()` at a high level, let's look at  what is inside the layer in detail. As before we will first consider the RNN cell (what happens in one time step), then how a cell is used to comprise an RNN layer (the loop over time steps).

### Schematic and equation for RNNCell with time-averaging



![model architechture](https://drive.google.com/uc?id=1K-A0fQuShnhbdwf0lTqhvOVUCqRuCcq5)

Let's look first at the equation for $h_t$, the hidden unit activation at time $t$, in the top right. Each unit's activation is a weighted average of two terms, *old* and *new*, with the weighting determined by the parameter $\tau \in [0,1]$. Here, *old* is just the activation of the hidden units on the previous timestep, $h_{t-1}$, while *new* indicates what the unit activations should be according to their current inputs. Thus *new* is simply the sigmoid activation function applied to the unit net inputs, which in turn are the sum of inputs from $h_{t-1}$, from the current input $x_t$, and from the bias weight $b$. Conceptually, the activation will move a fraction of the way from its prior state toward the state suggested by its current inputs, where the fraction is determined by $\tau$.

These operations are shown graphically within the RNN cell  at time $t$. Here the `+` operator indicates addition, the `X` operator indicates multiplication, and the $\sigma$ node indicates application of the sigmoid activation function.

With this understanding of the computations within a continuous time RNN, let's look at defining the `TimeAveragedRNNCell` class. NOTE that this class definition has the same functionality as the class included in the `connectionist` module, but is implemented a bit differently in the module to allow for simulation of varying forms of damage.

In [None]:
class TimeAveragedRNNCell(tf.keras.layers.Layer):
    def __init__(self, tau, units):
        super().__init__()
        self.tau = tau
        self.units = units

    def build(self, input_shape):
        self.input_dense = tf.keras.layers.Dense(self.units, use_bias=False)  # w_xh
        self.recurrent_dense = tf.keras.layers.Dense(self.units, use_bias=False)  # w_hh
        self.time_averaging = MultiInputTimeAveraging(tau=self.tau, average_at="after_activation", activation='sigmoid')  # "time-averaging mechanism" with multiple inputs
        self.built = True

    def call(self, inputs, states=None):
        xh = self.input_dense(inputs)  # x @ w_xh

        if states is None:
            hh = tf.zeros_like(xh)
        else:
            hh = self.recurrent_dense(states)  # h @ w_hh

        outputs = self.time_averaging([xh, hh])  # sigmoid (tau * (xh + hh + bias) + (1 - tau) * last activation)
        return outputs, outputs  # Just to be consistent with the RNN API, one for state and one for output

    def reset_states(self):  # TODO: This need another name?
        self.time_averaging.reset_states()  # Reset the states of the time-averaging mechanism (last activation = None)

Here is how an individual cell is structured along with the loop over `time_ticks` that governs the temporal flow of the entire time-averaged layer (also available in connectionist package, so you can call it directly there).

In [None]:
class TimeAveragedRNN(tf.keras.layers.Layer):
    def __init__(self, tau, units):
        super().__init__()
        self.tau = tau
        self.units = units

    def build(self, input_shape):
        self.rnn_cell = TimeAveragedRNNCell(tau=self.tau, units=self.units)
        self.built = True

    def call(self, inputs):
        max_ticks = inputs.shape[1]  # (batch_size, seq_len, input_dim)
        outputs = tf.TensorArray(dtype=tf.float32, size=max_ticks)
        states = None

        for t in range(max_ticks):
            this_tick_input = inputs[:, t, :]
            states, output = self.rnn_cell(this_tick_input, states=states)
            outputs = outputs.write(t, output)

        # states persist across tick, but not across batches, so we need to reset it
        self.rnn_cell.reset_states()
        outputs = outputs.stack()  # (seq_len, batch_size, units)
        outputs = tf.transpose(outputs, [1, 0, 2])  # (batch_size, seq_len, units)
        return outputs

### Some notes about implementing time
The network will always infer the number of time ticks from the input shape. The dimension of the shape that specifies the temporal length of the sequence is the 1st element, when you call `input.shape`. For example in the shape (20, 30, 20), the number of timesteps is `30`. This is the TensorFlow/Keras convention, and you can learn more about shapes in the tutorial module.

Also note that we use outputs of type `tf.TensorArray`. This will allow for long temporal sequences and avoids the sorts of memory issues that you find with other data types.

# Training

We don't really need to customize the training process in this example. We can just use Keras's `model.fit()` to train the model. Pass the train and test data as the first two arguments, and optionally specify some number of epochs over which training takes place.

In [None]:
loss_fn = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.SGD(learning_rate=10)  # Using a large learning rate for this simple toy problem

model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(x_train, y_train, epochs=1000)


## Evaluate the model
We have a model that produces language as output, which makes it easy to inspect the performance of the model. Here is an easy function to produce the letter sequence corresponding to the predictions for each input (word). We use `argmax()` because we use the most active unit in each predicted segment as a proxy for the unit that is "on". By identifying the element in each segment is on, we can determine which letter the model predicts as the correct letter among all letters for that particular timestep.

In [None]:
def decode_prediction(y, idx=None):
    """Decode the prediction."""

    decoded = ''.join([letters[np.argmax(v)] for v in y])
    return decoded

In [None]:
y_pred = model(x_train)

for i, word in enumerate(words):
    print(f"word: {word}; pred: {decode_prediction(y_pred[i])}")


## Examine the temporal dynamics of the output

The time-averaging mechanism essentially slows down the information flow in the network. One result of this fact is that we have the ability to examine how output changes over time. This is very useful in a number of ways, one of which is understanding the ways that the model progressively differentiates segments of the output, in this case letters.

In the graph below we see the unit activations for each letter over each timestep for the first predicted word (cat). We can see that the output slowly changes over time, and that the activation of particular letter units (indeed those that correspond to the proper letter for the spelling of the word at different timeslices of the output!) spike in a predictable and differentiated way, temporally. Notice the "ramping up" of activation that takes place. This dynamic is made possible through the continuous time-averaging of the activation taking place in the network. This example is for the pattern of activation over the output layer. We will examine similar dynamics for the hidden layer later.

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

# Get first prediction `cat` in proper type and shape
pred0 = y_pred[0].numpy().squeeze()

# How the prediction changes over time
fig, ax = plt.subplots()
ax.plot(pred0)
ax.legend(letters)