<a href="https://colab.research.google.com/github/vargamartonaron/nma_23_rnn/blob/main/working_memory_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# Set the random seed for reproducibility
torch.manual_seed(42)

# RNN model definition
class WorkingMemoryRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(WorkingMemoryRNN, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, nonlinearity='tanh', batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), self.rnn.hidden_size).to(x.device)
        out, _ = self.rnn(x, h0)
        out = self.fc(out[:, -1, :])
        return out

def generate_ou_input_data(batch_size, input_size, sequence_length, theta=0.15, mu=0.0, sigma=0.3):
    ou_inputs = np.zeros((batch_size, sequence_length, input_size))
    for i in range(batch_size):
        for j in range(input_size):
            x = np.zeros(sequence_length)
            x[0] = np.random.randn()  # Initial value
            for t in range(1, sequence_length):
                x[t] = x[t - 1] + theta * (mu - x[t - 1]) + sigma * np.random.randn()
            ou_inputs[i, :, j] = x

    return torch.tensor(ou_inputs, dtype=torch.float32)

# Set the parameters for the RNN model
input_size = 20
hidden_size = 1000
output_size = 20

# Corrected sequence length
sequence_length = hidden_size + output_size - 2

# Create the RNN model
model = WorkingMemoryRNN(input_size, hidden_size, output_size)

# Set the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop (using synthetic OU data for demonstration)
num_epochs = 1000
batch_size = 32

# Generate synthetic OU input data with the correct sequence length
ou_data = generate_ou_input_data(batch_size, input_size, sequence_length)

# Corrected output size (excluding the last two time steps)
output_size = sequence_length - 2

# Train the model
for epoch in range(num_epochs):
    optimizer.zero_grad()

    # Generate input data for the current epoch
    inputs = ou_data.clone()

    # Shift input data by 2 timesteps to create the previous two timesteps prediction task
    inputs_shifted = torch.cat((inputs[:, :-2, :], torch.zeros(batch_size, 2, input_size)), dim=1)

    # Generate the ground truth output for the previous two timesteps prediction task
    ground_truth = inputs[:, 2:2+output_size, :]  # Adjusted to have the correct sequence length

    # Initialize the loss accumulator
    total_loss = 0.0

    # Iterate through each time step for the current sequence
    for t in range(output_size):
        # Forward pass for the current time step
        output_t = model(inputs_shifted[:, t:t+2, :])

        # Compute the loss for the current time step
        loss_t = criterion(output_t, ground_truth[:, t, :])

        # Accumulate the loss
        total_loss += loss_t

    # Average the losses and backpropagate
    total_loss /= output_size
    total_loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss.item():.4f}')

# Evaluation
model.eval()

# Generate new input data for evaluation
eval_inputs = generate_ou_input_data(batch_size, input_size, sequence_length)

# Create empty tensors to store the predicted outputs and inferred inputs
eval_outputs_pred = torch.zeros(batch_size, output_size, hidden_size)
inferred_inputs_t = torch.zeros(batch_size, output_size, input_size)
inferred_inputs_t_minus_1 = torch.zeros(batch_size, output_size, input_size)
inferred_inputs_t_minus_2 = torch.zeros(batch_size, output_size, input_size)

# Iterate through each timestep for the evaluation task
for t in range(output_size):
    # Slice the input data for the current timestep
    eval_inputs_t = eval_inputs[:, :t+1, :]  # Include the current timestep and all previous timesteps

    # Shift input data by 2 timesteps for the evaluation task
    eval_inputs_t_shifted = torch.cat((eval_inputs_t[:, :-2, :], torch.zeros(batch_size, 2, input_size)), dim=1)

    # Forward pass for evaluation
    eval_outputs_t = model(eval_inputs_t_shifted)

    # Store the predicted outputs for the current timestep
    eval_outputs_pred[:, t, :] = eval_outputs_t[:, -1, :]

    # Infer the inputs at t-1 and t-2 from the current output
    if t > 1:
        inferred_inputs_t[:, t-1, :] = eval_outputs_t[:, -2, :]
    if t > 0:
        inferred_inputs_t_minus_1[:, t-2, :] = eval_outputs_t[:, -3, :]
    if t > 2:
        inferred_inputs_t_minus_2[:, t-3, :] = eval_outputs_t[:, -4, :]

# Reshape both tensors to have the same shape (flatten the last two dimensions)
outputs = eval_outputs_pred.reshape(batch_size * output_size, hidden_size)
ground_truth_t = eval_inputs[:, 2:2+output_size, :].reshape(batch_size * output_size, input_size)
ground_truth_t_minus_1 = eval_inputs[:, 1:1+output_size, :].reshape(batch_size * output_size, input_size)
ground_truth_t_minus_2 = eval_inputs[:, :output_size, :].reshape(batch_size * output_size, input_size)

# Compute accuracy for inferring the inputs at each timestep
accuracy_t = (inferred_inputs_t.argmax(dim=2) == ground_truth_t.argmax(dim=1)).float().mean().item()
accuracy_t_minus_1 = (inferred_inputs_t_minus_1.argmax(dim=2) == ground_truth_t_minus_1.argmax(dim=1)).float().mean().item()
accuracy_t_minus_2 = (inferred_inputs_t_minus_2.argmax(dim=2) == ground_truth_t_minus_2.argmax(dim=1)).float().mean().item()

print("Accuracy for inferring input at timestep t:", accuracy_t)
print("Accuracy for inferring input at timestep t-1:", accuracy_t_minus_1)
print("Accuracy for inferring input at timestep t-2:", accuracy_t_minus_2)


In [85]:
outputs.shape

torch.Size([32, 20])

In [86]:
ground_truth.shape

torch.Size([32, 18, 20])