## Import Required Libraries

In [1]:
%load_ext autoreload
%autoreload 2

## Load the gc_periodic function

Import the main simulation function from the converted Python file.

In [3]:
# Import the gc_periodic function
from gc_periodic import gc_periodic

## Initialize Network Parameters

**Warning:** The following are the parameters used in the associated paper. Altering them will likely lead to an unsuccessful simulation.

In [None]:
# ====================================
# SIMULATION PARAMETERS
# ====================================

# Timestep in ms
dt = 0.5

# Number of neurons (must be power of 2)
n = 2**7  # 128 x 128 neural grid

# Neuron time-constant (in ms)
tau = 5

# Simulation options
useSpiking = False  # Set to True for spiking model, False for rate model
filename = 'trajectory_data.npz'  # Path to trajectory data (uses random if not found)


# ====================================
# MODULE CONFIGURATION
# ====================================

def set_module_params(module_number):
    """
    Set parameters based on module number to simulate different grid scales.
    
    Each module represents a different spatial scale of grid cells, with higher
    module numbers corresponding to larger grid spacing (following the 1.4x scaling
    ratio observed in biological grid cells).
    
    Parameters:
    -----------
    module_number : int
        Module identifier (1-4), where higher numbers = larger spatial scales
    
    Returns:
    --------
    beta : float
        Spatial scale parameter for weight matrix
    gamma : float
        Interaction range parameter
    a : float
        Excitatory amplitude in Mexican Hat kernel
    wtphase : int
        Phase shift for directional selectivity (pixels)
    alpha : float
        Velocity gain parameter for path integration
    """
    # Scale factor increases by 1.4x per module (biological observation)
    scale_factor = 1.4 ** (module_number - 1)
    
    # Spatial wavelength parameter (Equation 3 from paper)
    lambda_param = 13 * scale_factor
    
    # Weight matrix spatial scale (Equation 3)
    beta = 3 / lambda_param**2
    
    # Interaction range parameter
    alphabar = 1.05
    gamma = alphabar * beta  # Derived from alphabar = gamma/beta
    
    # Excitatory amplitude (must be <= alphabar^2 for stability)
    a = 1
    
    # Directional phase shift ('l' from Equation 2)
    wtphase = 2
    
    # Velocity gain parameter (Equation 4)
    # Scales inversely with module size for consistent path integration
    alpha = 1 / scale_factor
    alpha = 1  # Override for testing
    
    return beta, gamma, a, wtphase, alpha


#

In [None]:
 #====================================
# RUN SIMULATION
# ====================================

# Select which module to simulate
MODULE = 1  # Options: 1 (smallest), 2, 3, 4 (largest)

# Get module-specific parameters
beta, gamma, abar, wtphase, alpha = set_module_params(MODULE)

print("=" * 60)
print(f"STARTING GRID CELL SIMULATION - MODULE {MODULE}")
print("=" * 60)
print(f"Network size:      {n} x {n} neurons")
print(f"Time constant:     {tau} ms")
print(f"Time step:         {dt} ms")
print(f"Beta (scale):      {beta:.6f}")
print(f"Gamma (range):     {gamma:.6f}")
print(f"A (amplitude):     {abar}")
print(f"Phase shift:       {wtphase} pixels")
print(f"Velocity gain:     {alpha:.4f}")
print(f"Model type:        {'Spiking' if useSpiking else 'Rate-based'}")
print("=" * 60)
print("\nThis may take several minutes depending on trajectory length...\n")

# Run the main simulation
spikes, integrated_path_x_cm, integrated_path_y_cm, error, position_x, position_y = gc_periodic(
    filename=filename,
    n=n,
    tau=tau,
    dt=dt,
    beta=beta,
    gamma=gamma,
    abar=abar,
    wtphase=wtphase,
    alpha=alpha,
    useSpiking=useSpiking,
    module=MODULE,
    GET_BAND=False,      # Use isotropic grid cells (not band cells)
    BAND_ANGLE=0,        # N/A for grid cells
    duration=1000000     # Steps for random trajectory if no data file
)

print("\n" + "=" * 60)
print("SIMULATION COMPLETE!")
print("=" * 60)


# ====================================
# MULTI-MODULE SIMULATION (Optional)
# ====================================
# Uncomment to run all modules sequentially:
#
# modules_to_run = [1, 2, 3, 4]
# results = {}
#
# for module_id in modules_to_run:
#     print(f"\n{'='*60}")
#     print(f"Running Module {module_id}...")
#     print(f"{'='*60}")
#     
#     beta, gamma, abar, wtphase, alpha = set_module_params(module_id)
#     
#     results[module_id] = gc_periodic(
#         filename=filename,
#         n=n,
#         tau=tau,
#         dt=dt,
#         beta=beta,
#         gamma=gamma,
#         abar=abar,
#         wtphase=wtphase,
#         alpha=alpha,
#         useSpiking=useSpiking,
#         module=module_id,
#         GET_BAND=False,
#         BAND_ANGLE=0,
#         duration=1000000
#     )


## Analyze Results

Visualize spike statistics if using the spiking model, or view the path integration plots generated during simulation if using the rate-based model.

In [None]:
# ====================================
# SPIKE ANALYSIS (For Spiking Model)
# ====================================

if useSpiking and spikes is not None:
    import matplotlib.pyplot as plt
    import numpy as np
    
    # Count total spikes at each time step
    spike_counts = []
    for spike_matrix in spikes:
        if spike_matrix is not None:
            spike_counts.append(np.sum(spike_matrix))
        else:
            spike_counts.append(0)
    
    # Visualize population spiking activity over time
    plt.figure(figsize=(12, 4))
    plt.plot(spike_counts, linewidth=0.5, alpha=0.8)
    plt.xlabel('Time Step', fontsize=12)
    plt.ylabel('Total Spikes in Population', fontsize=12)
    plt.title(f'Population Spiking Activity Over Time (Module {MODULE})', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    print("\n" + "="*50)
    print("SPIKE STATISTICS")
    print("="*50)
    print(f"Total simulation steps:    {len(spikes)}")
    print(f"Average spikes/timestep:   {np.mean(spike_counts):.2f}")
    print(f"Max spikes (single step):  {np.max(spike_counts):.0f}")
    print(f"Min spikes (single step):  {np.min(spike_counts):.0f}")
    print(f"Total spikes:              {np.sum(spike_counts):.0f}")
    print("="*50)
else:
    print("\nRate-based model used (no spike data).")
    print("See path integration plots generated during simulation.")


In [None]:
# ====================================
# QUICK TEST (Optional)
# ====================================
# Uncomment to run a quick test with smaller grid size for debugging:
#
# print("\n" + "="*60)
# print("RUNNING QUICK TEST (n=64)")
# print("="*60)
# 
# beta_test, gamma_test, abar_test, wtphase_test, alpha_test = set_module_params(1)
# 
# spikes_test, int_x_test, int_y_test, err_test, pos_x_test, pos_y_test = gc_periodic(
#     filename='nonexistent.npz',  # Forces random trajectory generation
#     n=64,                          # Smaller grid for faster testing
#     tau=tau,
#     dt=dt,
#     beta=beta_test,
#     gamma=gamma_test,
#     abar=abar_test,
#     wtphase=wtphase_test,
#     alpha=alpha_test,
#     useSpiking=False,
#     module=0,
#     duration=10000                 # Shorter trajectory
# )
# 
# print("\nTest complete!")
