# Load model

## Load options


In [None]:
import json
import os
from types import SimpleNamespace

# LSTM
path_lstm  = 'models/LSTM/LSTM_0904_2210'
# RNN
path_rnn   = 'models/RNN/RNN_0904_2100'
# RSNNs
path_if    = 'models/RSNN/RSNN_1015_1353_if'
path_lif   = 'models/RSNN/RSNN_1014_2127_lif'
path_alif  = 'models/RSNN/RSNN_1015_1408_alif'
path_dexat = 'models/RSNN/RSNN_1002_1731_dexat'
# path_texat = 'models/RSNN/RSNN_1012_1512_texat'  # 128
path_texat = 'models/RSNN/RSNN_1029_1232_texat'  # 4096

paths_dict = {
    'lstm':  path_lstm,
    'rnn':   path_rnn,
    'if':    path_if,
    'lif':   path_lif,
    'alif':  path_alif,
    'dexat': path_dexat,
    'texat': path_texat,
}

options_dict = dict()
for model_name, model_path in paths_dict.items():
    path_options = os.path.join(model_path, 'options.json')
    with open(path_options, 'r') as file:
        options = json.load(file)
    options = SimpleNamespace(**options)
    # print(f'options: \n{options}')
    
    options_dict[model_name] = options
    
    
# change device to cpu
# for model_name, options in options_dict.items():
#     print('---')
#     print(model_name)
    
#     options.device = 'cpu'
#     print(options.device)
    
print('options loaded')

## Load the trained model

In [None]:
lstm_rnn_models = ['lstm', 'rnn']
rsnn_models = ['if', 'lif', 'alif', 'dexat', 'texat']


# which_model = 'lstm'
# which_model = 'rnn'

# which_model = 'if'
# which_model = 'lif'
# which_model = 'alif'
# which_model = 'dexat'
which_model = 'texat'

In [None]:
options = options_dict[which_model]

print(f'using which_model={which_model}')
print(f'using options={options}')

In [None]:
import numpy as np
import torch
import os

from place_cells import PlaceCells
from trajectory_generator import TrajectoryGenerator
from model import RNN, LSTM
from model_rsnn import RSNN
# from trainer import Trainer

place_cells = PlaceCells(options)

# =========================================
# =======load multiple models==========
# models_dict = dict()
# for model_name, options in options_dict.items():
#     print('------')
#     print(f'options.RNN_type={options.RNN_type}')
    
#     if options.RNN_type == 'RNN':
#         model = RNN(options, place_cells)
#     elif options.RNN_type == 'LSTM':
#         model = LSTM(options, place_cells)
#     elif options.RNN_type == 'RSNN':
#         print(f'options.neuron_type={options.neuron_type}')
#         model = RSNN(options, place_cells)
#     else:
#         raise NotImplementedError

#     print(f'options.device={options.device}')
#     model = model.to(options.device)
    
#     # load model
#     model_pth = os.path.join(options.save_dir, options.run_ID)
#     model_pth = os.path.join(model_pth, 'model.pth')
#     model.load_state_dict(torch.load(model_pth))
    
#     models_dict[model_name] = model
# print('models loaded')
# =========================================


# =========================================
# =======load one model==========
print(options.RNN_type)
    
if options.RNN_type == 'RNN':
    model = RNN(options, place_cells)
elif options.RNN_type == 'LSTM':
    model = LSTM(options, place_cells)
elif options.RNN_type == 'RSNN':
    print(options.neuron_type)
    model = RSNN(options, place_cells)
else:
    raise NotImplementedError
    
print(options.device)
model = model.to(options.device)

# load model
model_pth = os.path.join(options.save_dir, options.run_ID)
####################################################################
model_pth = os.path.join(model_pth, 'model.pth')
# model_pth = os.path.join(model_pth, 'model_200.pth')
# model_pth = os.path.join(model_pth, 'model_100.pth')
####################################################################
model.load_state_dict(torch.load(model_pth))
print('model loaded')
# =========================================


trajectory_generator = TrajectoryGenerator(options, place_cells)
# trainer = Trainer(options, model, trajectory_generator)

# Task statistics

In [None]:
options.sequence_length = 100
print(f'options.sequence_length={options.sequence_length}')

In [None]:
from matplotlib import pyplot as plt

# Plot a few sample trajectories
inputs, pos, pc_outputs = trajectory_generator.get_test_batch()
# inputs, pos, pc_outputs = trajectory_generator.get_test_batch(batch_size=2000)
us = place_cells.us.cpu()
pos = pos.cpu()

In [None]:
plt.figure(figsize=(5,5))
plt.rcParams['axes.linewidth'] = 2

# plt.scatter(us[:,0], us[:,1], c='lightblue', label='$N_\\text{p}$ place cell centers')
plt.scatter(us[:,0], us[:,1], c='lightblue', label='Place cell centers')

start = 15
end = start + 1
for i in range(start, end):
    print(f'i={i}')
    
    plt.plot(pos[:,i,0], pos[:,i,1], c='black', label='Trajectory', linewidth=3)
    plt.scatter(pos[0,i,0], pos[0,i,1], c='blue', marker='o', label='Start of trajectory', s=70)
    plt.scatter(pos[-1,i,0], pos[-1,i,1], c='red', marker='^', label='End of trajectory', s=70)
    if i == start:
        # plt.legend(loc='upper right', bbox_to_anchor=(1.53, 1.02))
        plt.legend(loc='upper left')
        # plt.legend()
        
# plt.xlabel('$L$')
# plt.ylabel('$L$')
plt.xticks([]);
plt.yticks([]);
plt.tight_layout()

# plt.savefig('images/task/env.pdf', format='pdf', bbox_inches='tight')
# plt.savefig('images/task/env.png', dpi=300, bbox_inches='tight')

In [None]:
# print(inputs.shape)
print(pos.shape)
print(pc_outputs.shape)
print(pc_outputs[:, start, :].shape)

In [None]:
# Plot a few place cell outputs
# pc_outputs = pc_outputs.reshape(-1, options.Np).detach().cpu()
pc_outputs0 = pc_outputs[:, start, :].reshape(-1, options.Np).detach().cpu()

print(f'pc_outputs0.shape={pc_outputs0.shape}')
# print(pc_outputs[::100].shape)

# pc = place_cells.grid_pc(pc_outputs[::100], res=100)
pc = place_cells.grid_pc(pc_outputs0, res=500)
print(f'pc.shape={pc.shape}')

In [None]:
v, init_actv = inputs
print(init_actv.shape)

init_actv0 = init_actv[start].reshape(-1, options.Np).detach().cpu()
print(init_actv0.shape)

pc0 = place_cells.grid_pc(init_actv0, res=500)
print(f'pc0.shape={pc0.shape}')

In [None]:
plt.figure(figsize=(7*1.5, 3*1.5))

