Recreating the Flip Flop task from the following [paper](https://direct.mit.edu/neco/article/25/3/626/7854/Opening-the-Black-Box-Low-Dimensional-Dynamics-in). The goal is to recreate it first using Gradient Descent (GD) to train the RNN, then using the Exponentiated Gradient (EG) method, and analyse any difference of results.

To directly test produce the results, go to fixed-point-finder/examples/torch and run python run_FlipFlop.py

In this notebook the aim is to recreate everything from "scratch". The key steps will involve: 
- Training the RNN
- Generating the data to train / test the RNN on
- Make a way to calculate the fixed points of the trained RNN 
- Make a way to visualise the fixed points

## Training the RNN 
Before generating the data I need to know what i want the RNN to look like (num entries, num exits, etc...)

We trained a randomly connected network (N = 1000) to perform the 3-bit flip-flop task using the FORCE learning algorithm (Sussillo & Abbott) (see section 6). We then performed the linearization analysis, using the trajectories of the system during operation as ICs. Specifically, we spanned all possible transitions between the memory states using the inputs to the network and then randomly selected 600 network states out of these trajectories to serve as ICs for the q optimization. The algorithm resulted in 26 distinct fixed points, on which we performed a linear stability analysis. Specifically, we computed the Jacobian matrix, equation 3.12, around each fixed point and performed an eigenvector decomposition on these
matrices. The resulting stable fixed points and saddle points are shown in Figure 3 (left). To display the results of these analyses, the network state x(t) is plotted in the basis of the first three principal components of the network activations (the transient pulses reside in other dimensions; one is shown in the right panel of Figure 3).

In [1]:
import torch
import torch.nn as nn 
import math

class flipflop(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        # We have to use super for the class to work as an nn.Module
        super().__init__()

        # We save the inputs/parameters here
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        # We can't use the nn.rnn to define our RNN because it'd make everything learnable, so we have to implement it by hand
        # We initialise the parameters (only one is learnable)
        self.J = torch.randn(hidden_size, hidden_size) * 1 / math.sqrt(hidden_size)
        self.B = torch.randn(hidden_size, input_size) * 0.1
        self.W_fb = torch.randn(hidden_size, output_size) * 0.1
        self.W_out = nn.Parameter(torch.randn(hidden_size, output_size), requires_grad=False)

        # Initialise parameters for FORCE paramter update 
        alpha = 0.1
        # Inverse correlation matrix
        self.P = torch.eye(hidden_size) / alpha  


    def forward(self, input_tensor, output_tensor, dt = 1.0):
        '''  
        This is the RNN's main time loop
        This simulates the RNN over time and performs training at every step.
        '''
        # The input_tensor is the input of size [T, input size] where T is the number of time steps
        T = input_tensor.shape[0]
        device = self.W_out.device

        # x is the hidden state of the RNN at the current time — it evolves over time according to the system's dynamics. In continuous-time terms, it’s like the system’s position in "neural space."
        x = torch.zeros(self.hidden_size, device=device)

        outputs = []

        # Training loop (dt is set to one to discretize time)
        for t in range(T):
            r = torch.tanh(x)
            z = self.W_out.T @ r
            x = (1 - dt) * x + dt * (self.J @ r + self.B @ input_tensor[t] + self.W_fb @ z)
            self.force_update(r, z, output_tensor[t], dt)
            outputs.append(z.detach())

        return torch.stack(outputs)

    def force_update(self, r, z, target, dt = 1.0):
        """
        r: firing rates at current time step [hidden_size]
        z: current output [output_size]
        target: desired output [output_size]
        dt: time step (optional but included for consistency)
        """
        # The output weights W_out get updated, using Recursive Least Squares (RLS)
        # Ensure correct shapes
        r = r.unsqueeze(1)              # [hidden_size, 1]
        e = (z - target).unsqueeze(0)   # [1, output_size]

        # Compute gain vector k
        Pr = self.P @ r                 # [hidden_size, 1]
        rPr = (r.T @ Pr).item()         # scalar
        c = 1.0 / (1.0 + rPr)
        k = c * Pr                      # [hidden_size, 1]

        # Update output weights: W_out = W_out - k * e
        delta_W = k @ e                # [hidden_size, output_size]
        self.W_out.data -= delta_W     # update in-place (no autograd)

        # Update P matrix
        self.P = self.P - k @ (r.T @ self.P)


## Creating the data
I need to create a [T, input_size] tensor of random -1, 0, 1. Ideally have them constant then have a random chance to switch to another. 

In [None]:
class data_generation():
    def __init__(self, T, input_size, prob, seed = 42):
        # We initialise the values: 
        self.T = T
        self.input_size = input_size
        self.prob = prob
        self.seed = seed
    
    def generate_flipflop_input(T, input_size, prob=0.01, seed):
        torch.manual_seed(seed)

        # Initialise the input for the RNN
        input_tensor = torch.zeros(T, input_size, dtype=torch.float32)

        # Initial value: randomly chosen for each input bit from {-1, 0, 1}
        current = torch.randint(low=-1, high=2, size=(input_size,), dtype=torch.int)

        for t in range(T):
            if t > 0:
                # Carry over previous values
                current = input_tensor[t - 1].clone().to(torch.int)

            # Decide which bits to change
            change_mask = torch.rand(input_size) < prob

            for i in range(input_size):
                if change_mask[i]:
                    # Pick a new value that's different from the current one
                    possible = [-1, 0, 1]
                    possible.remove(int(current[i].item()))
                    new_value = possible[torch.randint(len(possible), (1,)).item()]
                    current[i] = new_value

            input_tensor[t] = current

        return input_tensor




# We set the hidden size to 1000 as specified in the paper, and in the input/ouput sizes to 3 for the 3-bit flip flop task. 
hidden_size = 1000 
input_size = 3
output_size = 3