# Cross-Regional Communication Analysis of Steinmetz Dataset

This notebook analyzes interactions between brain regions, including:
1. Inter-regional spike correlations
2. LFP coherence between regions
3. Information flow analysis
4. Region-specific population dynamics
5. Task-dependent connectivity

In [None]:
import sys
sys.path.append('../src')

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import signal, stats
from sklearn.preprocessing import StandardScaler
from data_loader import SteinmetzDataLoader
from neural_analysis import NeuralAnalyzer

# Set plotting style
plt.style.use('seaborn')
sns.set_context("talk")

## 1. Data Loading and Preparation

In [None]:
# Initialize data loader and load session
loader = SteinmetzDataLoader()
loader.download_data()
session_data = loader.load_session(11)  # Using session 11 as an example

# Initialize neural analyzer
analyzer = NeuralAnalyzer()

print("Available brain regions:", np.unique(session_data['brain_area_lfp']))

## 2. Inter-regional Spike Correlations

In [None]:
def compute_region_correlations(spikes, brain_regions, time_window=(-0.5, 0.5), bin_size=0.01):
    """Compute correlations between neurons from different regions."""
    unique_regions = np.unique(brain_regions)
    n_regions = len(unique_regions)
    correlation_matrix = np.zeros((n_regions, n_regions))
    
    # Compute firing rates
    time_bins = np.arange(time_window[0], time_window[1] + bin_size, bin_size)
    firing_rates = loader.compute_firing_rates(spikes, time_bins)
    
    # Compute correlations between regions
    for i, region1 in enumerate(unique_regions):
        neurons1 = np.where(brain_regions == region1)[0]
        rates1 = firing_rates[neurons1].mean(axis=0)
        
        for j, region2 in enumerate(unique_regions):
            neurons2 = np.where(brain_regions == region2)[0]
            rates2 = firing_rates[neurons2].mean(axis=0)
            
            correlation_matrix[i, j], _ = stats.pearsonr(rates1, rates2)
    
    return correlation_matrix, unique_regions

# Example brain regions (replace with actual regions)
example_regions = np.random.choice(['V1', 'V2', 'MT', 'LGN'], size=len(session_data['spikes']))

# Compute and plot correlation matrix
corr_matrix, regions = compute_region_correlations(session_data['spikes'], example_regions)

plt.figure(figsize=(10, 8))
sns.heatmap(corr_matrix, xticklabels=regions, yticklabels=regions,
            cmap='RdBu_r', center=0, vmin=-1, vmax=1)
plt.title('Inter-regional Spike Correlations')
plt.show()

## 3. LFP Coherence Analysis

In [None]:
def compute_frequency_band_coherence(lfp1, lfp2, band_range=(4, 8), fs=100):
    """Compute coherence in specific frequency band (e.g., theta: 4-8 Hz)."""
    f, Cxy = signal.coherence(lfp1, lfp2, fs=fs)
    band_mask = (f >= band_range[0]) & (f <= band_range[1])
    return np.mean(Cxy[band_mask])

# Define frequency bands
bands = {
    'theta': (4, 8),
    'alpha': (8, 13),
    'beta': (13, 30),
    'gamma': (30, 80)
}

# Compute coherence for each band between regions
unique_areas = np.unique(session_data['brain_area_lfp'])
n_areas = len(unique_areas)
band_coherence = {band: np.zeros((n_areas, n_areas)) for band in bands}

for band_name, freq_range in bands.items():
    for i, area1 in enumerate(unique_areas):
        for j, area2 in enumerate(unique_areas):
            if i != j:
                # Get LFP from both regions
                lfp1 = session_data['lfp'][:, session_data['brain_area_lfp'] == area1].mean(axis=1)
                lfp2 = session_data['lfp'][:, session_data['brain_area_lfp'] == area2].mean(axis=1)
                
                # Compute coherence
                band_coherence[band_name][i, j] = compute_frequency_band_coherence(
                    lfp1, lfp2, freq_range
                )

# Plot coherence matrices for each frequency band
fig, axes = plt.subplots(2, 2, figsize=(15, 15))
axes = axes.ravel()

for i, (band_name, coherence) in enumerate(band_coherence.items()):
    sns.heatmap(coherence, xticklabels=unique_areas, yticklabels=unique_areas,
                cmap='viridis', vmin=0, vmax=1, ax=axes[i])
    axes[i].set_title(f'{band_name.capitalize()} Band Coherence')

