# LFP Analysis of Steinmetz Dataset

This notebook analyzes Local Field Potential (LFP) data, including:
1. Power spectral analysis across brain regions
2. Phase relationships between regions
3. LFP patterns during decision-making
4. Spike-LFP relationships

In [None]:
import sys
sys.path.append('../src')

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import signal
from data_loader import SteinmetzDataLoader
from neural_analysis import NeuralAnalyzer

# Set plotting style
plt.style.use('seaborn')
sns.set_context("talk")

## 1. Data Loading

In [None]:
# Initialize data loader and load session
loader = SteinmetzDataLoader()
loader.download_data()
session_data = loader.load_session(11)  # Using session 11 as an example

# Initialize neural analyzer
analyzer = NeuralAnalyzer()

print("LFP shape:", session_data['lfp'].shape)
print("Brain areas:", np.unique(session_data['brain_area_lfp']))

## 2. Power Spectral Analysis

In [None]:
def plot_power_spectrum_by_region(lfp_data, brain_areas, freq_range=(1, 100)):
    """Plot power spectra for different brain regions."""
    unique_areas = np.unique(brain_areas)
    n_areas = len(unique_areas)
    
    fig, axes = plt.subplots(n_areas, 1, figsize=(12, 4*n_areas))
    
    for i, area in enumerate(unique_areas):
        area_channels = np.where(brain_areas == area)[0]
        
        # Average LFP across channels in this area
        area_lfp = np.mean(lfp_data[:, area_channels], axis=1)
        
        # Compute power spectrum
        freqs, power = analyzer.compute_lfp_power(area_lfp, freq_range)
        
        if n_areas == 1:
            ax = axes
        else:
            ax = axes[i]
            
        ax.semilogy(freqs, power)
        ax.set_xlabel('Frequency (Hz)')
        ax.set_ylabel('Power')
        ax.set_title(f'Power Spectrum - {area}')
        ax.grid(True)
    
    plt.tight_layout()
    return fig

# Plot power spectra for different brain regions
plot_power_spectrum_by_region(session_data['lfp'], session_data['brain_area_lfp'])
plt.show()

## 3. Time-Frequency Analysis

In [None]:
def compute_spectrogram(lfp_data, fs=100):
    """Compute spectrogram using short-time Fourier transform."""
    f, t, Sxx = signal.spectrogram(lfp_data, fs=fs, nperseg=256, noverlap=128)
    return f, t, Sxx

# Select an example channel
channel_idx = 0
lfp_channel = session_data['lfp'][:, channel_idx]

# Compute and plot spectrogram
f, t, Sxx = compute_spectrogram(lfp_channel)

plt.figure(figsize=(12, 6))
plt.pcolormesh(t, f, 10 * np.log10(Sxx), shading='gouraud')
plt.ylabel('Frequency (Hz)')
plt.xlabel('Time (s)')
plt.title(f'Spectrogram - {session_data["brain_area_lfp"][channel_idx]}')
plt.colorbar(label='Power (dB)')
plt.show()

## 4. Cross-Regional Coherence

In [None]:
def compute_coherence_matrix(lfp_data, brain_areas):
    """Compute coherence between all pairs of brain regions."""
    unique_areas = np.unique(brain_areas)
    n_areas = len(unique_areas)
    coherence_matrix = np.zeros((n_areas, n_areas))
    
    for i, area1 in enumerate(unique_areas):
        channels1 = np.where(brain_areas == area1)[0]
        lfp1 = np.mean(lfp_data[:, channels1], axis=1)
        
        for j, area2 in enumerate(unique_areas):
            channels2 = np.where(brain_areas == area2)[0]
            lfp2 = np.mean(lfp_data[:, channels2], axis=1)
            
            f, Cxy = signal.coherence(lfp1, lfp2, fs=100)
            coherence_matrix[i, j] = np.mean(Cxy)
    
    return coherence_matrix, unique_areas

# Compute and plot coherence matrix
coherence_matrix, areas = compute_coherence_matrix(
    session_data['lfp'], 
    session_data['brain_area_lfp']
)

plt.figure(figsize=(10, 8))
sns.heatmap(coherence_matrix, xticklabels=areas, yticklabels=areas,
            cmap='viridis', vmin=0, vmax=1)
plt.title('Inter-regional LFP Coherence')
plt.show()

## 5. Spike-LFP Relationships

In [None]:
# Select a neuron and nearby LFP channel
neuron_idx = 0
lfp_channel = 0

# Compute spike-triggered average of LFP
time_points, avg_lfp = analyzer.compute_spike_triggered_lfp(
    np.concatenate(session_data['spikes'][neuron_idx]),
    session_data['lfp'][:, lfp_channel],
    window=(-0.1, 0.1)  # 100ms before and after spike
)

# Plot spike-triggered average
plt.figure(figsize=(10, 6))
plt.plot(time_points * 1000, avg_lfp)  # Convert to milliseconds
plt.axvline(x=0, color='r', linestyle='--', alpha=0.5)
plt.xlabel('Time from spike (ms)')
plt.ylabel('Average LFP')
plt.title('Spike-Triggered Average LFP')
plt.grid(True)
plt.show()