for i in range(21):
    plt.subplot(3, 7, i+1)
    
    if i == 0:
        plt.imshow(pc0[0], cmap='jet', interpolation='gaussian')
        plt.title('$\\boldsymbol{p}_0$')
        plt.axis('off')
    else:
        im = plt.imshow(pc[i-1], cmap='jet', interpolation='gaussian')
        plt.title(f'$\\boldsymbol{{p}}_{{{i}}}$')
        plt.axis('off')
    # plt.colorbar(im, fraction=0.15, pad=0.04)

# plt.suptitle('Place cell outputs', fontsize=16)
# plt.show()
plt.tight_layout()

# plt.savefig('place_cell_act.pdf', format='pdf', bbox_inches='tight')

In [None]:
plt.imshow(pc0[0], cmap='jet', interpolation='gaussian')
# plt.title('$\\boldsymbol{p}_0$')
plt.axis('off')

# plt.colorbar()
plt.tight_layout()

# plt.savefig('images/task/p0.pdf', format='pdf', bbox_inches='tight')
# plt.savefig('images/task/p0.png', dpi=300, bbox_inches='tight')

In [None]:
plt.imshow(pc[0], cmap='jet', interpolation='gaussian')
plt.title('$\\boldsymbol{p}_{1}$')
plt.axis('off')

plt.colorbar()
plt.tight_layout()

# plt.savefig('images/task/p1.pdf', format='pdf', bbox_inches='tight')
# plt.savefig('images/task/p1.png', dpi=300, bbox_inches='tight')

In [None]:
plt.imshow(pc[-1], cmap='jet', interpolation='gaussian')
length = pc.shape[0]
print(f'length={length}')

# plt.title('$\\boldsymbol{p}_{100}$')
plt.axis('off')

# plt.colorbar()
plt.tight_layout()

# plt.savefig('images/task/p100.pdf', format='pdf', bbox_inches='tight')
# plt.savefig('images/task/p100.png', dpi=300, bbox_inches='tight')

# Evaluate path integration performance

In [None]:
for model_name, model in models_dict.items():
    model.eval()
    print(model_name)
    print(f'model.training={model.training}')

In [None]:
# options.sequence_length = 20
# options.sequence_length = 50
options.sequence_length = 100

## compare errors

In [None]:
print(f'options.sequence_length={options.sequence_length}')

inputs, pos, pc_outputs = trajectory_generator.get_test_batch()

for model_name, model in models_dict.items():
    print('------')
    print(model_name)
    
    loss, err = model.compute_loss(inputs, pc_outputs, pos)
    
    print(f'loss={loss:.3f}')
    print(f'err={err:.3f}')

## compute predicted paths

In [None]:
from matplotlib import pyplot as plt

inputs, pos, pc_outputs = trajectory_generator.get_test_batch()

pos = pos.cpu()
print(f'pos.shape={pos.shape}')

pred_pos_dict = dict()
for model_name, model in models_dict.items():
    print('------')
    print(model_name)
    
    # inputs[0] = inputs[0].to(options_dict[model_name].device)
    # inputs[1] = inputs[1].to(options_dict[model_name].device)
    pred_pos = place_cells.get_nearest_cell_pos(model.predict(inputs)).cpu()
    print(f'pred_pos.shape={pred_pos.shape}')
    pred_pos_dict[model_name] = pred_pos

us = place_cells.us.cpu()

## plot error with respect to each time step

In [None]:
def smooth(scalars, weight=0.8):  # Weight between 0 and 1
    last = scalars[0]  # First value in the plot (first timestep)
    smoothed = list()
    for point in scalars:
        smoothed_val = last * weight + (1.0 - weight) * point  # Calculate smoothed value
        smoothed.append(smoothed_val)                        # Save it
        last = smoothed_val                                  # Anchor the last smoothed value
        
    return smoothed

In [None]:
err_dict = dict()

for model_name, pred_pos in pred_pos_dict.items():
    err = torch.sqrt(((pos - pred_pos)**2).sum(-1)).mean(-1)

    print(model_name)
    print(f'err.shape={err.shape}')

    err_dict[model_name] = err

for model_name, err in err_dict.items():
    if not model_name == 'if':
        plt.plot(err, alpha=0.2, c='blue')
        plt.plot(smooth(err, weight=0.5), label=model_name)
    
plt.legend()

## plot trajs

In [None]:
colors = ['Blue', 'Orange', 'Green', 'Red', 'Purple', 'Brown', 'Cyan']
# colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#17becf']
# darker versions:
# colors = ['#134d72', '#b3540a', '#194019', '#851b1b', '#5c4380', '#573d30', '#89497e', '#4c4c4c', '#107288']
# colors = ['#184f81', '#cc660b', '#238327', '#aa2121', '#775593', '#70483d', '#b9609b', '#666666', '#1392a8']

In [None]:
fig = plt.figure(figsize=(6,6))
plt.rcParams['axes.linewidth'] = 2

# plt.scatter(us[:,0], us[:,1], s=20, alpha=0.5, c='lightgrey')
plt.scatter(us[:,0], us[:,1], c='lightblue', label='$N_\\text{p}$ place cell centers')

start = 125
end = start + 1
# for i in range(start, end):
# for i in [start, 80, 135]:
for i in [start, 20, 62]:
    plt.plot(pos[:,i,0], pos[:,i,1], c='black', label='G.T. path', linewidth=3)
    
    for j, (model_name, pred_pos) in enumerate(pred_pos_dict.items()):
        if model_name in rsnn_models:
            label = 'Int. path of ' + model_name.upper() + '-RSNN'
        else:
            label = 'Int. path of ' + model_name.upper()

        color = colors[j]
        
        # plt.plot(
        #     pred_pos[:,i,0], pred_pos[:,i,1], 
        #     '.-', color=color, label=label, 
        #     linewidth=2,
        #     alpha=0.7,
        # )
        
        # =================================
        # T=100
        if model_name in ['lstm', 'rnn', 'texat']:
            plt.plot(
                pred_pos[:,i,0], pred_pos[:,i,1], 
                '.-', color=color, label=label, 
                linewidth=2,
                alpha=0.7,
            )
        else:
            # Dummy plot for legend
            plt.plot(
                [], [], 
                '.-', color=color, label='(Failed) ' + label, 
                linewidth=2,
                alpha=0.7,
            )
        # =================================
            
    if i == start:
        # plt.legend(loc='upper left')
        plt.legend(loc='lower right')
        
# plt.xlabel('$L$')
# plt.ylabel('$L$')
plt.xticks([])
plt.yticks([])
plt.xlim([-options.box_width/2, options.box_width/2])
plt.ylim([-options.box_height/2, options.box_height/2])
plt.tight_layout()

# plt.savefig('images/int_path_results/int_path_t' + str(options.sequence_length) + '.pdf')

