<a href="https://colab.research.google.com/github/vargamartonaron/nma_23_rnn/blob/main/working_memory_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [23]:
import numpy as np

def generate_ou_process(n_steps, dt, theta, mu, sigma, x0):
    ou_process = [x0]
    for _ in range(1, n_steps):
        drift = theta * (mu - ou_process[-1]) * dt
        diffusion = sigma * np.sqrt(dt) * np.random.normal()
        x_next = ou_process[-1] + drift + diffusion
        ou_process.append(x_next)
    return np.array(ou_process)

# Parameters for the OU process
n_steps = 1000  # Number of time steps
dt = 0.1        # Time step size
theta = 0.1     # Mean reversion rate
mu = 0.0        # Mean value
sigma = 0.2     # Volatility
x0 = 1.0        # Initial value

# Generate the OU process data
ou_data = generate_ou_process(n_steps, dt, theta, mu, sigma, x0)


In [24]:
import torch
import torch.nn as nn

class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNNModel, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.rnn(x)
        out = self.fc(out)
        return out


In [25]:
# Convert the OU data to PyTorch tensors
ou_data_tensor = torch.tensor(ou_data, dtype=torch.float32).view(1, -1, 1)  # Input shape: (batch_size=1, seq_len=n_steps, input_size=1)

# Initialize the model and set hyperparameters
input_size = 1
hidden_size = 64
output_size = 1
learning_rate = 0.01
num_epochs = 100

model = RNNModel(input_size, hidden_size, output_size)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    outputs = model(ou_data_tensor)
    loss = criterion(outputs[:, :-1, :], ou_data_tensor[:, 1:, :])  # Shift the output by one timestep for decoding

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')


Epoch [10/100], Loss: 0.2031
Epoch [20/100], Loss: 0.1823
Epoch [30/100], Loss: 0.1628
Epoch [40/100], Loss: 0.1432
Epoch [50/100], Loss: 0.1232
Epoch [60/100], Loss: 0.1027
Epoch [70/100], Loss: 0.0819
Epoch [80/100], Loss: 0.0615
Epoch [90/100], Loss: 0.0432
Epoch [100/100], Loss: 0.0285


In [26]:
# Evaluation
model.eval()
with torch.no_grad():
    decoded_outputs = model(ou_data_tensor)
    mse_per_timestep = torch.mean((decoded_outputs[:, :-1, :] - ou_data_tensor[:, 1:, :])**2, dim=(0, 2))

print("Mean Squared Error at every timestep:")
for timestep, mse in enumerate(mse_per_timestep):
    print(f'Timestep {timestep+1}: {mse.item():.4f}')


Mean Squared Error at every timestep:
Timestep 1: 0.4554
Timestep 2: 0.2572
Timestep 3: 0.1930
Timestep 4: 0.2315
Timestep 5: 0.1571
Timestep 6: 0.0828
Timestep 7: 0.0673
Timestep 8: 0.0508
Timestep 9: 0.0393
Timestep 10: 0.0474
Timestep 11: 0.0917
Timestep 12: 0.0972
Timestep 13: 0.0968
Timestep 14: 0.0571
Timestep 15: 0.0359
Timestep 16: 0.0227
Timestep 17: 0.0008
Timestep 18: 0.0088
Timestep 19: 0.0153
Timestep 20: 0.0242
Timestep 21: 0.0141
Timestep 22: 0.0144
Timestep 23: 0.0370
Timestep 24: 0.0473
Timestep 25: 0.0172
Timestep 26: 0.0341
Timestep 27: 0.0414
Timestep 28: 0.1013
Timestep 29: 0.1130
Timestep 30: 0.0658
Timestep 31: 0.0739
Timestep 32: 0.0374
Timestep 33: 0.0827
Timestep 34: 0.0294
Timestep 35: 0.0388
Timestep 36: 0.0539
Timestep 37: 0.0722
Timestep 38: 0.0964
Timestep 39: 0.0687
Timestep 40: 0.0737
Timestep 41: 0.1297
Timestep 42: 0.1222
Timestep 43: 0.0673
Timestep 44: 0.1136
Timestep 45: 0.1290
Timestep 46: 0.0925
Timestep 47: 0.1347
Timestep 48: 0.1860
Timestep 49

In [27]:
threshold = 0.1

# Evaluation
model.eval()
with torch.no_grad():
    decoded_outputs = model(ou_data_tensor)
    mse_per_timestep = torch.mean((decoded_outputs[:, :-1, :] - ou_data_tensor[:, 1:, :])**2, dim=(0, 2))
    accuracy = torch.mean((mse_per_timestep < threshold).float()) * 100.0

print(f'Accuracy: {accuracy.item():.2f}%')


Accuracy: 94.09%
