#Training RNNs on cognitive tasks

In [14]:
# # Uninstall the current Gym version
# !pip uninstall -y gym

# # Install Gym version 0.23.1
# !pip install gym==0.23.1

# # Restart the runtime after installation (necessary in some environments like Colab)
# import os
# os._exit(00)


  and should_run_async(code)


#Introduction

In lab 1 we explored the architecture and dynamics of Recurrent Neural Networks (RNNs). Now, we transition from understanding the mechanics of RNNs to deploying them effectively on cognitive tasks. We’ll explore how these networks, inspired by the recurrent connections in our brain, can be trained to perform tasks that mimic cognitive functions. Engaging in such exercises not only offers insights into artificial intelligence but also sheds light on the computational capabilities of our own neural circuits.

We will train our network to perform a perceptual decision-making task. In the laboratory, the test subject (human or animal) is shown moving dots on a screen, and must respond to indicate whether most dots are moving to the left or right. By recording from different brain areas, neuroscientists have been able to isolate the brain areas where the evidence accumulates in order to make this type of perceptual decision [(review paper)](https://www.cell.com/neuron/fulltext/S0896-6273(13)00999-9?script=true&code=cell-site).
Let's take a closer look at how this cognitive task is performed in real life to deepen our understanding. Here is a [link](https://www.youtube.com/watch?v=oDxcyTn-0os&ab_channel=PamelaReinagelatUCSD) to a video featuring a rat executing this perceptual decision-making task.

We will build and train our network using pytorch, and then do the same using only numpy, to understand how the pytorch magic works.

Now, let's proceed with our main topic for today - training RNNs on cognitive tasks!

### Installing and importing relevant packages

In [15]:
# Install neurogym to use cognitive tasks
! git clone https://github.com/neurogym/neurogym.git
%cd neurogym/
! pip install -e .

fatal: destination path 'neurogym' already exists and is not an empty directory.
/content/neurogym/neurogym
Obtaining file:///content/neurogym/neurogym
[31mERROR: file:///content/neurogym/neurogym does not appear to be a Python project: neither 'setup.py' nor 'pyproject.toml' found.[0m[31m
[0m

In [16]:
# Import common packages
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import time

## Defining a recurrent neural network

In general, recurrent neural networks transform **sequence to sequence**. In the context of cognitive neuroscience, the sequence is usually a time series of task input or output. Recall the sequence we produced in Tutorial 1 by executing a forward pass through an RNN?

Let's understand the input and output dimensions of a typical recurrent network in machine learning, LSTM networks.

(Usage example adopted from pytorch documentation)

In [17]:
# Make a LSTM, input_size is the dimension of inputs,
# hidden_size is the number of hidden neurons
rnn = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)

# Generate some mock inputs
input = torch.randn(5, 3, 10)  # The arguments represent (Sequence Length, Batch Size, Input Size). Typically, in neuroscience,
# sequence length would correspond to time points in the time series, Batch size corresponds to the number of trials and
# input size corresponds to the dimension of the input (ie., the number of neurons or channels you're collecting data from)
output, (hn, cn) = rnn(input)

print('Output shape is (SeqLen, BatchSize, HiddenSize):', output.shape)

Output shape is (SeqLen, BatchSize, HiddenSize): torch.Size([5, 3, 20])


##**Defining a Leaky Recurrent Neural Network (Leaky RNN)**

Neuroscientists often prefer **Leaky Recurrent Neural Networks (Leaky RNNs)** due to their ability to accurately model the continuous and dynamic nature of biological neural processes. Leaky RNNs can mimic the temporal dynamics and adaptive learning capabilities of biological neural networks, providing a closer approximation to real neurological processes. Furthermore, their robustness in handling noisy environments, capability to generate complex behaviors, and applicability in studying real-time interactions and sensorimotor coordination make them a valuable tool in neuroscience research and experimentation.

Let us define a continuous-time leaky recurrent neural network,
\begin{align}
    \tau \frac{d\mathbf{a}}{dt} = -\mathbf{a}(t) + f(W_{a\rightarrow
a} \mathbf{a}(t) + W_{x\rightarrow a} \mathbf{x}(t) + \mathbf{b}_1).
\end{align}

Where,

$a(t)$ is the vector of neural firing rates (or activations) at time $t$.

$τ$ is the time constant which determines how fast the state approaches its steady-state value.

$f$ is a non-linear activation function applied element-wise.

$W_{a\rightarrow a}$ is the recurrent weight matrix.

$x(t)$ is the input vector at time $t$.

$W_{x\rightarrow a}$ is the input weight matrix.

$b_1​$ is the bias vector.


Let us discretize this network in time using the Euler method with a time step of $\Delta t$,
\begin{align}
    \mathbf{a}(t+\Delta t) = \mathbf{a}(t) + \Delta \mathbf{a} &= \mathbf{a}(t) + \frac{\Delta t}{\tau}[-\mathbf{a}(t) + f(W_{a\rightarrow a} \mathbf{a}(t) + W_{x\rightarrow a}  \mathbf{x}(t) + \mathbf{b}_r)] \\
    &= (1 - \frac{\Delta t}{\tau})\mathbf{a}(t) + \frac{\Delta t}{\tau}f(W_{a\rightarrow a} \mathbf{a}(t) + W_{x\rightarrow a}  \mathbf{x}(t) + \mathbf{b}_r)
\end{align}

Let us now define the network following the dynamics described by the above equation.

In [18]:
class LeakyRNN(nn.Module):
    """Leaky RNN.

    Parameters:
        input_size: Number of input neurons
        hidden_size: Number of hidden neurons
        dt: discretization time step in ms.
            If None, dt equals time constant tau

    Inputs:
        input: tensor of shape (seq_len, batch, input_size)
        hidden: tensor of shape (batch, hidden_size), initial hidden activity
            if None, hidden is initialized through self.init_hidden()

    Outputs:
        output: tensor of shape (seq_len, batch, hidden_size)
        hidden: tensor of shape (batch, hidden_size), final hidden activity
    """

    def __init__(self, input_size, hidden_size, dt=None, **kwargs):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.tau = 100
        if dt is None:
            alpha = 1
        else:
            alpha = dt / self.tau
        self.alpha = alpha

        self.input2h = nn.Linear(input_size, hidden_size)
        self.h2h = nn.Linear(hidden_size, hidden_size)

    def init_hidden(self, input_shape):
        batch_size = input_shape[1]
        return torch.zeros(batch_size, self.hidden_size)

    def recurrence(self, input, hidden):
        """Run network for one time step.

        Inputs:
            input: tensor of shape (batch, input_size)
            hidden: tensor of shape (batch, hidden_size)

        Outputs:
            h_new: tensor of shape (batch, hidden_size),
                network activity at the next time step
        """
        h_new = torch.relu(self.input2h(input) + self.h2h(hidden))
        h_new = hidden * (1 - self.alpha) + h_new * self.alpha
        return h_new

    def forward(self, input, hidden=None):
        """Propogate input through the network."""

        # If hidden activity is not provided, initialize it
        if hidden is None:
            hidden = self.init_hidden(input.shape).to(input.device)

        # Loop through time
        output = []
        steps = range(input.size(0))
        for i in steps:
            hidden = self.recurrence(input[i], hidden)
            output.append(hidden)

        # Stack together output from all time steps
        output = torch.stack(output, dim=0)  # (seq_len, batch, hidden_size)
        return output, hidden


class RNNNet(nn.Module):
    """Recurrent network model.

    Parameters:
        input_size: int, input size
        hidden_size: int, hidden size
        output_size: int, output size

    Inputs:
        x: tensor of shape (Seq Len, Batch, Input size)

    Outputs:
        out: tensor of shape (Seq Len, Batch, Output size)
        rnn_output: tensor of shape (Seq Len, Batch, Hidden size)
    """
    def __init__(self, input_size, hidden_size, output_size, **kwargs):
        super().__init__()

        # Leaky RNN
        self.rnn = LeakyRNN(input_size, hidden_size, **kwargs)

        # Add an output layer
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        rnn_output, _ = self.rnn(x)
        out = self.fc(rnn_output)
        return out, rnn_output

Let's determine the dimensions of its inputs and outputs.

In [19]:
batch_size = 16
seq_len = 20  # sequence length
input_size = 5  # input dimension

# Make some random inputs
input_rnn = torch.rand(seq_len, batch_size, input_size)

# Make network of 100 hidden units and 10 output units
rnn = RNNNet(input_size=input_size, hidden_size=100, output_size=10)

# Run the sequence through the network
out, rnn_output = rnn(input_rnn)

print('Input of shape =', input_rnn.shape)
print('Output of shape =', out.shape)

Input of shape = torch.Size([20, 16, 5])
Output of shape = torch.Size([20, 16, 10])


## Defining a simple cognitive task

Here we use the neurogym package to make a simple "perceptual decision making" task. Let us install the package first.NeuroGym is a curated collection of neuroscience tasks with a common interface. You may explore further [here](https://github.com/neurogym/neurogym)

The code provided below defines a custom environment, PerceptualDecisionMaking, using neurogym. This environment simulates a two-alternative forced choice task where an agent needs to decide which of two stimuli is higher on average, despite the stimuli being noisy. The agent is encouraged to integrate the stimulus over time due to this noise.

Given that the focus of today's tutorial is on training an RNN, after you browse the neurogym website, you can skip over much of this section (except for the first cell), or study at your own pace.

In [32]:
# @title importing neurogym
import neurogym as ngym

# Canned environment from neurogym

task_name = 'PerceptualDecisionMaking-v0'

# Importantly, we set discretization time step for the task as well
kwargs = {'dt': 20, 'timing': {'stimulus': 1000}}


For **supervised learning**, we need a dataset that returns (input, target output pairs).

In [21]:
# Make supervised dataset
seq_len = 100
batch_size = 16
dataset = ngym.Dataset(task_name, env_kwargs=kwargs, batch_size=batch_size, seq_len=seq_len)
env = dataset.env

# Generate one batch of data when called
inputs, target = dataset()
inputs = torch.from_numpy(inputs).type(torch.float)

input_size = env.observation_space.shape[0]
output_size = env.action_space.n

print('Input has shape (SeqLen, Batch, Dim) =', inputs.shape)
print('Target has shape (SeqLen, Batch) =', target.shape)

Input has shape (SeqLen, Batch, Dim) = torch.Size([100, 16, 3])
Target has shape (SeqLen, Batch) = (100, 16)


  logger.warn(
  logger.warn(
  logger.warn(


## Network Training

Let's now train the network to perform the task.

In [22]:
# Instantiate the network and print information
hidden_size = 128
net = RNNNet(input_size=input_size, hidden_size=hidden_size,
             output_size=output_size, dt=env.dt)
print(net)

def train_model(net, dataset):
    """Simple helper function to train the model.

    Args:
        net: a pytorch nn.Module module
        dataset: a dataset object that when called produce a (input, target output) pair

    Returns:
        net: network object after training
    """
    # Use Adam optimizer
    optimizer = optim.Adam(net.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    running_loss = 0
    running_acc = 0
    start_time = time.time()
    # Loop over training batches
    print('Training network...')
    for i in range(2000):
        # Generate input and target, convert to pytorch tensor
        inputs, labels = dataset()
        inputs = torch.from_numpy(inputs).type(torch.float)
        labels = torch.from_numpy(labels.flatten()).type(torch.long)

        # boiler plate pytorch training:
        optimizer.zero_grad()   # zero the gradient buffers
        output, _ = net(inputs)
        # Reshape to (SeqLen x Batch, OutputSize)
        output = output.view(-1, output_size)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()    # Does the update

        # Compute the running loss every 100 steps
        running_loss += loss.item()
        if i % 100 == 99:
            running_loss /= 100
            print('Step {}, Loss {:0.4f}, Time {:0.1f}s'.format(
                i+1, running_loss, time.time() - start_time))
            running_loss = 0
    return net

net = train_model(net, dataset)

RNNNet(
  (rnn): LeakyRNN(
    (input2h): Linear(in_features=3, out_features=128, bias=True)
    (h2h): Linear(in_features=128, out_features=128, bias=True)
  )
  (fc): Linear(in_features=128, out_features=3, bias=True)
)
Training network...
Step 100, Loss 0.1945, Time 7.5s
Step 200, Loss 0.0782, Time 17.2s
Step 300, Loss 0.0607, Time 22.5s
Step 400, Loss 0.0447, Time 26.4s
Step 500, Loss 0.0398, Time 30.4s
Step 600, Loss 0.0325, Time 35.7s
Step 700, Loss 0.0324, Time 39.6s
Step 800, Loss 0.0342, Time 43.6s
Step 900, Loss 0.0328, Time 48.9s
Step 1000, Loss 0.0304, Time 52.9s
Step 1100, Loss 0.0298, Time 56.9s
Step 1200, Loss 0.0301, Time 62.1s
Step 1300, Loss 0.0297, Time 66.1s
Step 1400, Loss 0.0259, Time 70.1s
Step 1500, Loss 0.0273, Time 75.3s
Step 1600, Loss 0.0266, Time 79.3s
Step 1700, Loss 0.0272, Time 83.2s
Step 1800, Loss 0.0261, Time 88.5s
Step 1900, Loss 0.0236, Time 92.5s
Step 2000, Loss 0.0259, Time 96.5s


## Testing the network

Here we run the network after training, record activity, and compute performance. We will explicitly loop through individual trials, so we can log the information and compute the performance of each trial.

In [25]:
# Reset environment
env = dataset.env
env.reset(no_step=True)

# Initialize variables for logging
perf = 0
activity_dict = {}  # recording activity
trial_infos = {}  # recording trial information

num_trial = 200
for i in range(num_trial):
    # Neurogym boiler plate
    # Sample a new trial
    trial_info = env.new_trial()
    # Observation and groud-truth of this trial
    ob, gt = env.ob, env.gt
    # Convert to numpy, add batch dimension to input
    inputs = torch.from_numpy(ob[:, np.newaxis, :]).type(torch.float)

    # Run the network for one trial
    # inputs (SeqLen, Batch, InputSize)
    # action_pred (SeqLen, Batch, OutputSize)
    action_pred, rnn_activity = net(inputs)

    # Compute performance
    # First convert back to numpy
    action_pred = action_pred.detach().numpy()[:, 0, :]
    # Read out final choice at last time step
    choice = np.argmax(action_pred[-1, :])
    # Compare to ground truth
    correct = choice == gt[-1]

    # Record activity, trial information, choice, correctness
    rnn_activity = rnn_activity[:, 0, :].detach().numpy()
    activity_dict[i] = rnn_activity
    trial_infos[i] = trial_info  # trial_info is a dictionary
    trial_infos[i].update({'correct': correct})

# Print information for sample trials
for i in range(5):
    print('Trial ', i, trial_infos[i])

print('Average performance', np.mean([val['correct'] for val in trial_infos.values()]))

Trial  0 {'ground_truth': 0, 'coh': 0.0, 'correct': True}
Trial  1 {'ground_truth': 0, 'coh': 6.4, 'correct': True}
Trial  2 {'ground_truth': 0, 'coh': 51.2, 'correct': True}
Trial  3 {'ground_truth': 1, 'coh': 25.6, 'correct': True}
Trial  4 {'ground_truth': 0, 'coh': 25.6, 'correct': True}
Average performance 0.925


## Backpropagation Through Time (BPTT)

We will now delve into the world of Recurrent Neural Networks (RNNs), gaining an understanding of their functionality and constructing one from the ground up using only NumPy in Python. Having previously explored backpropagation in feedforward neural networks, we now turn our attention to the complexity introduced by temporal dependencies. Frameworks like PyTorch handle BPTT automatically via autograd. Below, we'll modify our training loop to illustrate how BPTT works under the hood.

### Implementing BPTT Step-by-Step
We'll implement BPTT manually to illustrate how it works. This involves:

Forward Pass: Compute the network's output and store necessary variables.
Backward Pass: Compute gradients of the loss with respect to weights by backpropagating errors through time.
Weight Updates: Update the weights using the computed gradients.







### 1. Forward Pass (Modifying the LeakyRNN Class)

We need to store the inputs and hidden states at each time step during the forward pass to use them in the backward pass.

In [26]:
class LeakyRNN(nn.Module):
    """ LeakyRNN with BPTT support."""

    def __init__(self, input_size, hidden_size, dt=None, **kwargs):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.tau = 100
        if dt is None:
            alpha = 1
        else:
            alpha = dt / self.tau
        self.alpha = alpha

        # Define weights explicitly
        self.Wxh = nn.Parameter(torch.randn(hidden_size, input_size) * 0.01)
        self.Whh = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.01)
        self.bh = nn.Parameter(torch.zeros(hidden_size))

    def init_hidden(self, batch_size):
        return torch.zeros(batch_size, self.hidden_size)

    def forward(self, inputs):
        """Forward pass through time, storing variables for BPTT."""
        seq_len, batch_size, _ = inputs.size()
        hidden = self.init_hidden(batch_size).to(inputs.device)

        self.inputs = []   # Store inputs for BPTT
        self.hiddens = [hidden]  # Store hidden states for BPTT

        outputs = []
        for t in range(seq_len):
            input_t = inputs[t]
            self.inputs.append(input_t)
            # Compute pre-activation
            pre_activation = self.Wxh @ input_t.T + self.Whh @ hidden.T + self.bh[:, None]
            pre_activation = pre_activation.T  # Shape: (batch_size, hidden_size)
            # Apply activation function
            hidden = (1 - self.alpha) * hidden + self.alpha * torch.tanh(pre_activation)
            self.hiddens.append(hidden)
            outputs.append(hidden)
        outputs = torch.stack(outputs)
        return outputs

class RNNNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, **kwargs):
        super().__init__()

        # Use our modified LeakyRNN
        self.rnn = LeakyRNN(input_size, hidden_size, **kwargs)

        # Output layer
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        rnn_output = self.rnn(x)
        out = self.fc(rnn_output)
        return out

### 2. Backward Pass (Manual Gradient Computation)



In [33]:
def bptt(net, outputs, targets):
    """Manual Backpropagation Through Time."""
    # Initialize gradients
    dWxh = torch.zeros_like(net.rnn.Wxh)
    dWhh = torch.zeros_like(net.rnn.Whh)
    dbh = torch.zeros_like(net.rnn.bh)
    dWhy = torch.zeros_like(net.fc.weight)
    dby = torch.zeros_like(net.fc.bias)

    # Initialize gradient w.r.t hidden state
    dh_next = torch.zeros(outputs.size(1), net.rnn.hidden_size)

    seq_len, batch_size, num_classes = outputs.size()

    # Compute gradient of loss w.r.t. output logits
    outputs_flat = outputs.view(-1, num_classes)
    outputs_softmax = torch.softmax(outputs_flat, dim=1)
    outputs_softmax = outputs_softmax.view(seq_len, batch_size, num_classes)

    # Create one-hot encoding of targets
    targets_one_hot = torch.nn.functional.one_hot(targets, num_classes=num_classes).float()

    # Compute dy = dL/dy (gradient of loss w.r.t. logits)
    dy = outputs_softmax - targets_one_hot  # Shape: (seq_len, batch_size, num_classes)

    # Loop backward through time
    for t in reversed(range(seq_len)):
        # Gradients for output layer
        ht = net.rnn.hiddens[t+1]  # Hidden state at time t
        dWhy += dy[t].T @ ht
        dby += dy[t].sum(0)

        # Backprop into hidden layer
        dh = dy[t] @ net.fc.weight + dh_next  # Shape: (batch_size, hidden_size)

        # Derivative through activation function
        dtanh = net.rnn.alpha * (1 - ht ** 2) * dh  # Shape: (batch_size, hidden_size)

        # Gradients w.r.t parameters
        xt = net.rnn.inputs[t]  # Input at time t
        ht_prev = net.rnn.hiddens[t]  # Hidden state at time t-1
        dWxh += dtanh.T @ xt
        dWhh += dtanh.T @ ht_prev
        dbh += dtanh.sum(0)

        # Prepare dh_next for next iteration
        dh_next = dh * (1 - net.rnn.alpha) + dtanh @ net.rnn.Whh.T

    # Clip gradients to prevent exploding gradients
    clip_value = 1.0
    for grad in [dWxh, dWhh, dbh, dWhy, dby]:
        grad.clamp_(-clip_value, clip_value)

    # Update weights manually
    learning_rate = 0.0001
    net.rnn.Wxh.data -= learning_rate * dWxh
    net.rnn.Whh.data -= learning_rate * dWhh
    net.rnn.bh.data -= learning_rate * dbh
    net.fc.weight.data -= learning_rate * dWhy
    net.fc.bias.data -= learning_rate * dby


### 3. Weight Updates

In [34]:
def train_model_bptt(net, dataset):
    """Train the model using manual BPTT."""
    criterion = nn.CrossEntropyLoss()
    running_loss = 0
    print('Training network with BPTT...')
    for i in range(2000):
        # Generate input and target, convert to PyTorch tensors
        inputs, labels = dataset()
        inputs = torch.from_numpy(inputs).type(torch.float)  # Shape: (seq_len, batch_size, input_size)
        labels = torch.from_numpy(labels).type(torch.long)   # Shape: (seq_len, batch_size)

        # Zero gradients
        net.zero_grad()

        # Forward pass
        outputs = net(inputs)  # outputs shape: (seq_len, batch_size, num_classes)

        # Compute loss
        outputs_flat = outputs.view(-1, outputs.size(-1))    # Shape: (seq_len * batch_size, num_classes)
        labels_flat = labels.view(-1)                        # Shape: (seq_len * batch_size)
        loss = criterion(outputs_flat, labels_flat)

        # Backward pass using manual BPTT
        bptt(net, outputs, labels)

        # Logging
        running_loss += loss.item()
        if i % 100 == 99:
            running_loss /= 100
            print('Step {}, Loss {:0.4f}'.format(i+1, running_loss))
            running_loss = 0
    return net


### Training the Network with BPTT

In [35]:
# Instantiate the network
hidden_size = 128
net = RNNNet(input_size=input_size, hidden_size=hidden_size,
             output_size=output_size, dt=env.dt)

# Train the network
net = train_model_bptt(net, dataset)


Training network with BPTT...
Step 100, Loss 0.9611
Step 200, Loss 0.4608
Step 300, Loss 0.4134
Step 400, Loss 0.3899
Step 500, Loss 0.3681
Step 600, Loss 0.3494
Step 700, Loss 0.3289
Step 800, Loss 0.3077
Step 900, Loss 0.2893
Step 1000, Loss 0.2703
Step 1100, Loss 0.2516
Step 1200, Loss 0.2358
Step 1300, Loss 0.2201
Step 1400, Loss 0.2049
Step 1500, Loss 0.1925
Step 1600, Loss 0.1803
Step 1700, Loss 0.1685
Step 1800, Loss 0.1589
Step 1900, Loss 0.1493
Step 2000, Loss 0.1396


### Testing the network

In [36]:
def test_model(net, env, num_trial=200):
    # Initialize variables for logging
    activity_dict = {}  # recording activity
    trial_infos = {}  # recording trial information
    for i in range(num_trial):
        # Sample a new trial
        trial_info = env.new_trial()
        # Observation and ground-truth of this trial
        ob, gt = env.ob, env.gt
        # Convert to tensor, add batch dimension to input
        inputs = torch.from_numpy(ob[:, np.newaxis, :]).type(torch.float)
        # Run the network for one trial
        outputs = net(inputs)
        outputs = outputs.detach().numpy()[:, 0, :]
        # Compute performance
        choice = np.argmax(outputs[-1, :])
        correct = choice == gt[-1]
        # Record activity, trial information, choice, correctness
        activity_dict[i] = outputs
        trial_infos[i] = trial_info  # trial_info is a dictionary
        trial_infos[i].update({'correct': correct})
    return trial_infos, activity_dict


In [37]:
trial_infos, activity_dict = test_model(net, env, num_trial=200)
print('Average performance', np.mean([val['correct'] for val in trial_infos.values()]))


Average performance 0.715



*Acknowledgments*

*Special thanks to Guangyu Robert Yang for their [original work](https://github.com/gyyang/nn-brain/blob/master/RNN_tutorial.ipynb), which served as a foundation for this tutorial.*