In [None]:
fig = plt.figure(figsize=(6,6))
plt.rcParams['axes.linewidth'] = 2

# plt.scatter(us[:,0], us[:,1], s=20, alpha=0.5, c='lightgrey')
plt.scatter(us[:,0], us[:,1], c='lightblue', label='$N_\\text{p}$ place cell centers')

for i in range(1):
    plt.plot(pos[:,i,0], pos[:,i,1], c='black', label='G.T. path', linewidth=2)
    
    for model_name, pred_pos in pred_pos_dict.items():
        if model_name == 'lstm':
            plt.plot(pred_pos[:,i,0], pred_pos[:,i,1], '.-', color='blue', label='Integrated path using LSTM')
        elif model_name == 'rnn':
            plt.plot(pred_pos[:,i,0], pred_pos[:,i,1], '.-', color='orange', label='Integrated path using RNN')
        elif model_name == 'texat':
            plt.plot(pred_pos[:,i,0], pred_pos[:,i,1], '.-', color='green', label='Integrated path using TEXAT-RSNN')
            
    if i == 0:
        plt.legend()
        
plt.xlabel('$L$')
plt.ylabel('$L$')
plt.xticks([])
plt.yticks([])
plt.xlim([-options.box_width/2, options.box_width/2])
plt.ylim([-options.box_height/2, options.box_height/2])
# plt.legend()

# plt.savefig('images/int_path_lstm_rnn_texat.pdf')

In [None]:
fig = plt.figure(figsize=(6,6))
plt.rcParams['axes.linewidth'] = 2

# plt.scatter(us[:,0], us[:,1], s=20, alpha=0.5, c='lightgrey')
plt.scatter(us[:,0], us[:,1], c='lightblue', label='$N_\\text{p}$ place cell centers')

start = 66
end = start + 3
for i in range(start, end):
    plt.plot(pos[:,i,0], pos[:,i,1], c='black', label='G.T. path', linewidth=2)
    
    for model_name, pred_pos in pred_pos_dict.items():
        if model_name == 'if':
            plt.plot(pred_pos[:,i,0], pred_pos[:,i,1], '.-', color='blue', label='Integrated path using IF-RSNN')
        elif model_name == 'lif':
            plt.plot(pred_pos[:,i,0], pred_pos[:,i,1], '.-', color='orange', label='Integrated path using LIF-RSNN')
        elif model_name == 'alif':
            plt.plot(pred_pos[:,i,0], pred_pos[:,i,1], '.-', color='green', label='Integrated path using ALIF-RSNN')
        elif model_name == 'dexat':
            plt.plot(pred_pos[:,i,0], pred_pos[:,i,1], '.-', color='purple', label='Integrated path using DEXAT-RSNN')
        elif model_name == 'texat':
            plt.plot(pred_pos[:,i,0], pred_pos[:,i,1], '.-', color='red', label='Integrated path using TEXAT-RSNN')
            
    if i == start:
        plt.legend()
        
plt.xlabel('$L$')
plt.ylabel('$L$')
plt.xticks([])
plt.yticks([])
plt.xlim([-options.box_width/2, options.box_width/2])
plt.ylim([-options.box_height/2, options.box_height/2])

# plt.savefig('images/int_path_rsnns.pdf')

In [None]:
# Visualize predicted place cell outputs
inputs, pos, pc_outputs = trajectory_generator.get_test_batch()

preds = model.predict(inputs)
preds = preds.reshape(-1, options.Np).detach().cpu()

pc_outputs = model.softmax(pc_outputs).reshape(-1, options.Np).cpu()

pc_pred = place_cells.grid_pc(preds[:100])
pc = place_cells.grid_pc(pc_outputs[:100])

plt.figure(figsize=(16,4))
for i in range(8):
    plt.subplot(2,8,i+9)
    plt.imshow(pc_pred[2*i], cmap='jet')
    if i == 0:
        plt.ylabel('Predicted')
    plt.axis('off')
    
for i in range(8):
    plt.subplot(2,8,i+1)
    plt.imshow(pc[2*i], cmap='jet', interpolation='gaussian')
    if i == 0:
        plt.ylabel('True')
    plt.axis('off')
    
plt.suptitle('Place cell outputs', fontsize=16)
plt.show()

# Ratemaps

## compute ratemaps

In [None]:
# from visualize import compute_ratemaps, plot_ratemaps, rgb, compute_ratemaps_rsnn
from visualize import compute_ratemaps  # lstm, rnn
from visualize import compute_ratemaps_rsnn  # rsnns

res = 20
n_avg = 50
# res = 50
# n_avg = 100

if which_model in rsnn_models:
    # dict of (Ng, res, res) np.array
    ratemaps = compute_ratemaps_rsnn(model, trajectory_generator, options, res=res, n_avg=n_avg)
    for key, rm in ratemaps.items():
        print(f'{key}, rm.shape={rm.shape}')
else:
    # (Ng, res, res) np.array
    ratemaps = compute_ratemaps(model, trajectory_generator, options, res=res, n_avg=n_avg)
    print(f'ratemaps.shape={ratemaps.shape}')

In [None]:
from utils import Ratemap

if which_model in rsnn_models:
    # dict of Ratemap
    rate_maps = {}
    for key, rm in ratemaps.items():
        rate_maps[key] = Ratemap(options=options, res=res, ratemaps=rm)
else:
    # Ratemap
    rate_map = Ratemap(options=options, res=res, ratemaps=ratemaps)

## save ratemaps

In [None]:
import pickle

if which_model in rsnn_models:
    # if not ('model.pth' in model_pth):
    #     pkl_name = f'data/rate_maps/rate_maps_{which_model}_{model_pth[-7:-4]}.pkl'
    # else:
    #     pkl_name = f'data/rate_maps/rate_maps_{which_model}.pkl'
        
    pkl_name = f'data/rate_maps/rate_maps_{which_model}_Ng128.pkl'
    
    with open(pkl_name, 'wb') as f:
        # dict of Ratemap
        pickle.dump(rate_maps, f)
    print('saved pickle to ' + pkl_name)
    
else:
    pkl_name = f'data/rate_maps/rate_map_{which_model}.pkl'
    with open(pkl_name, 'wb') as f:
        # Ratemap
        pickle.dump(rate_map, f)
    print('saved pickle to ' + pkl_name)

---
## load ratemaps

In [None]:
layer_labels = {
    'spike_in': 'ISL', 
    'spike_rnn_1': 'RSL 1',
    'spike_rnn_2': 'RSL 2',
    'spike_rnn_3': 'RSL 3',
    'mem_out': 'OL',
}

lstm_rnn_models = ['lstm', 'rnn']
rsnn_models = ['if', 'lif', 'alif', 'dexat', 'texat']


# which_model = 'lstm'
# which_model = 'rnn'

