# Population Dynamics Analysis of Steinmetz Dataset

This notebook analyzes neural population dynamics, including:
1. Dimensionality reduction using PCA
2. Neural trajectories during decision-making
3. Population coding of task variables
4. State space analysis

In [None]:
import sys
sys.path.append('../src')

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from data_loader import SteinmetzDataLoader
from neural_analysis import NeuralAnalyzer

# Set plotting style
plt.style.use('seaborn')
sns.set_context("talk")

## 1. Data Loading and Preparation

In [None]:
# Initialize data loader and load session
loader = SteinmetzDataLoader()
loader.download_data()
session_data = loader.load_session(11)  # Using session 11 as an example

# Initialize neural analyzer
analyzer = NeuralAnalyzer()

# Prepare population activity data
time_bins = np.arange(-0.5, 0.5, 0.01)  # -500ms to 500ms, 10ms bins
firing_rates = loader.compute_firing_rates(session_data['spikes'], time_bins)

print(f"Population activity shape: {firing_rates.shape}")

## 2. Dimensionality Reduction

In [None]:
# Standardize the data
scaler = StandardScaler()
scaled_rates = scaler.fit_transform(firing_rates.T).T

# Apply PCA
pca = PCA()
projected_data = pca.fit_transform(scaled_rates.T).T

# Plot explained variance ratio
plt.figure(figsize=(10, 6))
plt.plot(np.cumsum(pca.explained_variance_ratio_))
plt.xlabel('Number of Components')
plt.ylabel('Cumulative Explained Variance Ratio')
plt.title('PCA Explained Variance')
plt.grid(True)
plt.show()

# Print number of components needed for 80% variance
n_components_80 = np.where(np.cumsum(pca.explained_variance_ratio_) >= 0.8)[0][0] + 1
print(f"Number of components needed for 80% variance: {n_components_80}")

## 3. Neural Trajectories

In [None]:
def plot_3d_trajectory(projected_data, time_points):
    """Plot neural trajectory in 3D space of first three PCs."""
    from mpl_toolkits.mplot3d import Axes3D
    
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    scatter = ax.scatter(projected_data[0], 
                        projected_data[1], 
                        projected_data[2],
                        c=time_points,
                        cmap='viridis')
    
    plt.colorbar(scatter, label='Time (s)')
    ax.set_xlabel('PC1')
    ax.set_ylabel('PC2')
    ax.set_zlabel('PC3')
    plt.title('Neural Population Trajectory')
    
    return fig

# Plot neural trajectory
plot_3d_trajectory(projected_data[:3], time_bins[:-1])
plt.show()

## 4. Population Dynamics by Brain Region

In [None]:
def analyze_regional_dynamics(firing_rates, brain_regions):
    """Analyze population dynamics separately for each brain region."""
    unique_regions = np.unique(brain_regions)
    
    fig, axes = plt.subplots(len(unique_regions), 1, 
                            figsize=(12, 4*len(unique_regions)))
    
    for i, region in enumerate(unique_regions):
        # Get neurons from this region
        region_mask = brain_regions == region
        region_rates = firing_rates[region_mask]
        
        if len(region_rates) > 0:
            # Apply PCA to regional activity
            pca = PCA(n_components=3)
            projected = pca.fit_transform(region_rates.T).T
            
            # Plot first PC
            if len(unique_regions) == 1:
                ax = axes
            else:
                ax = axes[i]
                
            ax.plot(time_bins[:-1], projected[0])
            ax.set_title(f'{region} (n={len(region_rates)} neurons)')
            ax.set_xlabel('Time (s)')
            ax.set_ylabel('PC1')
            ax.axvline(x=0, color='r', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    return fig

# Example brain regions (replace with actual regions from your data)
example_regions = np.random.choice(['V1', 'V2', 'MT'], size=len(session_data['spikes']))
analyze_regional_dynamics(firing_rates, example_regions)
plt.show()

## 5. State Space Analysis

In [None]:
def compute_state_space_metrics(projected_data):
    """Compute metrics in neural state space."""
    # Compute speed (rate of change)
    velocity = np.diff(projected_data, axis=1)
    speed = np.sqrt(np.sum(velocity**2, axis=0))
    
    # Compute distance from starting point
    start_point = projected_data[:, 0:1]
    distances = np.sqrt(np.sum((projected_data - start_point)**2, axis=0))
    
    return speed, distances

# Compute and plot state space metrics
speed, distances = compute_state_space_metrics(projected_data[:3])

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

# Plot speed
ax1.plot(time_bins[1:-1], speed)
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Speed in State Space')
ax1.axvline(x=0, color='r', linestyle='--', alpha=0.5)

# Plot distance from start
ax2.plot(time_bins[:-1], distances)
ax2.set_xlabel('Time (s)')
ax2.set_ylabel('Distance from Start')
ax2.axvline(x=0, color='r', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()