<a href="https://colab.research.google.com/github/vargamartonaron/nma_23_rnn/blob/main/working_memory_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [53]:
!pip install torch torchvision
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# Define the OU process function for generating input data
def ou_process_function(x, mu, theta, sigma):
    dx = theta * (mu - x) + sigma * torch.randn_like(x)
    return x + dx

# Define the InputLayer, RNNLayer, and OutputLayer classes
class InputLayer(nn.Module):
    def __init__(self):
        super(InputLayer, self).__init__()
        self.size = 20
        self.gain = 10.0
        self.time_before_input_starts = 10
        self.time_for_input_active = 1
        self.fc = nn.Linear(self.size, self.size, bias=False)
        self.input_data = torch.zeros(self.size)
        self.time_step_counter = 0

    def forward(self, x):
        # Increment the time step counter
        self.time_step_counter += 1

        if self.time_step_counter >= self.time_before_input_starts and self.time_step_counter < self.time_before_input_starts + self.time_for_input_active:
            self.input_data = ou_process_function(self.input_data, mu=1, theta=0.05, sigma=0.1)
        else:
            self.input_data = torch.zeros(self.size)

        # Reset the time step counter after 10 milliseconds
        if self.time_step_counter == self.time_before_input_starts + self.time_for_input_active:
            self.time_step_counter = 0

        return self.gain * self.fc(self.input_data)

class RNNLayer(nn.Module):
    def __init__(self):
        super(RNNLayer, self).__init__()
        self.num_neurons = 2000
        self.gain_synaptic_weights = 0.95
        self.fraction_nonzero_weights = 0.25
        self.tau = 20
        self.dt = 0.1
        self.sparsity = 0.05
        self.fc = nn.Linear(self.num_neurons, self.num_neurons, bias=False)
        self.nonlinearity = nn.Tanh()
        self.weight_init()

    def weight_init(self):
        # Initialize the dense weight tensor with zeros
        self.fc.weight.data = torch.zeros(self.num_neurons, self.num_neurons)

        # Set a fraction of the weights to non-zero values to achieve sparsity
        num_nonzero_weights = int(self.num_neurons * self.num_neurons * self.sparsity)
        indices = torch.randint(0, self.num_neurons, size=(2, num_nonzero_weights))
        values = torch.randn(num_nonzero_weights) * self.gain_synaptic_weights
        self.fc.weight.data[indices[0], indices[1]] = values

    def forward(self, x, state):
        dx = (-state + self.fc(self.nonlinearity(state))) / self.tau
        state = torch.add(state, self.dt * dx)
        return state

class OutputLayer(nn.Module):
    def __init__(self):
        super(OutputLayer, self).__init__()
        self.fc = nn.Linear(2000, 20, bias=False)

    def forward(self, x):
        return self.fc(x)

# Define the complete WorkingMemoryModel class
class WorkingMemoryModel(nn.Module):
    def __init__(self):
        super(WorkingMemoryModel, self).__init__()
        self.input_layer = InputLayer()
        self.rnn_layer = RNNLayer()
        self.output_layer = OutputLayer()

    def forward(self, x):
        rnn_state = torch.zeros(self.rnn_layer.num_neurons)
        inputs = torch.stack([self.input_layer(t) for t in x])
        for input_t in inputs:
            rnn_state = self.rnn_layer(input_t, rnn_state)
        output = self.output_layer(rnn_state)
        return output



In [54]:
def simulate_model_with_rnn_state(model, total_time):
    input_data = torch.stack([model.input_layer(t) for t in total_time])
    rnn_state = torch.zeros(model.rnn_layer.num_neurons)

    outputs = []
    for input_t in input_data:
        rnn_state = model.rnn_layer(input_t, rnn_state)
        output = model.output_layer(rnn_state)
        outputs.append(output)

    return torch.stack(outputs), rnn_state

In [None]:
def train_model(model, total_time, num_epochs):
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    criterion = nn.MSELoss()

    # Enable anomaly detection
    autograd.set_detect_anomaly(True)

    for epoch in range(num_epochs):
        optimizer.zero_grad()
        output, _ = simulate_model_with_rnn_state(model, total_time)

        # Generate the input data for the current time step
        input_data = torch.stack([model.input_layer(t) for t in total_time])

        # Calculate the loss between the generated input and the decoded output
        loss = criterion(input_data, output)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")

# Example usage for training the model
model = WorkingMemoryModel()
total_time = 10.0
total_time = [t * 0.1 for t in range(int(total_time * 10))]  # Adjust the total time based on your requirement
num_epochs = 200

train_model(model, total_time, num_epochs)

Epoch 1/200, Loss: 0.035526808351278305
Epoch 2/200, Loss: 0.042763762176036835
Epoch 3/200, Loss: 0.030870338901877403
Epoch 4/200, Loss: 0.04346989095211029
Epoch 5/200, Loss: 0.03793623298406601
Epoch 6/200, Loss: 0.036768969148397446
Epoch 7/200, Loss: 0.03384675085544586
Epoch 8/200, Loss: 0.030449654906988144
Epoch 9/200, Loss: 0.04049912095069885
Epoch 10/200, Loss: 0.03141278028488159
Epoch 11/200, Loss: 0.025123581290245056
Epoch 12/200, Loss: 0.030217895284295082
Epoch 13/200, Loss: 0.029994338750839233
Epoch 14/200, Loss: 0.027675170451402664
Epoch 15/200, Loss: 0.03106709197163582
Epoch 16/200, Loss: 0.03222484514117241
Epoch 17/200, Loss: 0.02783745899796486
Epoch 18/200, Loss: 0.03632538020610809
Epoch 19/200, Loss: 0.027975425124168396
Epoch 20/200, Loss: 0.02659611403942108
Epoch 21/200, Loss: 0.023287536576390266
Epoch 22/200, Loss: 0.027468902990221977
Epoch 23/200, Loss: 0.026698697358369827
Epoch 24/200, Loss: 0.022970542311668396
Epoch 25/200, Loss: 0.0214885342866