# which_model = 'if'
# which_model = 'lif'
# which_model = 'alif'
# which_model = 'dexat'
which_model = 'texat'

In [None]:
import pickle

if which_model in rsnn_models:
    pkl_name = f'data/rate_maps/rate_maps_{which_model}.pkl'
    
    with open(pkl_name, 'rb') as f:
        # dict of Ratemap
        rate_maps = pickle.load(f)
    print('loaded pickle from ' + pkl_name)
    
else:
    pkl_name = f'data/rate_maps/rate_map_{which_model}.pkl'
    with open(pkl_name, 'rb') as f:
        # Ratemap
        rate_map = pickle.load(f)
    print('loaded pickle from ' + pkl_name)
    
    
# check
if which_model in rsnn_models:
    print(f"len(rate_maps['mem_out'].score_60)={len(rate_maps['mem_out'].score_60)}")
    print(f"len(rate_maps['mem_out'].max_60_mask)={len(rate_maps['mem_out'].max_60_mask)}")
else:
    print(f'len(rate_map.score_60)={len(rate_map.score_60)}')
    print(f'len(rate_map.max_60_mask)={len(rate_map.max_60_mask)}')

## Grid scales analysis

In [None]:
if which_model in rsnn_models:
    grid_scales = dict()
    for key, rate_map in rate_maps.items():
        grid_scales[key] = rate_map.get_grid_scale()
        print(f'{key} len: {len(grid_scales[key])}')
else:
    grid_scale = rate_map.get_grid_scale()
    print(f'len(grid_scale): {len(grid_scale)}')

In [None]:
# colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', 
#           '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#17becf']
colors = ['Blue', 'Orange', 'Green', 'Red', 
          'Purple', 'Brown', 'Pink', 'Gray', 'Cyan']

layer_labels = {
    'spike_in': 'ISL', 
    'spike_rnn_1': 'RSL 1',
    'spike_rnn_2': 'RSL 2',
    'spike_rnn_3': 'RSL 3',
    'mem_out': 'OL',
}

### KDE plot of scales for LSTM and RNN

In [None]:
import seaborn as sns
from matplotlib import pyplot as plt
import numpy as np

ax = sns.kdeplot(
    data=grid_scale, 
    # legend=True, 
    # label=which_model.upper(),
)

x = ax.get_lines()[-1].get_xdata()
y = ax.get_lines()[-1].get_ydata()
x_max = x[np.argmax(y)]

print(which_model)
print(f'max scale={x_max}')

ax.axvline(x=x_max, linestyle='--', color='gray')
ax.text(x=x_max, y=ax.get_ylim()[1], 
        s=f'{x_max:.2f}', ha='center', va='bottom')


xlabel = 'Grid scale'
plt.xlabel(xlabel)
# plt.legend()

# plt.savefig('images/grid_scales_' + options.neuron_type + '.pdf')

In [None]:
plt.plot(x,y)
print(f'x.shape={x.shape}')
print(f'y.shape={y.shape}')

from scipy.signal import find_peaks

peak_indices = find_peaks(y)[0]
x_peaks = x[peak_indices]
y_peaks = y[peak_indices]

print("x values at local maxima:", x_peaks)
print("y values at local maxima:", y_peaks)

In [None]:
ax = sns.kdeplot(
    data=grid_scale, 
    # legend=True, 
    fill=True, alpha=.5, linewidth=1,
    # label=which_model.upper(),
)

# ===============================================
# x_peaks_draw = [x_peaks[0]]  # lstm
x_peaks_draw = x_peaks[1:4]  # rnn
# ===============================================

for x in x_peaks_draw:
    ax.axvline(x=x, linestyle='--', color='gray')
    ax.text(x=x, y=ax.get_ylim()[1], 
            s=f'{x:.2f}', ha='center', va='bottom')


for num in range(len(x_peaks_draw) - 1):
    x_left = x_peaks_draw[num]
    x_right = x_peaks_draw[num + 1]
    
    # <-> arrow
    ax.annotate(
        '', 
        xy    =(x_left,  ax.get_ylim()[1] * 0.9), 
        xytext=(x_right, ax.get_ylim()[1] * 0.9), 
        arrowprops=dict(arrowstyle='<->', color='blue'),
    )
    # ratio text
    avg = (x_left + x_right) / 2.0
    ratio = x_right / x_left
    ax.text(x=avg, 
            # y=ax.get_ylim()[1] * 0.9, # lstm
            y=ax.get_ylim()[1] * 0.91, # rnn
            s=f'{ratio:.2f}', ha='center', va='bottom')


ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

xlabel = 'Grid scale (m)'
ylabel = 'Probability density'
plt.xlabel(xlabel, fontsize=14)
plt.ylabel(ylabel, fontsize=14)

# plt.legend()

plt.savefig('images/gridscales/grid_scales_' + which_model + '.pdf')

### KDE plot of scales for RSNNs

In [None]:
import seaborn as sns
from matplotlib import pyplot as plt
import numpy as np

max_scales = dict()
idx_color = 0
is_first = True

for key, scale in grid_scales.items():
    if not (key in [
        # 'mem_in',
        'spike_in',
        # 'mem_rnn_1',
        'spike_rnn_1',
        # 'mem_rnn_2',
        'spike_rnn_2',
        # 'mem_rnn_3',
        'spike_rnn_3',
        'mem_out',
    ]):
        continue
    
    if len(scale) == 0 or len(scale) == 1:
        continue
    
    if is_first:
        ax = sns.kdeplot(
            data=scale, 
            legend=True, 
            color=colors[idx_color], 
            label=layer_labels[key],
        )
        is_first = False
    else:
        sns.kdeplot(
            data=scale, 
            ax=ax, 
            legend=True, 
            color=colors[idx_color], 
            label=layer_labels[key],
        )
    idx_color += 1
    
    x = ax.get_lines()[-1].get_xdata()
    y = ax.get_lines()[-1].get_ydata()
    x_max = x[np.argmax(y)]
    max_scales[key] = x_max
    # max_scales[key] = round(x_max, 2)
    
    print(key)
    print(f'max scale={x_max}')


for scale in [min(max_scales.values()), max(max_scales.values())]:
    ax.axvline(x=scale, linestyle='--', color='gray')
    ax.text(x=scale, y=ax.get_ylim()[1], 
            s=f'{scale:.2f}', ha='center', va='bottom')


xlabel = 'Grid scale'
plt.xlabel(xlabel)
plt.legend()

# plt.savefig('images/grid_scales_' + options.neuron_type + '.pdf')

