## Overview
This script contains the source code for analyzing the dynamics of the trained RNN model on a 3-bit memory task. Users need to install necessary libraries and initialize the inference before creating figures.

In all functions:
* $\beta=0.0$ indicates an $L^p$ regularized model
* $\beta=1.0$ indicates a distance-constrained model

Relevant to [Fig. A.2](#fps)

### Libraries required to run the script

In [1]:
import sys
import os
import numpy as np

import torch
import matplotlib.pyplot as plt

from model import ConstrainedModel
import utils
import random

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import matplotlib
from mpl_toolkits.mplot3d import Axes3D
from generate_input import get_data

In [None]:
json_path = os.path.join(os.getcwd(), "hps.json")
assert os.path.isfile(
    json_path
), "No json configuration file found at {}".format(json_path)
hps = utils.Params(json_path)

### Generates a sequence of static inputs
Static inputs are required to understand different structures output by PCA.

In [2]:
def get_static_inputs(time_steps, n_samples, inputs=None):
    # No inputs
    if inputs==None:  
        inputs = torch.tensor(np.tile(np.zeros([1, 3], dtype=int), (time_steps, 1)))
    else:
        inputs_seq = np.tile(np.array(inputs, dtype=int), (n_samples, 1))
        inputs_fill = np.tile(np.zeros([1, 3], dtype=int), (time_steps-n_samples,1))
        inputs = torch.tensor(np.concatenate([inputs_seq, inputs_fill], axis=0))
    
    return inputs.unsqueeze(0)

### Find fixed points of the network

In [3]:
def get_fps_states(hidden_size, inputs=[[1,1,1],[1,-1,1],[1,-1,-1],[-1,1,1],[-1,-1,1],[-1,-1,-1],[-1,1,-1],[1,1,-1]], n_samples=5, time_steps=20):
    fps_states = []
    for idx in range(0, len(inputs)):
        inputs_seq = np.tile(np.array(inputs[idx], dtype=int), (n_samples, 1))
        inputs_fill = np.tile(np.zeros([1, 3], dtype=int), (time_steps-n_samples,1))
        static_inputs = torch.tensor(np.concatenate([inputs_seq, inputs_fill], axis=0)).unsqueeze(0)
        static_inputs = static_inputs.float()
        
        hidden = torch.distributions.normal.Normal(loc=0.0, scale=1.0).sample([1,1,hps.hidden_size])
        _, _, hidden = model(static_inputs, hidden)
        hidden = hidden.squeeze()
        fps_states.append(hidden.detach().numpy())
    fps_states = np.stack(fps_states)
    return fps_states

def get_state_progression(pca, pca_hidden_sample, inputs=None):
    if inputs:
        raise NotImplementedError()
    static_inputs = get_static_inputs(time_steps=1, n_samples=1).float()
    pca_hidden_sample = pca_hidden_sample.reshape(1, -1)
    
    hidden = pca.inverse_transform(pca_hidden_sample)
    hidden = torch.tensor(hidden).unsqueeze(0).float()

    _, _, hidden = model(static_inputs, hidden)
    hidden = hidden.squeeze().detach().numpy()
    return pca.transform(hidden.reshape(1,-1))

<a id="fps"></a>

### [Set up inputs for state trajectory projections with fixed points](#fps)
``` state_traj_fps()``` performs principal componenet analysis on activation states of all hidden nodes, and reduce the dimensionality into 3D. Eight fixed points are noted as red dots in the figure.

This function does not require any inputs, as it will generate a random sequence of inputs with three chanels and each chanel can be +/-1 or zero, (e.g. [[1, 0, -1], [-1, -1, 1]]).  

In [4]:
def state_traj_fps(hps, ckpt_path, beta, seed, alpha, time_steps=1, n_samples=0, inputs=None, plot_sample_size=1):
    model = ConstrainedModel(hps.n_bits, hps.hidden_size, hps.n_bits, hps.n_spatial_dims, hps.norm)
    ckpt_path = f'./trial{hps.trial}/seed{seed}trained_model/alpha_{alpha}_beta_{b}/checkpoints/last.pth'
    print(f'Load model from: {ckpt_path}')
    ckpt = utils.load_checkpoint(ckpt_path, model)
    fps_states = get_fps_states(hps.hidden_size)

    test_loader, _ = get_data(hps)
    
    hidden_states = []
    for idx, data in enumerate(test_loader):
        inputs_seq, _ = data
        inputs_fill = np.tile(np.zeros([1, 3], dtype=int), (512, time_steps-n_samples,1))
        inputs = torch.tensor(np.concatenate([inputs_seq, inputs_fill], axis=1))
        inputs= inputs.float()
        hidden = torch.distributions.normal.Normal(loc=0.0, scale=1.0).sample([1,hps.batch_size,hps.hidden_size])
        hidden_state, _, _ = model(inputs, hidden)
        hidden_states.append(hidden_state.detach())
    hidden_states = np.vstack(hidden_states)
    [n_batch, n_time, n_states] = hidden_states.shape
    
    # Train PCA with hidden_states
    pca = PCA(n_components=3)
    state_traj = np.reshape(hidden_states, (n_batch * n_time, n_states))
    pca.fit(state_traj)
    
    
    fps = pca.transform(fps_states)

    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111,projection="3d")
    ax.set_xlabel("PC 1", fontweight=3.5)
    ax.set_zlabel("PC 3", fontweight=3.5)
    ax.set_ylabel("PC 2", fontweight=3.5)
    
    plot_sample_size = range(plot_sample_size)
    for batch_idx in plot_sample_size:
        hidden_state_idx = hidden_states[batch_idx]
        # Transforms data from (64, 16) to (64, 3) --> reduce dimensionality
        z = pca.transform(hidden_state_idx[range(0, n_time), :])
        ax.plot(z[:, 0], z[:, 1], z[:, 2], color="k", linewidth=0.2, alpha=0.1)
    ax.scatter(fps[:,0], fps[:,1], fps[:,2], color='r')
    print(fps[:,0], fps[:,1], fps[:,2])
#     ax.view_init(45, 80)
    fig.savefig(f'seed{seed}_alpha{alpha}_beta{beta}_fps_traj_inputs.png', dpi=400, bbox_inches='tight')
    plt.show()