# A simple recurrent network

In this tutorial we will learn how to build, train, and test a simple recurrent network (SRN) in the context of the AXBY task. We will do this using python, and the focus of the tutorial is on building an SRN using the tensorflow (via keras) deep learning library. We will first motivate the use of an SRN for the task by showing an alternative implementation using basic python functions.

This tutorial assumes you have read the earlier mini-tutorial on feed forward networks. If you are not familiar with feed forward networks, you should first familiarize yourself with the course materials on those simpler networks.

## AXBY task
SRNs are networks that take inputs and generate outputs over several sequential steps of time. Many human behaviors we study in cognitive science have this character. SRNs have been instrumental in developing theories in several cognitive domains, including working memory and cognitive control, memory for arbitrary serial order, sequential comprehension and production of single words, and sentence comprehension and production. A very useful resource for understanding the core interest and basic operation of SRNs is Jeff Elman’s classic paper from 1990 in Cognitive Science called ["Finding Structure in Time"](https://onlinelibrary.wiley.com/doi/pdf/10.1207/s15516709cog1402_1).

In the standard AX task, people view a sequence of letters and are instructed to press a button whenever an X appears, UNLESS the X was preceded by an A, in which case they should NOT press the button. This task, though very simple, draws out some important characteristics of cognitive control and working memory. Specifically, to perform correctly the participant must (a) retain a representation of a prior event (the preceding letter), (b) continually update this representation with each new letter in the sequence and (c) use the representation of the prior event to constrain the response generated to the current item. In particular, when an X appears following an A, the participant must use the “memory” of the preceding A to *inhibit* the response usually associated with the X (the button press).

### Instructions
The task variant we will study here adds a further level of complexity. In the AXBY task, participants must press button 1 whenever an X appears, UNLESS the X was preceded by an A in which case no button should be pressed, UNLESS the A was preceded by the letter B, in which case button 2 should be pressed. This task thus requires the participant to retain and continually update information about the preceding two letters in the sequence, and to use this conjoint representation to shape the response generated for the X (button 2 if preceded by B then A; otherwise no button if preceded by A; otherwise button 1). Extensive research on tasks similar to this have suggested that these abilities are supported by structures in the prefrontal cortex, and that behaviors on such tasks can be understand by SRN-like mechanisms.

### Intuition

One possible intuitive solution to modeling this process (without using an artificial neural network) is to store the stimuli in a 3-slot memory buffer. Based on the pattern in the buffer, decide whether to press button 1, 2, or not to press any button (0).

All possibilities:

    - 9 possible patterns in the buffer ending with X:

        - A, A, X -> 0
        - A, B, X -> 1
        - A, X, X -> 1
        - B, A, X -> 2
        - B, B, X -> 1
        - B, X, X -> 1
        - X, A, X -> 0
        - X, B, X -> 1
        - X, X, X -> 1
        
    - all other patterns

        - ?, ?, not X -> 0

## Solve AXBY task using basic python functions and simple data structures

Due to the simplicity of this task, it is not difficult to solve it with some basic python programming, essentially relying on conditionals. The following code would be one such solution to this task. 

*NOTE:* The functions and objects created in this section will be used later when developing the neural network, so make sure to run the cells!

In [None]:
def do_task(inputs):
    """Function to process AXBY task.
    inputs: A list with three elements indicating
    three successive stimuli in the task."""

    last_last_input = None
    last_input = None
    outputs = []

    for input in inputs:
        if input == 'X':
            if last_input == 'A':
                if last_last_input == 'B':
                    outputs.append(2)
                else:
                    outputs.append(0)
            else:
                outputs.append(1)
        else:
            outputs.append(0)

        last_last_input = last_input
        last_input = input

    return outputs


Let's test the function:

In [None]:
# test_cases is a list containing all the inputs and expected outputs
# each element of the test_cases list is a tuple (11 total)
# the first element of each tuple is a list with the input (e.g., ['A', 'A', 'X'])
# the second element is an integer expressing the expected output (0, 1, or 2)
test_cases = [
    (['A', 'A', 'X'], 0), 
    (['A', 'B', 'X'], 1), 
    (['A', 'X', 'X'], 1), 
    (['B', 'A', 'X'], 2), 
    (['B', 'B', 'X'], 1),
    (['B', 'X', 'X'], 1),
    (['X', 'A', 'X'], 0),
    (['X', 'B', 'X'], 1),
    (['X', 'X', 'X'], 1),
    (['X', 'X', 'B'], 0),
    (['X', 'X', 'A'], 0),
]

# write a loop to apply do_task() to
for test_case in test_cases:
    inputs, expected_output = test_case
    output = do_task(inputs)
    last_output = output[-1]
    print(f"{inputs=}, {output=}, {last_output=}, {expected_output=}, {last_output==expected_output=}")

All the tests pass. Just to make sure, let's randomly generate some cases and double-check.

In [None]:
import pandas as pd
from random import choices


seq_len = 20 # create a desired sequence length
inputs = choices(['X', 'A', 'B'], k=seq_len) # select that many stimuli from our three possible
pd.DataFrame({'inputs': inputs, 'outputs': do_task(inputs)}).T # create a data-frame with the inputs/ outputs
# for more information about what is going on in the above line, look up "dictionary comprehension"

Looks like this approach works - all the outputs are what is required based on the task description for AXBY. But this type of approach is not scalable to more complex problem. e.g., computer vision, speech recognition, etc.

## Solve AXBY task with RNN
Let's now approach the task with a recurrent neural network.

We can begin with a schematic drawing of a recurrent neural network that can solve this sequential task. We will go into detail about the network later, but use this simple illustration to guide our thinking.


<br>
<img src='https://drive.google.com/uc?id=1dhsDSGK7Af_Zg9Y0TEjbRvzvLdRxErV4'>


The model gets direct input for the currently-viewed letter, which in this case can be X, A or B. Input units send feed-forward connections to a hidden layer, which in turn project to an output layer containing two units. The first unit indicates a press for button 1; the second indicates a press for button 2. If no button is to be pressed, neither unit should be activated. Note that all of these connections are feed forward.

Hanging off the side of the hidden layer is a second layer of units, labeled “Context.” The context units, like the input, send feed-forward connections to the hidden layer. However, they also _receive_ connections from the hidden layer, which are shown as a dotted line labeled t-1. These connections play a special role: rather than transmitting information from sending to receiving units via weighted connections in the usual way, they instead simply _copy_ the activation pattern from the previous point in time to the units in the receiving units. It is these connections that makes the network “recurrent.” The hidden layer and context layer together are sometimes jointly referred to as a “recurrent” layer.

Note that the context layer has the same number of units as the hidden layer, and simply contains a copy of the hidden layer activation pattern on the previous timestep. They can be viewed as containing a "memory" of the previous hidden state. The weights projecting from context to hidden layers act just like all other weights in the network, influencing the pattern of activation on the current time step (together with the current input and the bias). These weights can be viewed as capturing how a "memory" of the prior hidden pattern should influence the current pattern.


### Data preparation (manually, showing all steps)

First import these modules to aid in constructing the input and output patterns.

In [None]:
import numpy as np
import tensorflow as tf

Let's look at the sequence of inputs for the task generated in the earlier code block:

In [None]:
print(f"{inputs=}")

1. Encode these inputs into numbers:

In [None]:
input_mapping = {'X': 0, 'A': 1, 'B': 2}
numeric_inputs = [input_mapping[input] for input in inputs]
print(f"{numeric_inputs=}")

2. Create one-hot encodings of inputs

You can see from the printed array below that each list of binary values has one unit set to `1` (the "hot" unit/ node) and the others set to `0`. The representation for the numeric input `0` has the 0th unit set to `1`, the rep for input `1` has the 1st unit set to `1`, and the rep for `2` has the 2nd unit set to `1`. (Remember that in `python` we always call the 1st element in a thing the 0th element, and `0` is its index in that thing)

In [None]:
one_hot_inputs = tf.one_hot(numeric_inputs, depth=len(input_mapping))
print(f"{one_hot_inputs=}")

#### Create the input pipeline
When we create workflows in `python` (and other languages) it is best to automate processes as much as possible. Here, let's rewrite the code above (i.e., we *refactor* our previous code) into better code that automates the process. We call an automated process like this a *pipeline*.

In [None]:
def get_case(seq_len=20, verbose=False):
    """Create a case of inputs and outputs."""

    input_mapping = {'X': 0, 'A': 1, 'B': 2}

    # Input
    input_tokens = list(input_mapping.keys())
    inputs = choices(input_tokens, k=seq_len)
    encoded_inputs = [input_mapping[input] for input in inputs]
    tensor_inputs = tf.one_hot(encoded_inputs, depth=len(input_mapping))

    # Input representation:
    # X: [1, 0, 0]
    # A: [0, 1, 0]
    # B: [0, 0, 1]
    
    # Output (similar to input, except we remove the first bit)  
    expected_outputs = do_task(inputs)
    tensor_outputs = tf.one_hot(expected_outputs, depth=3)
    tensor_outputs = tensor_outputs[:, 1:]  # Remove the first bit. *Technically, not necessary, but it makes the output more intuitive.*

    # output representation: 
    # No button press: [0, 0]
    # Button 1: [1, 0]
    # Button 2: [0, 1]
    
    if verbose:
        print(f"{inputs=}")
        print(f"{expected_outputs=}")

    return tensor_inputs, tensor_outputs

x, y = get_case(verbose=True)
print(f"{x.shape=} and {y.shape=}")


Conceptually we often imagine that a network computes error and adjusts weights after processing a single training example, a mode sometimes calls *online learning.* Online learning is often pretty slow though! In practice it is more common to present a model with a *batch* of training examples&mdash;not a single pattern, and not the entire corpus, but a random sampling of *n* patterns from the corpus, whene *n* is a free hyperparameter. The model processes all patterns in the batch, tabulating the error and weight gradients for each. After all patterns in the batch have been processes, the gradients on each weight are summed together (across patterns), and a single weight update is made. This speeds things up considerably, and also ensures that a single pattern does not have a giant impact on the weight update. 

Let's make a function to get a batch:

In [None]:
def get_cases(n, verbose=False):
    """Get a batch of cases."""
    
    cases = [get_case(verbose=verbose) for _ in range(n)]
    inputs = tf.stack([case[0] for case in cases])
    outputs = tf.stack([case[1] for case in cases])
    return inputs, outputs

batch_x, batch_y = get_cases(5)
print(f"{batch_x.shape=} and {batch_y.shape=}")


The above code shows the shape of the batch objects for the input (batch_x) and output (batch_y) patterns. Recall that (1) there are three input units, (2) we have designed sequences of length 20 for the AXBY task, and the code shown asks for batches of 5 patterns. You can see that the input batch array is of shape [5, 20, 3]&mdash;so you should infer that the first dimension of the array inindicates different patterns in the batch, the second dimension indicates different time-points in the sequence, and the third indicates differen input units. This is the standard arrangement for batch input and output arrays for sequential models in Keras/TensorFlow.

You can see the full sequence of inputs for the first item in the batch just by indexing the array as follows:

In [None]:
batch_x[0,:,:]

## Constructing and training the RNN model with built-in tools

The Tensorflow/KERAS ecosystem has many convenient functions for building and training RNNs. In this section we will use the `keras.layers.SimpleRNN` function to build the model, and the built-in *fit* method to train it.
You can read the [documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers/SimpleRNN) for details. To better understand how learning and processing works in a simple RNN, the next section will construct similar functionality from scratch, with a more considered explanation of the underlying computations.

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import SimpleRNN, Dense
from tensorflow.keras.models import Sequential

# Define the shape of the input, output, and sequence length
seq_len = 20
input_units = 3
output_units = 2

# Create a sequential model, i.e., a simple stack of layers, one after the other:
easy_model = Sequential()  

# Add a simple RNN layer with 10 units. This is the combined hidden-and-context layer from the earlier figure.
#The input shape is (20, 3), i.e., 20 time steps, each with 3 inputs matching our input shape (we use 3 units to represent X, A, B).
easy_model.add(SimpleRNN(units=10, input_shape=(seq_len, input_units), return_sequences=True)) 

# Add a dense layer for the output units.
easy_model.add(Dense(output_units, activation='sigmoid'))  

easy_model.summary()  # Print a summary of the model.

In this method of building a model, the first layer added is interpreted as receiving inputs, so input shape must be specified. Each subsequent layer is assumed to receive inputs from the previously-created layer. Here the whole model is created with three lines of code!

Training the model then involves the same steps employed for feed-forward models:

1. Define the loss function and the optimizer
2. `Compile` the model
3. `Fit` the model with data for some number of (`n`) *epochs* (the number of times the model sees the entire dataset)

In [None]:
epochs = 100
x_train, y_train = get_cases(100)

loss_fn = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.SGD(learning_rate=1)  # Using a large learning rate for this simple toy problem

easy_model.compile(optimizer=optimizer, loss=loss_fn)
easy_model.fit(x_train, y_train, epochs=epochs)

### Building an RNN with custom classes.

To better understand how processing and training work in this model, we will build a model with similar functionality via custom classes and functions.

Before beginning, it is helpful to visualize the information flow in this model in a more detailed illustration, together with some equations to help you map aspects of the computation (and the diagram) to the code. First, here is a more detailed diagram of an RNN "cell": the combination of hidden and context units shown in simple form earlier:

![model architechture](https://drive.google.com/uc?id=1f3kI6XCF6vrtiQFQG1PvpQbTNqbim7uT)


The boxes labelled *RNNCell* contain all of the hidden/context units, while the other chart elements indicate all of the computations carried out by these units in a given timestep. Focusing on time *t*, you can see the units in the cell receive inputs from the hidden units in the previous timestep (arrow labelled $w_{hh}$), from the current input pattern (arrow labelled $x_t$), and from the bias unit $b$. The circle with a plus-sign indicates that these inputs are summed together, while the $\sigma$ indicates that the sum is passed through the sigmoid activation function. These computations are carried out for each hidden unit, creating the tensor of output activations labelled $h_t$. The output activations are passed forward through a layer of weights to the output units at time $t$ (not shown), and also provide input to the hidden units at the next timestep $t+1$ via the weights $w_{hh}$. The entire sequence can then be viewed as iterating this process for $n$ steps, unrolling the network forward in time.

With this understanding, let's create this functionality from more elementary building blocks. First we define an RNNCell class, which defines the computations carried out in a single timestep: 

In [None]:
class RNNCell(tf.keras.layers.Layer):
    """Simple RNN cell."""

    def __init__(self, units):
        super().__init__()
        self.units = units

    def build(self, input_shape):
        self.input_dense = tf.keras.layers.Dense(self.units, use_bias=False) # w_xh
        self.recurrent_dense = tf.keras.layers.Dense(self.units, use_bias=False)  # w_hh
        self.bias = self.add_weight(shape=(self.units,), name='bias')  # b
        self.built = True # attach an attribute that shows that the network has been built

    def call(self, inputs, states):
        hx = self.input_dense(inputs)  # x @ w_xh (@: matrix multiplication)
        hh = self.recurrent_dense(states)  # h @ w_hh
        outputs = hx + hh + self.bias
        outputs = tf.sigmoid(outputs)
        return outputs, outputs  # technically outputs, states, but they are the same


Next we create a `SimpleRNN` class that loops through RNNCell computations in each step of a sequence:

In [None]:
class SimpleRNN(tf.keras.layers.Layer):
    """Simple RNN Layer."""

    def __init__(self, units):
        super().__init__()
        self.units = units

    def build(self, input_shape):
        self.rnn_cell = RNNCell(units=self.units)
        self.built = True  # Boiler-plate

    def call(self, inputs, states=None):

        batch_size = tf.shape(inputs)[0]
        seq_len = tf.shape(inputs)[1]

        # Rule of thumb: do not use python lists inside a keras layer's call method
        # tf.TensorArray is a better alternative, see docs for details
        outputs = tf.TensorArray(tf.float32, size=seq_len)  

        # Initialize the hidden state
        states = tf.zeros([batch_size, self.units])

        for t in range(seq_len):
            this_input = inputs[:, t, :]
            states, output = self.rnn_cell(inputs=this_input, states=states)
            outputs = outputs.write(t, output)  # Analogous to list.append()

        outputs = outputs.stack()  # output with wrong shape (seq_len, batch_size, units)
        return tf.transpose(outputs, [1, 0, 2])  # output with correct shape (batch_size, seq_len, units)

Finally, we take our `SimpleRNN` class and build it into a full model to predict the output. The class defining the AXBY model could look like this:

In [None]:
class AXBY(tf.keras.Model):

    def __init__(self, rnn_units):
        super().__init__()
        self.rnn = SimpleRNN(units=rnn_units)
        self.dense = tf.keras.layers.Dense(units=2, activation='sigmoid')

    def call(self, inputs):
        x = self.rnn(inputs)
        return self.dense(x)

model = AXBY(rnn_units=10)

print(f"{model(batch_x).shape=}")
model.summary()


The new model is doing something very similar to the previous code in `easy_model`, but since we have defined the core classes ourselves, it is easier to see exactly what is going on in each layer, and potentially easier to modify or tailor the functionality of recurrent layer.

## Training the custom model

To better understand how a training loop might work, let's build it for the custom model using a lower-level API.

In [None]:
# me

In [None]:
loss_fn = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.SGD(learning_rate=1)

@tf.function  # Decorator to compile the function into a graph, and make it run faster
def train_step(x, y):  # x and y in this case are the inputs and targets of the model
    with tf.GradientTape() as tape:  # Record operations for automatic differentiation
        y_pred = model(x, training=True)  # Forward pass
        loss_value = loss_fn(y, y_pred)  # Compute the loss value
    grads = tape.gradient(loss_value, model.trainable_weights)  # get gradients: dL/dW in all trainable weights
    optimizer.apply_gradients(zip(grads, model.trainable_weights))  # Update all trainable weights and biases
    return loss_value  # Return the loss value so we can print it later

Advanced topic: Learn from Andrej Karpathy about how modern machine learning frameworks magically get the gradients (reverse-mode autodiff) [here](https://www.youtube.com/watch?v=VMj-3S1tku0&list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ&index=1&t=492s).

In [None]:
# model.fit() equivalent
# Train with the same data for 100 epochs
for epoch in range(epochs):
    loss_value = train_step(x_train, y_train)
    print(f"Epoch {epoch + 1}: loss = {loss_value}")

We are using a loss function called _binary crossentropy_. Additionally we are using _stoachastic gradient descent_ as the optimizer, and printing out performance using our measure for loss. You can, of course, measure performance in any way you'd like, not just with the loss function. Different metrics have different quantitative structure and therefore offer different descriptive characteristics, so you can decide on one (or more) that work well for your purpose.

## Evaluate the results

There are a number of ways that you might evaluate the performance of the model. The method that you choose with a model that you develop will be unique to your particular application and your analytic goals.

Here we provide a workflow that allows you to decode x and y so that you can _visually_ inspect the results. This is often useful when the output can be coerced into a human readable format of some kind.

In [None]:
import numpy as np

def decode_y(y, is_prediction=False):
    """Decode output from one-hot encoding to a list of number."""

    if isinstance(y, tf.Tensor):
        y = y.numpy()

    if is_prediction:
        # y_pred is float, so we need to round it to 0 or 1
        if len(y) != 1:
            raise ValueError(f"Prediction should be a single case, but got {len(y)} cases.")
        y = y[0].round(0)

    decoded = []
    for output in y:
        
        if all(output == [0, 0]):
            decoded.append(0)
        elif all(output == [1, 0]):
            decoded.append(1)
        elif all(output == [0, 1]):
            decoded.append(2)
        else:
            raise ValueError(f"Invalid y: {output}")

    return decoded

Here we evaluate a random case:

In [None]:
def evaluate(model):
    """Randomly generate a case and evaluate the expected outputs vs. model's prediction."""

    x, y = get_cases(n=1, verbose=True)
    pred_y = model(x)
    print(f"Model prediction: {decode_y(pred_y, is_prediction=True)}")


In [None]:
evaluate(model)

The performance of the model is not perfect, but it does alright (especially for button 1). Notice that button 2 is way more difficult to learn due to fewer examples in the training set. You can also think through yourself what would allow this model to learn better. What aspects of the model structure and specification here could be augmented to enhance performance?

Exercises:

- Try to improve the accuracy by any means you can think of.

## Summary
So, in this module you've learned about simple recurrent neural networks, how to specify an RNN cell (single timestep) and fold more than one cell into an architecture that can handle multiple timesteps. Additionally, you have learned about how to apply the model in training and evaluation routines, along with some experience using the special jargon we use when talking about artificial neural networks.

Next we will move on to a more complex model RNN architecture that uses _continuous_ computation, though we will build on many of the concepts we've covered here.