In [None]:
is_first = True
for key, scale in grid_scales.items():
    if not (key in [
        # 'mem_in',
        'spike_in', 
        # 'mem_rnn_1',
        'spike_rnn_1',
        # 'mem_rnn_2',
        'spike_rnn_2',
        # 'mem_rnn_3',
        'spike_rnn_3',
        'mem_out',
    ]):
        continue
    
    if len(scale) == 0 or len(scale) == 1:
        continue
    
    if is_first:
        ax = sns.kdeplot(
            data=scale, 
            legend=True, 
            fill=True, alpha=.5, linewidth=1,
            # color=colors[idx_color], 
            label=layer_labels[key],
        )
        is_first = False
    else:
        sns.kdeplot(
            data=scale, 
            ax=ax, 
            legend=True, 
            fill=True, alpha=.5, linewidth=1,
            # color=colors[idx_color], 
            label=layer_labels[key],
        )
    # idx_color += 1

    
# 2 vertical lines 
for scale in [min(max_scales.values()), max(max_scales.values())]:
    ax.axvline(x=scale, linestyle='--', color='gray')
    ax.text(x=scale, y=ax.get_ylim()[1], 
            s=f'{scale:.2f}', ha='center', va='bottom')
    

# <-> arrow
ax.annotate(
    '', 
    xy    =(min(max_scales.values()), ax.get_ylim()[1] * 0.9), 
    xytext=(max(max_scales.values()), ax.get_ylim()[1] * 0.9), 
    arrowprops=dict(arrowstyle='<->', color='blue'),
)
# ratio text
avg = (min(max_scales.values()) + max(max_scales.values())) / 2.0
ratio = max(max_scales.values()) / min(max_scales.values())
ax.text(x=avg, y=ax.get_ylim()[1] * 0.9, 
        s=f'{ratio:.2f}', ha='center', va='bottom')


ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

ax.legend(title='RSNN layer')

xlabel = 'Grid scale (m)'
ylabel = 'Probability density'
plt.xlabel(xlabel, fontsize=14)
plt.ylabel(ylabel, fontsize=14)

# plt.legend()

plt.savefig('images/gridscales/grid_scales_' + which_model + '.pdf')

## Grid cells and gridness scores analysis

### plot grid cell rate maps

In [None]:
from visualize import plot_ratemaps_and_sacs

top_num = 40  #########

if which_model in lstm_rnn_models:
    fig_rm_sac = plot_ratemaps_and_sacs(
        rate_map,
        title=which_model,
        n_plot=top_num, 
        n_cols=top_num//4, 
    )

if which_model in rsnn_models:
    for layer in layer_labels.keys():
        fig_rm_sac = plot_ratemaps_and_sacs(
            rate_maps[layer], 
            # res=50,
            title=which_model + '_' + layer, 
            n_plot=top_num, 
            n_cols=top_num//4, 
        )

In [None]:
from visualize import plot_ratemaps_and_sacs
from matplotlib import pyplot as plt

selected_num = 3


selected_idxs = [1, 8, 17]  # lstm
selected_idxs = [8, 17, 23]  # rnn


# ===========================================
# ===========================================
# RSNN

# layer = 'spike_in'
# selected_idxs = [0, 1, 2]
# selected_idxs = [0, 1, 2]
# selected_idxs = [0, 1, 2]
# selected_idxs = [2, 18, 20]

# layer = 'spike_rnn_1'
# selected_idxs = [0, 1, 9]
# selected_idxs = [0, 1, 2]
# selected_idxs = [0, 1, 2]
# selected_idxs = [0, 2, 5]

# layer = 'spike_rnn_2'
# selected_idxs = [1, 2, 4]
# selected_idxs = [0, 1, 4]
# selected_idxs = [0, 1, 2]
# selected_idxs = [0, 1, 2]

# layer = 'spike_rnn_3'
# selected_idxs = [1, 3, 5]
# selected_idxs = [0, 1, 2]
# selected_idxs = [0, 1, 2]
# selected_idxs = [3, 7, 8]

# layer = 'mem_out'
# selected_idxs = [1, 10, 14]
# selected_idxs = [0, 2, 4]
# selected_idxs = [1, 3, 4]
# selected_idxs = [3, 12, 17]
# ===========================================
# ===========================================


if which_model in lstm_rnn_models:
    fig_rm_sac = plot_ratemaps_and_sacs(
        rate_map, 
        n_plot=selected_num, 
        n_cols=selected_num, 
        selected_idxs=selected_idxs,
    )
    # plt.savefig('images/gridcells/gridcells_' + which_model + '.pdf')


if which_model in rsnn_models:
    fig_rm_sac = plot_ratemaps_and_sacs(
        rate_maps[layer], 
        n_plot=selected_num, 
        n_cols=selected_num, 
        selected_idxs=selected_idxs,
    )
    # plt.savefig('images/gridcells/gridcells_' + which_model + '_' + layer + '.pdf')

### plot rate maps of heterogeneous cells

In [None]:
from visualize import plot_ratemaps_and_sacs

for layer in layer_labels.keys():
    fig_rm_sac = plot_ratemaps_and_sacs(
        rate_maps[layer], 
        title=which_model + '_' + layer, 
        n_plot=128, 
    )

### visualize gridness scores

#### save scores

In [None]:
from matplotlib import pyplot as plt
import numpy as np

if which_model in rsnn_models:
    for layer in layer_labels.keys():
        print(f'------{layer}------')
        
        score_60 = rate_maps[layer].score_60
        print(f'len(score_60)={len(score_60)}')

        name = f'data/scores/score_{which_model}_{layer}.npy'
        np.save(name, np.array(score_60))

        avg_score_60 = np.nanmean(score_60)
        print(f'avg_score_60={avg_score_60}')
else:
    print(f'------{which_model}------')
    
    score_60 = rate_map.score_60
    
    name = f'data/scores/score_{which_model}.npy'
    np.save(name, np.array(score_60))

    avg_score_60 = np.nanmean(score_60)
    print(f'avg_score_60={avg_score_60}')

#### load scores

In [None]:
from matplotlib import pyplot as plt
import numpy as np

layer_labels = {
    # 'mem_in',
    'spike_in': 'ISL', 
    # 'mem_rnn_1',
    'spike_rnn_1': 'RSL 1',
    # 'mem_rnn_2',
    'spike_rnn_2': 'RSL 2',
    # 'mem_rnn_3',
    'spike_rnn_3': 'RSL 3',
    'mem_out': 'OL',
}

lstm_rnn_models = ['lstm', 'rnn']
rsnn_models = ['if', 'lif', 'alif', 'dexat', 'texat']

