## Import Required Libraries

In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import resample
from scipy.fft import fft2, ifft2, fftshift
import os
from IPython.display import clear_output
import time


## Load the gc_periodic function

Import the main simulation function from the converted Python file.

In [4]:
# Import the gc_periodic function
from gc_periodic import gc_periodic

## Initialize Network Parameters

**Warning:** The following are the parameters used in the associated paper. Altering them will likely lead to an unsuccessful simulation.

In [79]:
# Timestep in ms
dt = 0.5

# Number of neurons (must be power of 2)
n = 2**7 # 128 neurons

# Neuron time-constant (in ms)
tau = 5


modules = [1, 2, 3, 4]  # List of active modules


def set_module_params(module_number):
    """Set parameters based on module number."""
    scale_factor = 1.4 ** (module_number - 1)
    
    # Envelope and Weight Matrix Parameters
    lambda_param = 13 * scale_factor  # Equation (3)
    beta = 3 / lambda_param**2  # Equation (3)
    alphabar = 1.05
    gamma = alphabar * beta  # alphabar = gamma/beta from Equation (3)
    a =  1 # a should be <= alphabar^2. Equation (3)
    wtphase = 2  # wtphase is 'l' from Equation (2)
    alpha = 1 / scale_factor  # The velocity gain from Equation (4
    
    return beta, gamma, a, wtphase, alpha


# Simulation options
useSpiking = False  # Set to True for spiking model, False for rate model
filename = 'trajectory_data.npz'  # Path to trajectory data (will use random if not found)

# print("Network Parameters:")
# print(f"  Number of neurons: {n} x {n}")
# print(f"  Time constant: {tau} ms")
# print(f"  Time step: {dt} ms")
# print(f"  Beta: {beta:.6f}")
# print(f"  gamma: {gamma}")
# print(f"  A bar: {abar}")
# print(f"  Weight phase: {wtphase}")
# print(f"  Velocity gain: {alpha}")
# print(f"  Spiking model: {useSpiking}")

In [95]:
# Run the simulation
print("Starting simulation...")
print("This may take several minutes depending on trajectory length.\n")

MODULE = 1

# # Run the simulation for each module
# for MODULE in modules:
#     print(f"Simulating Module {MODULE}...")
#         spikes = gc_periodic(
#             filename=filename,
#             n=n,
#             tau=tau,
#             dt=dt,
#             beta=beta,
#             gamma=gamma,
#             abar=abar,
#             wtphase=wtphase,
#             alpha=alpha,
#             useSpiking=useSpiking,
#             module=MODULE
#         )

beta, gamma, abar, wtphase, alpha = set_module_params(MODULE)

spikes = gc_periodic(
    filename=filename,
    n=n,
    tau=tau,
    dt=dt,
    beta=beta ,
    gamma=gamma,
    abar=abar,
    wtphase=wtphase,
    alpha=alpha,
    useSpiking=useSpiking,
    module = MODULE,
    GET_BAND=False,
    BAND_ANGLE=30,
    duration=1000000
)

print("\nSimulation complete!")

Starting simulation...
This may take several minutes depending on trajectory length.

Saved frame 000000 to plots/simulation/module_1/frame_000000.png
Saved frame 000001 to plots/simulation/module_1/frame_000001.png
Saved frame 000002 to plots/simulation/module_1/frame_000002.png
Saved frame 000003 to plots/simulation/module_1/frame_000003.png
Saved frame 000004 to plots/simulation/module_1/frame_000004.png
Saved frame 000005 to plots/simulation/module_1/frame_000005.png
Saved frame 000006 to plots/simulation/module_1/frame_000006.png
Saved frame 000007 to plots/simulation/module_1/frame_000007.png
Saved frame 000008 to plots/simulation/module_1/frame_000008.png
Saved frame 000009 to plots/simulation/module_1/frame_000009.png
Saved frame 000010 to plots/simulation/module_1/frame_000010.png
Saved frame 000011 to plots/simulation/module_1/frame_000011.png
Saved frame 000012 to plots/simulation/module_1/frame_000012.png
Saved frame 000013 to plots/simulation/module_1/frame_000013.png
Save

KeyboardInterrupt: 

## Analyze Results

If spiking model was used, analyze the spike data.

In [None]:
# if useSpiking and spikes is not None:
#     # Count total spikes over time
#     spike_counts = []
#     for spike_matrix in spikes:
#         if spike_matrix is not None:
#             spike_counts.append(np.sum(spike_matrix))
#         else:
#             spike_counts.append(0)
    
#     # Plot spike count over time
#     plt.figure(figsize=(12, 4))
#     plt.plot(spike_counts)
#     plt.xlabel('Time step')
#     plt.ylabel('Total spikes in population')
#     plt.title('Population Spiking Activity Over Time')
#     plt.grid(True, alpha=0.3)
#     plt.show()
    
#     print(f"Total simulation steps: {len(spikes)}")
#     print(f"Average spikes per timestep: {np.mean(spike_counts):.2f}")
#     print(f"Max spikes in single timestep: {np.max(spike_counts):.0f}")
# else:
#     print("Rate-based model used (no spike data). See visualizations from simulation above.")

Rate-based model used (no spike data). See visualizations from simulation above.


## Quick Test with Smaller Grid (Optional)

For faster testing, you can run a smaller simulation:

In [None]:
# Uncomment to run a quick test with smaller parameters
# print("Running quick test with n=64...")
# spikes_test = gc_periodic(
#     filename='nonexistent.npz',  # Will use random trajectory
#     n=64,
#     tau=tau,
#     dt=dt,
#     beta=beta,
#     alphabar=alphabar,
#     abar=abar,
#     wtphase=wtphase,
#     alpha=alpha,
#     useSpiking=False
# )

## Notes

- The simulation generates random trajectories if no data file is provided
- Network formation occurs during the first 1000 iterations (with aperiodic boundaries)
- The envelope switches to uniform input at iteration 800
- Visualizations update every 20 time steps during the simulation
- The final plots show:
  - **Top panel**: Neural population activity (rate or spike-based)
  - **Bottom panel**: Animal trajectory with single neuron firing locations marked