In [None]:
from canns.data import load_example_data

# Load example 1D neural activity data
data_dict = load_example_data('bump_1d_example')

print(f"Available keys: {data_dict.keys()}")
print(f"Activity shape: {data_dict['activity'].shape}")
print(f"Time points: {data_dict['time'].shape}")

In [None]:
import matplotlib.pyplot as plt
import jax.numpy as jnp  # :cite:p:`jax2018github`

# Plot activity heatmap
plt.figure(figsize=(10, 4))
plt.imshow(data_dict['activity'].T, aspect='auto', cmap='viridis')
plt.xlabel('Time step')
plt.ylabel('Neuron index')
plt.title('Neural Population Activity')
plt.colorbar(label='Activity')
plt.show()

# Plot activity at one time point
plt.figure(figsize=(8, 3))
plt.plot(data_dict['positions'], data_dict['activity'][100])
plt.xlabel('Position (rad)')
plt.ylabel('Activity')
plt.title('Activity snapshot at t=100')
plt.grid(True)
plt.show()

In [None]:
from canns.analyzer.data import BumpAnalyzer1D

# Create analyzer
analyzer = BumpAnalyzer1D(positions=data_dict['positions'])

# Fit bumps to all time points
results = analyzer.fit_bumps(data_dict['activity'])

print(f"Detected bump centers: {results['centers'][:10]}")  # First 10
print(f"Bump widths: {results['widths'][:10]}")
print(f"Bump amplitudes: {results['amplitudes'][:10]}")

In [None]:
plt.figure(figsize=(10, 4))
plt.plot(data_dict['time'], results['centers'], linewidth=2)
plt.xlabel('Time (ms)')
plt.ylabel('Bump Position (rad)')
plt.title('Decoded Bump Trajectory')
plt.grid(True)
plt.show()

# Plot bump width over time
plt.figure(figsize=(10, 4))
plt.plot(data_dict['time'], results['widths'], linewidth=2, color='orange')
plt.xlabel('Time (ms)')
plt.ylabel('Bump Width (rad)')
plt.title('Bump Width Dynamics')
plt.grid(True)
plt.show()

In [None]:
# Plot fit quality over time
plt.figure(figsize=(10, 4))
plt.plot(data_dict['time'], results['fit_quality'], linewidth=2)
plt.axhline(y=0.8, color='r', linestyle='--', label='Quality threshold')
plt.xlabel('Time (ms)')
plt.ylabel('Fit Quality (R²)')
plt.title('Bump Fit Quality')
plt.legend()
plt.grid(True)
plt.show()

# Identify low-quality fits
low_quality_indices = jnp.where(results['fit_quality'] < 0.8)[0]
print(f"Time points with poor fits: {len(low_quality_indices)} / {len(data_dict['time'])}")

In [None]:
from canns.data import load_example_data
from canns.analyzer.data import BumpAnalyzer1D
import matplotlib.pyplot as plt

# 1. Load data
data = load_example_data('bump_1d_example')

# 2. Create analyzer
analyzer = BumpAnalyzer1D(positions=data['positions'])

# 3. Fit bumps
results = analyzer.fit_bumps(data['activity'])

# 4. Visualize trajectory
plt.figure(figsize=(12, 6))

# Subplot 1: Activity heatmap with fitted centers overlaid
plt.subplot(2, 1, 1)
plt.imshow(data['activity'].T, aspect='auto', cmap='viridis', extent=[0, len(data['time']), data['positions'][0], data['positions'][-1]])
plt.plot(range(len(results['centers'])), results['centers'], 'r-', linewidth=2, label='Fitted bump center')
plt.ylabel('Position (rad)')
plt.title('Neural Activity with Detected Bump Trajectory')
plt.legend()
plt.colorbar(label='Activity')

# Subplot 2: Bump position over time
plt.subplot(2, 1, 2)
plt.plot(data['time'], results['centers'], linewidth=2)
plt.xlabel('Time (ms)')
plt.ylabel('Bump Center (rad)')
plt.title('Decoded Position Trajectory')
plt.grid(True)

plt.tight_layout()
plt.savefig('experimental_bump_analysis.png', dpi=150)
plt.show()

print("Analysis complete! Results saved.")

In [None]:
from canns.analyzer.data import TopologyAnalyzer

# Analyze topological features in neural dynamics
tda = TopologyAnalyzer()
persistence = tda.compute_persistence(data['activity'])

In [None]:
from canns.analyzer.data import RNNAnalyzer

# Find fixed points :cite:p:`sussillo2013opening,golub2018fixedpointfinder` and slow manifolds in trained RNN models
rnn_analyzer = RNNAnalyzer(model=my_rnn)
fixed_points = rnn_analyzer.find_fixed_points()

In [None]:
import jax.numpy as jnp  # :cite:p:`jax2018github`

# Load from numpy array, CSV, or other format
my_activity = jnp.load('my_experiment.npy')  # Shape: (time, neurons)
my_positions = jnp.linspace(-3.14, 3.14, num_neurons)

# Create analyzer with your neuron positions
analyzer = BumpAnalyzer1D(positions=my_positions)

# Analyze
results = analyzer.fit_bumps(my_activity)

In [None]:
import jax.numpy as jnp  # :cite:p:`jax2018github`

# Remove NaN rows
valid_indices = ~jnp.isnan(activity).any(axis=1)
clean_activity = activity[valid_indices]