# load all scores
scores = {
    'lstm': np.load('data/scores/score_lstm.npy'),
    'rnn':  np.load('data/scores/score_rnn.npy'),
    'if': {
        'spike_in':    np.load('data/scores/score_if_spike_in.npy'),
        'spike_rnn_1': np.load('data/scores/score_if_spike_rnn_1.npy'),
        'spike_rnn_2': np.load('data/scores/score_if_spike_rnn_2.npy'),
        'spike_rnn_3': np.load('data/scores/score_if_spike_rnn_3.npy'),
        'mem_out':     np.load('data/scores/score_if_mem_out.npy'),
    },
    'lif': {
        'spike_in':    np.load('data/scores/score_lif_spike_in.npy'),
        'spike_rnn_1': np.load('data/scores/score_lif_spike_rnn_1.npy'),
        'spike_rnn_2': np.load('data/scores/score_lif_spike_rnn_2.npy'),
        'spike_rnn_3': np.load('data/scores/score_lif_spike_rnn_3.npy'),
        'mem_out':     np.load('data/scores/score_lif_mem_out.npy'),
    },
    'alif': {
        'spike_in':    np.load('data/scores/score_alif_spike_in.npy'),
        'spike_rnn_1': np.load('data/scores/score_alif_spike_rnn_1.npy'),
        'spike_rnn_2': np.load('data/scores/score_alif_spike_rnn_2.npy'),
        'spike_rnn_3': np.load('data/scores/score_alif_spike_rnn_3.npy'),
        'mem_out':     np.load('data/scores/score_alif_mem_out.npy'),
    },
    'dexat': {
        'spike_in':    np.load('data/scores/score_dexat_spike_in.npy'),
        'spike_rnn_1': np.load('data/scores/score_dexat_spike_rnn_1.npy'),
        'spike_rnn_2': np.load('data/scores/score_dexat_spike_rnn_2.npy'), 
        'spike_rnn_3': np.load('data/scores/score_dexat_spike_rnn_3.npy'),
        'mem_out':     np.load('data/scores/score_dexat_mem_out.npy'),
    },
    'texat': {
        'spike_in':    np.load('data/scores/score_texat_spike_in.npy'),
        'spike_rnn_1': np.load('data/scores/score_texat_spike_rnn_1.npy'),
        'spike_rnn_2': np.load('data/scores/score_texat_spike_rnn_2.npy'), 
        'spike_rnn_3': np.load('data/scores/score_texat_spike_rnn_3.npy'),
        'mem_out':     np.load('data/scores/score_texat_mem_out.npy'),
    },
}

# convert nan to -inf
for model in scores.keys():
    if model in lstm_rnn_models:
        scores[model] = np.nan_to_num(scores[model], nan=-np.inf)
    elif model in rsnn_models:
        for layer in scores[model].keys():
            scores[model][layer] = np.nan_to_num(scores[model][layer], nan=-np.inf)


#### compute grid cell proportion: score > gridness threshold

In [None]:
grid_thresh = 0.37

def compute_ratio(values, threshold):
    total_count = len(values)
    greater_than_threshold_count = sum(1 for value in values if value > threshold)
    ratio = float(greater_than_threshold_count) / float(total_count)
    return ratio


for model in scores.keys():
    print('------')
    if model in lstm_rnn_models:
        print(f'model={model}')
        
        ratio = compute_ratio(scores[model], grid_thresh)
        ratio = round(ratio * 100, 1)
        print(ratio)
    elif model in rsnn_models:
        print(f'model={model}')
        for layer in scores[model].keys():
            print(f'layer={layer}')
            
            ratio = compute_ratio(scores[model][layer], grid_thresh)
            ratio = round(ratio * 100, 1)
            print(ratio)


#### compute average scores

In [None]:
# compute average scores
avg_scores = dict()
std_err = dict()

top_num = 10  #########

for model in scores.keys():
    if model in lstm_rnn_models:
        s_no_nan = scores[model]
        top_scores = np.sort(s_no_nan)[-top_num:]
        # mean
        avg_scores[model] = np.mean(top_scores)
        # std err of mean
        std_err[model] = np.std(top_scores) / np.sqrt(len(top_scores))
        
    elif model in rsnn_models:
        avg_scores[model] = dict()
        std_err[model] = dict()
        
        for layer in scores[model].keys():
            s_no_nan = scores[model][layer]
            top_scores = np.sort(s_no_nan)[-top_num:]
            # mean
            avg_scores[model][layer] = np.mean(top_scores)
            # std err of mean
            std_err[model][layer] = np.std(top_scores) / np.sqrt(len(top_scores))

print(avg_scores)
print(std_err)

#### visualize average scores, bar chart

In [None]:
# colors = ['Blue', 'Orange', 'Green', 'Red', 'Purple', 'Brown', 'Pink', 'Gray', 'Cyan']
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#17becf']
# darker versions:
# colors = ['#134d72', '#b3540a', '#194019', '#851b1b', '#5c4380', '#573d30', '#89497e', '#4c4c4c', '#107288']
# colors = ['#184f81', '#cc660b', '#238327', '#aa2121', '#775593', '#70483d', '#b9609b', '#666666', '#1392a8']

In [None]:
# X labels
x_labels = [
    model.upper() if model in lstm_rnn_models else (model.upper() + '-RSNN')
    for model in avg_scores.keys()
]

# Define the x-axis positions for each group
# x = np.arange(len(x_labels))
x = [
    0,
    0.4,
    1.1,
    1.1 + 1.0,
    1.1 + 2.0,
    1.1 + 3.0,
    1.1 + 4.0,
]


fig, ax = plt.subplots(figsize=(12, 6))

for i, model in enumerate(avg_scores.keys()):
    if model in lstm_rnn_models:
        ax.bar(
            x[i], avg_scores[model], 
            ###
            yerr=std_err[model], 
            capsize=3,
            ###
            color='gray', 
            width=0.12, 
        )
        
    elif model in rsnn_models:
        for j, layer in enumerate(avg_scores[model].keys()):
            ax.bar(
                x[i] + j * 0.15 - 0.3, avg_scores[model][layer], 
                ###
                yerr=std_err[model][layer], 
                capsize=3,
                ###
                color=colors[j], 
                width=0.12, 
                label=layer_labels[layer],
            )


# Add horizontal line at y=0 for reference
plt.axhline(avg_scores['lstm'], color='gray', linestyle='--', linewidth=1)
plt.axhline(avg_scores['rnn'], color='gray', linestyle='--', linewidth=1)
# plt.axhline(0.37, color='gray', linestyle='--', linewidth=1)

   
# Set labels
ax.set_xticks(x)
ax.set_xticklabels(x_labels, fontsize=10)

ax.set_xlabel("Model", fontsize=14)
ax.set_ylabel("Grid score", fontsize=14)
# ax.set_title(f"Average gridness score for top {top_num} grid cells", fontsize=16)

# Add legend
axhandles, axlabels = ax.get_legend_handles_labels()
by_label = dict(zip(axlabels, axhandles))  # Remove duplicate labels
ax.legend(by_label.values(), by_label.keys(), 
          title='RSNN layer',
          loc="upper left", bbox_to_anchor=(1.0, 1.0))

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Show the plot
plt.tight_layout()
# plt.show()