plt.tight_layout()
plt.show()

## 4. Information Flow Analysis

In [None]:
def compute_granger_causality(signal1, signal2, max_lag=10):
    """Compute Granger causality between two signals."""
    from statsmodels.tsa.stattools import grangercausalitytests
    
    # Prepare data
    data = np.column_stack([signal1, signal2])
    
    # Test Granger causality in both directions
    gc_1to2 = grangercausalitytests(data, maxlag=max_lag, verbose=False)
    gc_2to1 = grangercausalitytests(np.column_stack([signal2, signal1]), 
                                    maxlag=max_lag, verbose=False)
    
    # Extract test statistics
    stats_1to2 = [gc_1to2[i+1][0]['ssr_chi2test'][1] for i in range(max_lag)]
    stats_2to1 = [gc_2to1[i+1][0]['ssr_chi2test'][1] for i in range(max_lag)]
    
    return np.min(stats_1to2), np.min(stats_2to1)

# Compute Granger causality between regions
flow_matrix = np.zeros((n_areas, n_areas))

for i, area1 in enumerate(unique_areas):
    for j, area2 in enumerate(unique_areas):
        if i != j:
            lfp1 = session_data['lfp'][:, session_data['brain_area_lfp'] == area1].mean(axis=1)
            lfp2 = session_data['lfp'][:, session_data['brain_area_lfp'] == area2].mean(axis=1)
            
            gc_1to2, gc_2to1 = compute_granger_causality(lfp1, lfp2)
            flow_matrix[i, j] = -np.log10(gc_1to2)

# Plot information flow matrix
plt.figure(figsize=(10, 8))
sns.heatmap(flow_matrix, xticklabels=unique_areas, yticklabels=unique_areas,
            cmap='YlOrRd')
plt.title('Information Flow between Regions')
plt.show()

## 5. Task-Dependent Connectivity

In [None]:
def analyze_task_dependent_connectivity(lfp_data, brain_areas, time_window=(-0.5, 0.5)):
    """Analyze how regional connectivity changes during the task."""
    # Split time window into early and late periods
    mid_point = len(lfp_data) // 2
    early_data = lfp_data[:mid_point]
    late_data = lfp_data[mid_point:]
    
    # Compute coherence for each period
    unique_areas = np.unique(brain_areas)
    n_areas = len(unique_areas)
    early_coherence = np.zeros((n_areas, n_areas))
    late_coherence = np.zeros((n_areas, n_areas))
    
    for i, area1 in enumerate(unique_areas):
        for j, area2 in enumerate(unique_areas):
            if i != j:
                # Early period
                lfp1_early = early_data[:, brain_areas == area1].mean(axis=1)
                lfp2_early = early_data[:, brain_areas == area2].mean(axis=1)
                f, Cxy = signal.coherence(lfp1_early, lfp2_early, fs=100)
                early_coherence[i, j] = np.mean(Cxy)
                
                # Late period
                lfp1_late = late_data[:, brain_areas == area1].mean(axis=1)
                lfp2_late = late_data[:, brain_areas == area2].mean(axis=1)
                f, Cxy = signal.coherence(lfp1_late, lfp2_late, fs=100)
                late_coherence[i, j] = np.mean(Cxy)
    
    return early_coherence, late_coherence, unique_areas

# Analyze task-dependent connectivity
early_coh, late_coh, areas = analyze_task_dependent_connectivity(
    session_data['lfp'],
    session_data['brain_area_lfp']
)

# Plot comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

sns.heatmap(early_coh, xticklabels=areas, yticklabels=areas,
            cmap='viridis', vmin=0, vmax=1, ax=ax1)
ax1.set_title('Early Task Period Coherence')

sns.heatmap(late_coh, xticklabels=areas, yticklabels=areas,
            cmap='viridis', vmin=0, vmax=1, ax=ax2)
ax2.set_title('Late Task Period Coherence')

plt.tight_layout()
plt.show()

# Compute and plot difference
plt.figure(figsize=(8, 6))
sns.heatmap(late_coh - early_coh, xticklabels=areas, yticklabels=areas,
            cmap='RdBu_r', center=0)
plt.title('Change in Coherence (Late - Early)')
plt.show()