plt.savefig('images/gridness_scores.pdf')

In [None]:
print(x)

# State space properties

## Manifold distance

In [None]:
import numpy as np
from matplotlib import pyplot as plt

res = 50 #20

origins = np.stack(np.mgrid[:3, :3] - 1) * res//4 + res//2
# print(f'origins={origins}')

n_grid_cells = 128 #4096 #1024

###### RSNN ######
# layer = 'spike_in'
# layer = 'spike_rnn_1'
# layer = 'spike_rnn_2'
# layer = 'spike_rnn_3'
layer = 'mem_out'

grid_sort = np.flip(np.argsort(rate_maps[layer].score_60))
rate_map = rate_maps[layer].ratemaps

###### RNN ######
# grid_sort = np.flip(np.argsort(rate_map.score_60))
# rate_map = rate_map.ratemaps


fig = plt.figure(figsize=(8, 8))
for i in range(3):
    for j in range(3):
        plt.subplot(3, 3, 3*i + j + 1)
        x = origins[0, i, j]
        y = origins[1, i, j]
        
        r0 = rate_map[grid_sort[:n_grid_cells], x, y]
        r0 = r0.reshape(-1, 1, 1)
        
        dr0 = r0
        dr1 = rate_map[grid_sort[:n_grid_cells]]
        # dr0 = dr0.reshape(dr0.shape[0], -1)
        # dr1 = dr1.reshape(dr1.shape[0], -1)
        dists = np.linalg.norm(dr0 - dr1, axis=0)
        # print(f'dr0.shape={dr0.shape}')
        # print(f'dr1.shape={dr1.shape}')
        # print(f'dists.shape={dists.shape}')
        
        im = plt.imshow(dists.reshape(res, res) / np.max(dists),
                        cmap='viridis_r', 
                        # cmap='jet', 
                        interpolation='gaussian')
        plt.axis('off')
        plt.scatter(y, x, marker='x', c='black')
        
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.12, 0.02, 0.74])
cbar = fig.colorbar(im, cax=cbar_ax)
cbar.ax.locator_params(nbins=3)
cbar.ax.tick_params(labelsize=20)
cbar.outline.set_visible(False)

## Neural sheet

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# res = 20
res = 50

# n_grid_cells = 128
# n_grid_cells = 1024
# n_grid_cells = 2048
n_grid_cells = 4096

if which_model in rsnn_models:
    ###### RSNN ######
    # layer = 'spike_in'
    # layer = 'spike_rnn_1'
    layer = 'spike_rnn_2'
    # layer = 'spike_rnn_3'
    # layer = 'mem_out'

    grid_sort = np.flip(np.argsort(rate_maps[layer].score_60))

    rate_map = rate_maps[layer].ratemaps
    print(f'rate_map.shape={rate_map.shape}')

    rate_map = rate_map[grid_sort[:n_grid_cells]]
    print(f'rate_map.shape={rate_map.shape}')
    
else:
    grid_sort = np.flip(np.argsort(rate_map.score_60))

    rate_map = rate_map.ratemaps
    print(f'rate_map.shape={rate_map.shape}')

    rate_map = rate_map[grid_sort[:n_grid_cells]]
    print(f'rate_map.shape={rate_map.shape}')

In [None]:
from tqdm import tqdm

# Fourier transform 
# Ng = options.Ng
Ng = n_grid_cells

rm_fft_real = np.zeros([Ng, res, res])
rm_fft_imag = np.zeros([Ng, res, res])

for i in tqdm(range(Ng)):
    rm_fft_real[i] = np.real(np.fft.fft2(rate_map[i].reshape([res, res])))
    rm_fft_imag[i] = np.imag(np.fft.fft2(rate_map[i].reshape([res, res])))
    
rm_fft = rm_fft_real + 1j * rm_fft_imag

im = (np.real(rm_fft)**2).mean(0)
im[0, 0] = 0

In [None]:
from matplotlib.patches import FancyArrowPatch

fig, ax = plt.subplots()

width = 6
idxs = np.arange(-width + 1, width)
x2, y2 = np.meshgrid(
    # np.arange(2*width - 1),
    # np.arange(2*width - 1),
    np.arange(-width+1, width),
    np.arange(-width+1, width),
)

# scatter = plt.scatter(x2, y2, c=im[idxs][:, idxs], s=600, cmap='Oranges')  # circle
scatter = ax.scatter(x2, y2, c=im[idxs][:, idxs], marker='s', s=500, cmap='Oranges')  # square
# scatter = ax.scatter(x2, y2, c=im[idxs][:, idxs], marker='s', s=500, cmap='Blues')  # square

# color bar
cbar = plt.colorbar(scatter, shrink=0.6, pad=0.0)
cbar.set_ticks([cbar.vmin, cbar.vmax])  # Set ticks at the minimum and maximum
cbar.set_ticklabels(['Low', 'High'])    # Set the labels

# origin
ax.scatter(0, 0, color='black', s=10)

# ==============================
###### 'spike_rnn_1' ######
# k1 = [2,1]
# k2 = [0,2]
# k3 = [-1,2]
###### 'spike_rnn_2' ######
k1 = [2,1]
k2 = [0,2]
k3 = [-2,1]
###### 'spike_rnn_3' ###### res50
# k1 = [2,1]
# k2 = [0,2]
# k3 = [-2,1]
###### 'mem_out' ######
# k1 = [3,0]
# k2 = [1,3]
# k3 = [-2,2]
# res50
# k1 = [1,1]
# k2 = [0,1]
# k3 = [-1,1]


###### rnn ######
# k1 = [3,0]
# k2 = [1,2]
# k3 = [-2,2]

# draw arrows
for i, k in enumerate([k1, k2, k3]):
    # k arrow
    # plt.arrow(0, 0, 
    #           k[1], k[0], 
    #           head_width=0.05, head_length=0.05, fc='k', ec='k')
    arrow = FancyArrowPatch(
        (0, 0), 
        (k[1], k[0]), 
        arrowstyle="->",       # Arrow style
        color="black",         # Arrow color
        mutation_scale=10      # Scale of the arrowhead
    )
    ax.add_patch(arrow)
    
    if i == 0:
        xytext = (k[1]/2.0 + 0.15, k[0]/2.0)
        # xytext = (k[1]/2.0 + 0.1, k[0]/2.0)
    elif i == 1:
        xytext = (k[1]/2.0, k[0]/2.0 + 0.1)
        # xytext = (k[1]/2.0, k[0]/2.0 + 0.3)
    elif i == 2:
        xytext = (k[1]/2.0 + 0.1, k[0]/2.0 )
    plt.annotate(f'$k_{i+1}$', 
                 xy=(k[1], k[0]), 
                 xytext=xytext, 
                 textcoords='data')

    # -k arrow
    # plt.arrow(0, 0, 
    #           -k[1], -k[0], 
    #           head_width=0.1, head_length=0.1, fc='k', ec='k')
    arrow = FancyArrowPatch(
        (0, 0), 
        (-k[1], -k[0]), 
        arrowstyle="->",       # Arrow style
        color="black",         # Arrow color
        mutation_scale=10      # Scale of the arrowhead
    )
    ax.add_patch(arrow)
    if i == 0:
        xytext = (-k[1]/2.0 - 0.7, -k[0]/2.0)
    elif i == 1:
        xytext = (-k[1]/2.0 - 0.5, -k[0]/2.0 + 0.1)
        # xytext = (-k[1]/2.0 - 0.5, -k[0]/2.0 + 0.2)
    elif i == 2:
        xytext = (-k[1]/2.0 - 0.75, -k[0]/2.0)
        # xytext = (-k[1]/2.0 - 0.75, -k[0]/2.0 - 0.1)
    plt.annotate(f'$-k_{i+1}$', 
                 xy=(-k[1], -k[0]), 
                 xytext=xytext, 
                 textcoords='data')
    
# draw dotted lines
xlist = [
    [k1[1], k2[1]],
    [k2[1], k3[1]],
    [k3[1], -k1[1]],
    [-k1[1], -k2[1]],
    [-k2[1], -k3[1]],
    [-k3[1], k1[1]],
]
ylist = [
    [k1[0], k2[0]],
    [k2[0], k3[0]],
    [k3[0], -k1[0]],
    [-k1[0], -k2[0]],
    [-k2[0], -k3[0]],
    [-k3[0], k1[0]],
]
for i in range(len(xlist)):
    x = xlist[i]
    y = ylist[i]
    plt.plot(x, y, linestyle='dotted', color='blue', alpha=0.8)
# ==============================

plt.axis('equal')
plt.axis('off')
# plt.title('Mean Fourier Power')
plt.tight_layout()

# plt.savefig('images/statespace/mean_fourier_power.pdf', format='pdf', bbox_inches='tight')
# plt.savefig('images/statespace/mean_fourier_power_res50.pdf', format='pdf', bbox_inches='tight')

# plt.savefig('images/statespace/mean_fourier_power_rnn.pdf', format='pdf', bbox_inches='tight')

In [None]:
k4 = k5 = k6 = k1

freq = 1
ks = freq * np.array([k1, k2, k3, k4, k5, k6])
ks = ks.astype('int')

modes = np.stack([rm_fft[:, k[0], k[1]] for k in ks])

# Find phases
phases = [np.angle(mode) for mode in modes]

plt.figure(figsize=(15,5))

plt.subplot(131)
plt.scatter(phases[0], phases[1], c='black', s=10)
plt.xlabel(r'$\phi_1$')
plt.ylabel(r'$\phi_2$')

plt.subplot(132)
plt.scatter(phases[1], phases[2], c='black', s=10)
plt.xlabel(r'$\phi_2$')
plt.ylabel(r'$\phi_3$')

plt.subplot(133)
plt.scatter(phases[2], phases[0], c='black', s=10)
plt.xlabel(r'$\phi_3$')
plt.ylabel(r'$\phi_1$')

In [None]:
# from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(10,8))
ax = fig.add_subplot(111, projection='3d')
# ax = Axes3D(fig)

ax.scatter(phases[0], phases[1], phases[2], c='black', s=10)
ax.view_init(azim=60)

# ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
# ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
# ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))

# ax.xaxis._axinfo["grid"]['color'] = (1,1,1,0)
# ax.yaxis._axinfo["grid"]['color'] = (1,1,1,0)
# ax.zaxis._axinfo["grid"]['color'] = (1,1,1,0)

ax.set_xlabel(r'$\phi_1$', fontsize=16)
ax.set_ylabel(r'$\phi_2$', fontsize=16)
ax.set_zlabel(r'$\phi_3$', fontsize=16)

# plt.title('Phase distribution')
plt.tight_layout()

# plt.savefig('images/statespace/phase_distribution.pdf', format='pdf', bbox_inches='tight')
# plt.savefig('images/statespace/phase_distribution_200.pdf', format='pdf', bbox_inches='tight')

# plt.savefig('images/statespace/phase_distribution_rnn.pdf', format='pdf', bbox_inches='tight')

In [None]:
R = 5.0
r = 2.0

# =====================
phi   = phases[0]
theta = phases[1]
# =====================

x = (R + r * np.cos(phi)) * np.cos(theta)
y = (R + r * np.cos(phi)) * np.sin(theta)
z = r * np.sin(phi)

fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')

ax.scatter(
    x, y, z, 
    # c='lightgreen', 
    c=np.cos((phi - 0.3) % (2 * np.pi)), 
    # c=z, 
    cmap='viridis', 
    # cmap='summer', 
    alpha=0.8,
    s=60,
)

ax.axis('off')
ax.view_init(elev=60, azim=0)

# ax.set_xlim([-5, 5])
# ax.set_ylim([-5, 5])
# ax.set_zlim([-10, 10])
# ax.set_zlim(-r/2, r/2)

# Toroidal topology
plt.tight_layout()

# plt.savefig('images/statespace/toroidal_topology.pdf', format='pdf', bbox_inches='tight')
# plt.savefig('images/statespace/toroidal_topology_rnn.pdf', format='pdf', bbox_inches='tight')

In [None]:
import random

freq = 1
crop = 0
cmaps = ['Blues', 'Oranges', 'Greens']

x = np.mgrid[:res, :res] * 2 * np.pi / res
x = x.reshape(2, -1)
k = freq * np.stack([k1, k2, k3])
X = np.concatenate([np.cos(k.dot(x)), np.sin(k.dot(x))], axis=0)
# print(f'X[0].shape={X[0].shape}')

idxs1, idxs2 = np.mgrid[crop:res-crop, crop:res-crop]
idxs = np.ravel_multi_index((idxs1, idxs2), (res, res)).ravel()
# print(f'idxs={idxs}')


plt.figure(figsize=(12,4))

for i in range(3):
    plt.subplot(1, 3, i+1)
    
    B = np.stack([np.cos(phases[i]), np.sin(phases[i])])
    test = B @ (rate_map.reshape([Ng, -1]))
    
    plt.scatter(test[0], test[1], c=X[i][idxs], cmap=cmaps[i], s=30)
    
    plt.axis('off')
    plt.title(f'$k_{i+1}$')
    
# plt.savefig('images/statespace/projections.pdf', format='pdf', bbox_inches='tight')
# plt.savefig('images/statespace/projections_res50.pdf', format='pdf', bbox_inches='tight')
# plt.savefig('images/statespace/projections_200_res50.pdf', format='pdf', bbox_inches='tight')


In [None]:
print(f'B.shape={B.shape}')
print(f'phases[0].shape={phases[0].shape}')