# Steinmetz Data Integration Notebook
#### This notebook provides functions for loading and integrating the three Steinmetz datasets:
- steinmetz_st.npz (spike times)
- steinmetz_lfp.npz (local field potentials)
- steinmetz_wav.npz (waveforms)

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os, requests
import pandas as pd
from matplotlib import rcParams
import seaborn as sns
from scipy import signal
from sklearn.decomposition import PCA

# Set matplotlib defaults
rcParams['figure.figsize'] = [20, 4]
rcParams['font.size'] = 15
rcParams['axes.spines.top'] = False
rcParams['axes.spines.right'] = False
rcParams['figure.autolayout'] = True

## Cell 1: Data Download Function

In [2]:
def download_steinmetz_data(data_dir='.'):
    """
    Download Steinmetz datasets if they don't exist locally
    
    Parameters:
    data_dir : str
        Directory where data should be stored
    """
    # Create data directory if it doesn't exist
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    
    # Define filenames and URLs
    fname = ['steinmetz_st.npz', 'steinmetz_wav.npz', 'steinmetz_lfp.npz']
    url = [
        "https://osf.io/4bjns/download",
        "https://osf.io/ugm9v/download",
        "https://osf.io/kx3v9/download"
    ]
    
    # Download each file if it doesn't exist
    for j in range(len(url)):
        file_path = os.path.join(data_dir, fname[j])
        if not os.path.isfile(file_path):
            try:
                print(f"Downloading {fname[j]}...")
                r = requests.get(url[j])
                if r.status_code == requests.codes.ok:
                    with open(file_path, "wb") as fid:
                        fid.write(r.content)
                    print(f"Successfully downloaded {fname[j]}")
                else:
                    print(f"Failed to download {fname[j]}, status code: {r.status_code}")
            except requests.ConnectionError:
                print(f"Connection error while downloading {fname[j]}")
        else:
            print(f"{fname[j]} already exists in {data_dir}")
    
    print("Download process complete.")

In [3]:

# Cell 2: Data Loading Function
def load_steinmetz_data(data_dir='.', verbose=True):
    """
    Load all three Steinmetz datasets
    
    Parameters:
    data_dir : str
        Directory where data is stored
    verbose : bool
        Whether to print detailed information
        
    Returns:
    st_data : numpy.ndarray
        Spike times data
    lfp_data : numpy.ndarray
        LFP data
    wav_data : numpy.ndarray
        Waveform data
    """
    # Check if files exist, download if needed
    for fname in ['steinmetz_st.npz', 'steinmetz_wav.npz', 'steinmetz_lfp.npz']:
        if not os.path.isfile(os.path.join(data_dir, fname)):
            print(f"File {fname} not found. Downloading data...")
            download_steinmetz_data(data_dir)
            break
    
    # Load datasets
    if verbose:
        print("Loading datasets...")
    
    st_path = os.path.join(data_dir, 'steinmetz_st.npz')
    lfp_path = os.path.join(data_dir, 'steinmetz_lfp.npz')
    wav_path = os.path.join(data_dir, 'steinmetz_wav.npz')
    
    st_data = np.load(st_path, allow_pickle=True)['dat']
    lfp_data = np.load(lfp_path, allow_pickle=True)['dat']
    wav_data = np.load(wav_path, allow_pickle=True)['dat']
    
    if verbose:
        print(f"Loaded {len(st_data)} sessions from spike times data")
        print(f"Loaded {len(lfp_data)} sessions from LFP data")
        print(f"Loaded {len(wav_data)} sessions from waveform data")
        
        # Print keys for the first session of each dataset
        print("\nKeys in first session of each dataset:")
        print(f"Spike times data: {list(st_data[0].keys())}")
        print(f"LFP data: {list(lfp_data[0].keys())}")
        print(f"Waveform data: {list(wav_data[0].keys())}")
    
    return st_data, lfp_data, wav_data

In [4]:

# Cell 3: Data Integration Function
def integrate_steinmetz_data(st_data, lfp_data, wav_data, verbose=True):
    """
    Integrate data from all three Steinmetz datasets
    
    Parameters:
    st_data : numpy.ndarray
        Spike times data
    lfp_data : numpy.ndarray
        LFP data
    wav_data : numpy.ndarray
        Waveform data
    verbose : bool
        Whether to print detailed information
        
    Returns:
    integrated_data : list
        List of dictionaries, each containing integrated data for one session
    """
    # Verify all datasets have the same number of sessions
    num_st_sessions = len(st_data)
    num_lfp_sessions = len(lfp_data)
    num_wav_sessions = len(wav_data)
    
    if not (num_st_sessions == num_lfp_sessions == num_wav_sessions):
        print("WARNING: The datasets have different numbers of sessions!")
        num_sessions = min(num_st_sessions, num_lfp_sessions, num_wav_sessions)
        print(f"Will proceed with the first {num_sessions} sessions only")
    else:
        num_sessions = num_st_sessions
        if verbose:
            print(f"All datasets have {num_sessions} sessions")
    
    # Create integrated dataset
    integrated_data = []
    
    for i in range(num_sessions):
        if verbose and i % 10 == 0:
            print(f"Integrating session {i}...")
        
        # Create a new dictionary for this session
        session_data = {}
        
        # Add all keys from spike times data
        for key in st_data[i].keys():
            session_data[key] = st_data[i][key]
        
        # Add all keys from LFP data
        for key in lfp_data[i].keys():
            # Avoid key collisions by adding prefix if key already exists
            if key in session_data:
                session_data[f"lfp_{key}"] = lfp_data[i][key]
            else:
                session_data[key] = lfp_data[i][key]
        
        # Add all keys from waveform data
        for key in wav_data[i].keys():
            # Avoid key collisions by adding prefix if key already exists
            if key in session_data:
                session_data[f"wav_{key}"] = wav_data[i][key]
            else:
                session_data[key] = wav_data[i][key]
        
        # Add the session to the list
        integrated_data.append(session_data)
    
    if verbose:
        print(f"Successfully integrated {len(integrated_data)} sessions")
        print(f"Keys in first integrated session: {list(integrated_data[0].keys())}")
    
    return integrated_data

In [6]:


# Cell 4: Verification Function
def verify_data_correspondence(integrated_data, session_idx=0, verbose=True):
    """
    Verify that data from different datasets corresponds correctly
    
    Parameters:
    integrated_data : list
        List of integrated session data
    session_idx : int
        Index of session to verify
    verbose : bool
        Whether to print detailed information
        
    Returns:
    correspondence : dict
        Dictionary with verification results
    """
    if session_idx >= len(integrated_data):
        print(f"Session index {session_idx} out of range (max: {len(integrated_data)-1})")
        return None
    
    session = integrated_data[session_idx]
    correspondence = {"session_idx": session_idx, "checks": {}}
    
    # Check 1: Neuron count consistency
    if 'ss' in session and 'waveform_w' in session:
        num_neurons_st = len(session['ss'])
        num_neurons_wav = session['waveform_w'].shape[0]
        
        correspondence["checks"]["neuron_count_match"] = (num_neurons_st == num_neurons_wav)
        
        if verbose:
            print(f"Neuron count in spike times data: {num_neurons_st}")
            print(f"Neuron count in waveform data: {num_neurons_wav}")
            print(f"Neuron counts match: {correspondence['checks']['neuron_count_match']}")
    
    # Check 2: Trial count consistency
    if 'ss' in session and len(session['ss']) > 0:
        num_trials = len(session['ss'][0])
        
        if verbose:
            print(f"Number of trials: {num_trials}")
        
        correspondence["checks"]["trial_count"] = num_trials
    
    # Check 3: LFP data shape
    if 'lfp' in session:
        lfp_shape = session['lfp'].shape
        
        if verbose:
            print(f"LFP data shape: {lfp_shape}")
        
        correspondence["checks"]["lfp_shape"] = lfp_shape
    
    # Check 4: Brain areas
    if 'brain_area' in session:
        unique_areas = np.unique(session['brain_area'])
        
        if verbose:
            print(f"Number of unique brain areas: {len(unique_areas)}")
            print(f"Brain areas: {unique_areas}")
        
        correspondence["checks"]["brain_areas"] = unique_areas.tolist()
    
    # Check 5: LFP brain areas
    if 'brain_area_lfp' in session:
        lfp_areas = session['brain_area_lfp']
        
        if verbose:
            print(f"Number of LFP brain areas: {len(lfp_areas)}")
            print(f"LFP brain areas: {lfp_areas}")
        
        correspondence["checks"]["lfp_brain_areas"] = lfp_areas.tolist()
    
    return correspondence

# Cell 5: Brain Region Grouping
def define_brain_region_groups():
    """
    Define groupings of brain regions
    
    Returns:
    regions : list
        List of region group names
    brain_groups : list
        List of lists, each containing brain areas in a group
    """
    regions = ["vis ctx", "thal", "hipp", "other ctx", "midbrain", "basal ganglia", "cortical subplate", "other"]
    
    brain_groups = [
        ["VISa", "VISam", "VISl", "VISp", "VISpm", "VISrl"],  # visual cortex
        ["CL", "LD", "LGd", "LH", "LP", "MD", "MG", "PO", "POL", "PT", "RT", "SPF", "TH", "VAL", "VPL", "VPM"],  # thalamus
        ["CA", "CA1", "CA2", "CA3", "DG", "SUB", "POST"],  # hippocampal
        ["ACA", "AUD", "COA", "DP", "ILA", "MOp", "MOs", "OLF", "ORB", "ORBm", "PIR", "PL", "SSp", "SSs", "RSP", "TT"],  # non-visual cortex
        ["APN", "IC", "MB", "MRN", "NB", "PAG", "RN", "SCs", "SCm", "SCig", "SCsg", "ZI"],  # midbrain
        ["ACB", "CP", "GPe", "LS", "LSc", "LSr", "MS", "OT", "SNr", "SI"],  # basal ganglia
        ["BLA", "BMA", "EP", "EPd", "MEA"]  # cortical subplate
    ]
    
    return regions, brain_groups

# Cell 6: Session Selection Function
def find_sessions_with_areas(integrated_data, areas_of_interest, min_neuron_count=5):
    """
    Find sessions containing neurons from specific brain areas
    
    Parameters:
    integrated_data : list
        List of integrated session data
    areas_of_interest : list
        List of brain areas to look for
    min_neuron_count : int
        Minimum number of neurons required in each area
        
    Returns:
    matching_sessions : list
        List of dictionaries with session indices and neuron counts
    """
    matching_sessions = []
    
    for i, session in enumerate(integrated_data):
        if 'brain_area' not in session:
            continue
        
        # Count neurons in each area of interest
        area_counts = {}
        for area in areas_of_interest:
            area_counts[area] = np.sum(np.array(session['brain_area']) == area)
        
        # Check if all areas have enough neurons
        if all(area_counts[area] >= min_neuron_count for area in areas_of_interest):
            matching_sessions.append({
                "session_idx": i,
                "neuron_counts": area_counts
            })
    
    return matching_sessions

# Cell 7: Data Extraction Function
def extract_session_data(integrated_data, session_idx, areas_of_interest=None):
    """
    Extract and preprocess data for a specific session
    
    Parameters:
    integrated_data : list
        List of integrated session data
    session_idx : int
        Index of session to extract
    areas_of_interest : list, optional
        List of brain areas to focus on
        
    Returns:
    extracted_data : dict
        Dictionary with extracted and preprocessed data
    """
    if session_idx >= len(integrated_data):
        print(f"Session index {session_idx} out of range (max: {len(integrated_data)-1})")
        return None
    
    session = integrated_data[session_idx]
    extracted_data = {"session_idx": session_idx}
    
    # Extract spike data
    if 'ss' in session:
        extracted_data["spike_times"] = session['ss']
        
        # Filter by brain area if specified
        if areas_of_interest is not None and 'brain_area' in session:
            area_mask = np.array([area in areas_of_interest for area in session['brain_area']])
            extracted_data["area_mask"] = area_mask
            extracted_data["filtered_spike_times"] = [session['ss'][i] for i in np.where(area_mask)[0]]
            extracted_data["filtered_brain_areas"] = np.array(session['brain_area'])[area_mask]
    
    # Extract LFP data
    if 'lfp' in session:
        extracted_data["lfp"] = session['lfp']
        
        # Filter by brain area if specified
        if areas_of_interest is not None and 'brain_area_lfp' in session:
            lfp_area_mask = np.array([area in areas_of_interest for area in session['brain_area_lfp']])
            extracted_data["lfp_area_mask"] = lfp_area_mask
            
            if np.any(lfp_area_mask):
                extracted_data["filtered_lfp"] = session['lfp'][lfp_area_mask]
                extracted_data["filtered_lfp_areas"] = np.array(session['brain_area_lfp'])[lfp_area_mask]
    
    # Extract waveform data
    if 'waveform_w' in session:
        extracted_data["waveform_w"] = session['waveform_w']
        
        # Filter by brain area if specified
        if areas_of_interest is not None and 'brain_area' in session and 'area_mask' in extracted_data:
            extracted_data["filtered_waveform_w"] = session['waveform_w'][np.where(extracted_data["area_mask"])[0]]
    
    if 'waveform_u' in session:
        extracted_data["waveform_u"] = session['waveform_u']
        
        # Filter by brain area if specified
        if areas_of_interest is not None and 'brain_area' in session and 'area_mask' in extracted_data:
            extracted_data["filtered_waveform_u"] = session['waveform_u'][np.where(extracted_data["area_mask"])[0]]
    
    # Extract trial information if available
    for key in ['contrast_left', 'contrast_right', 'response', 'reaction_time', 'feedback_type']:
        if key in session:
            extracted_data[key] = session[key]
    
    return extracted_data



In [7]:
# Cell 8: Example Usage - Data Loading and Integration
# Run this cell to load and integrate the data
print("Loading and integrating Steinmetz datasets...")
st_data, lfp_data, wav_data = load_steinmetz_data()
integrated_data = integrate_steinmetz_data(st_data, lfp_data, wav_data)


Loading and integrating Steinmetz datasets...
Loading datasets...
Loaded 39 sessions from spike times data
Loaded 39 sessions from LFP data
Loaded 39 sessions from waveform data

Keys in first session of each dataset:
Spike times data: ['ss', 'ss_passive']
LFP data: ['lfp', 'lfp_passive', 'brain_area_lfp']
Waveform data: ['waveform_w', 'waveform_u', 'trough_to_peak']
All datasets have 39 sessions
Integrating session 0...
Integrating session 10...
Integrating session 20...
Integrating session 30...
Successfully integrated 39 sessions
Keys in first integrated session: ['ss', 'ss_passive', 'lfp', 'lfp_passive', 'brain_area_lfp', 'waveform_w', 'waveform_u', 'trough_to_peak']


In [8]:

# Cell 9: Example Usage - Verify Data Correspondence
# Verify data correspondence for the first session
print("\nVerifying data correspondence for session 0...")
correspondence = verify_data_correspondence(integrated_data, session_idx=0)




Verifying data correspondence for session 0...
Neuron count in spike times data: 734
Neuron count in waveform data: 734
Neuron counts match: True
Number of trials: 214
LFP data shape: (7, 214, 250)
Number of LFP brain areas: 7
LFP brain areas: [np.str_('ACA'), np.str_('LS'), np.str_('MOs'), np.str_('CA3'), np.str_('DG'), np.str_('SUB'), np.str_('VISp')]


AttributeError: 'list' object has no attribute 'tolist'

In [None]:
# Cell 10: Example Usage - Find Sessions with Areas of Interest
# Define areas of interest for research questions
areas_of_interest = ['MOs', 'ACB', 'CP', 'ACA', 'PL']  # MOs, basal ganglia, prefrontal cortex
print(f"\nFinding sessions with neurons in {areas_of_interest}...")
matching_sessions = find_sessions_with_areas(integrated_data, areas_of_interest)

print(f"Found {len(matching_sessions)} sessions with neurons in all areas of interest")
if matching_sessions:
    print("Top 3 matching sessions:")
    for i, session in enumerate(matching_sessions[:3]):
        print(f"Session {session['session_idx']}: {session['neuron_counts']}")



In [None]:
# Cell 11: Example Usage - Extract Data from a Session
# Extract data from the first matching session if available
if matching_sessions:
    best_session_idx = matching_sessions[0]["session_idx"]
    print(f"\nExtracting data from session {best_session_idx}...")
    session_data = extract_session_data(integrated_data, best_session_idx, areas_of_interest)
    
    # Print summary of extracted data
    print("Extracted data summary:")
    for key, value in session_data.items():
        if isinstance(value, np.ndarray):
            print(f"{key}: shape {value.shape}")
        elif isinstance(value, list) and key == "filtered_spike_times":
            print(f"{key}: {len(value)} neurons")
        elif not isinstance(value, list) and not isinstance(value, np.ndarray):
            print(f"{key}: {value}")

In [None]:
# Cell 12: Spike Train to Firing Rate Function
def spike_times_to_firing_rate(spike_times, bin_size=0.01, t_start=-0.5, t_end=1.0):
    """
    Convert spike times to binned firing rates
    
    Parameters:
    spike_times : list
        List of spike times for each neuron and trial
    bin_size : float
        Size of time bins in seconds
    t_start : float
        Start time in seconds (relative to stimulus onset)
    t_end : float
        End time in seconds (relative to stimulus onset)
        
    Returns:
    firing_rates : numpy.ndarray
        Array of firing rates (neurons x trials x time bins)
    time_bins : numpy.ndarray
        Array of time bin centers
    """
    if not spike_times:
        return None, None
    
    # Create time bins
    time_bins = np.arange(t_start, t_end, bin_size)
    bin_centers = time_bins[:-1] + bin_size/2
    
    n_neurons = len(spike_times)
    n_trials = len(spike_times[0])
    n_bins = len(time_bins) - 1
    
    # Initialize firing rates array
    firing_rates = np.zeros((n_neurons, n_trials, n_bins))
    
    # Bin spike times into firing rates
    for n in range(n_neurons):
        for t in range(n_trials):
            spikes = np.array(spike_times[n][t])
            for b in range(n_bins):
                bin_start = time_bins[b]
                bin_end = time_bins[b+1]
                firing_rates[n, t, b] = np.sum((spikes >= bin_start) & (spikes < bin_end)) / bin_size
    
    return firing_rates, bin_centers

# Cell 13: LFP Analysis Function
def analyze_lfp(lfp_data, brain_areas, sampling_rate=100):
    """
    Analyze LFP data: compute power spectra and filter in frequency bands
    
    Parameters:
    lfp_data : numpy.ndarray
        LFP data (channels x time)
    brain_areas : list
        List of brain areas for each channel
    sampling_rate : float
        Sampling rate in Hz
        
    Returns:
    lfp_analysis : dict
        Dictionary with LFP analysis results
    """
    n_channels, n_timepoints = lfp_data.shape
    
    # Initialize results dictionary
    lfp_analysis = {
        "brain_areas": brain_areas,
        "power_spectra": [],
        "frequencies": None,
        "theta_power": np.zeros(n_channels),
        "beta_power": np.zeros(n_channels),
        "gamma_power": np.zeros(n_channels)
    }
    
    # Compute power spectra for each channel
    for ch in range(n_channels):
        # Compute power spectrum
        freqs, psd = signal.welch(lfp_data[ch], fs=sampling_rate, nperseg=256)
        
        # Store results
        lfp_analysis["power_spectra"].append(psd)
        if lfp_analysis["frequencies"] is None:
            lfp_analysis["frequencies"] = freqs
        
        # Compute power in specific frequency bands
        theta_mask = (freqs >= 4) & (freqs <= 8)
        beta_mask = (freqs >= 13) & (freqs <= 30)
        gamma_mask = (freqs >= 30) & (freqs <= 80)
        
        lfp_analysis["theta_power"][ch] = np.mean(psd[theta_mask])
        lfp_analysis["beta_power"][ch] = np.mean(psd[beta_mask])
        lfp_analysis["gamma_power"][ch] = np.mean(psd[gamma_mask])
    
    # Convert power spectra to array
    lfp_analysis["power_spectra"] = np.array(lfp_analysis["power_spectra"])
    
    return lfp_analysis

# Cell 14: Cross-Regional Connectivity Function
def compute_cross_regional_connectivity(firing_rates, brain_areas):
    """
    Compute functional connectivity between brain regions
    
    Parameters:
    firing_rates : numpy.ndarray
        Firing rates (neurons x trials x time)
    brain_areas : numpy.ndarray
        Brain area for each neuron
        
    Returns:
    connectivity : dict
        Dictionary with connectivity results
    """
    # Get unique brain areas
    unique_areas = np.unique(brain_areas)
    n_areas = len(unique_areas)
    
    # Average firing rates across trials
    mean_rates = np.mean(firing_rates, axis=1)  # neurons x time
    
    # Compute correlation matrix between all neurons
    corr_matrix = np.corrcoef(mean_rates)
    
    # Compute average correlation between brain areas
    area_corr = np.zeros((n_areas, n_areas))
    
    for i, area1 in enumerate(unique_areas):
        for j, area2 in enumerate(unique_areas):
            # Get indices of neurons in each area
            idx1 = np.where(brain_areas == area1)[0]
            idx2 = np.where(brain_areas == area2)[0]
            
            # Compute average correlation between areas
            if i == j:  # Within-area correlation
                # Exclude self-correlations (diagonal)
                area_corr[i, j] = np.mean(corr_matrix[np.ix_(idx1, idx1)] - np.eye(len(idx1)))
            else:  # Between-area correlation
                area_corr[i, j] = np.mean(corr_matrix[np.ix_(idx1, idx2)])
    
    # Create connectivity dictionary
    connectivity = {
        "neuron_corr": corr_matrix,
        "area_corr": area_corr,
        "areas": unique_areas
    }
    
    return connectivity

# Cell 15: Neural Dynamics Analysis Function
def analyze_neural_dynamics(firing_rates, brain_areas, trial_info=None):
    """
    Analyze neural dynamics using dimensionality reduction
    
    Parameters:
    firing_rates : numpy.ndarray
        Firing rates (neurons x trials x time)
    brain_areas : numpy.ndarray
        Brain area for each neuron
    trial_info : dict, optional
        Dictionary with trial information
        
    Returns:
    dynamics : dict
        Dictionary with neural dynamics results
    """
    # Get unique brain areas
    unique_areas = np.unique(brain_areas)
    
    # Initialize results dictionary
    dynamics = {
        "areas": unique_areas,
        "pca_results": {},
        "trajectories": {}
    }
    
    # Analyze each brain area separately
    for area in unique_areas:
        # Get neurons in this area
        area_mask = brain_areas == area
        area_rates = firing_rates[area_mask]
        
        if len(area_rates) < 3:
            # Skip areas with too few neurons
            continue
        
        # Reshape for PCA: (neurons x (trials*time))
        n_neurons, n_trials, n_timepoints = area_rates.shape
        reshaped_rates = area_rates.reshape(n_neurons, -1).T
        
        # Run PCA
        pca = PCA(n_components=min(3, n_neurons))
        pca_result = pca.fit_transform(reshaped_rates)
        
        # Reshape back to (trials x time x components)
        trajectories = pca_result.T.reshape(min(3, n_neurons), n_trials, n_timepoints)
        
        # Store results
        dynamics["pca_results"][area] = {
            "explained_variance_ratio": pca.explained_variance_ratio_,
            "components": pca.components_
        }
        dynamics["trajectories"][area] = trajectories
    
    # If trial information is provided, add condition-specific trajectories
    if trial_info is not None and 'response' in trial_info:
        dynamics["condition_trajectories"] = {}
        
        for area in unique_areas:
            if area not in dynamics["trajectories"]:
                continue
            
            # Get trajectories for this area
            area_trajectories = dynamics["trajectories"][area]
            
            # Split by response
            left_mask = trial_info['response'] == -1
            right_mask = trial_info['response'] == 1
            
            dynamics["condition_trajectories"][area] = {
                "left": area_trajectories[:, left_mask, :],
                "right": area_trajectories[:, right_mask, :]
            }
    
    return dynamics

# Cell 16: Example Usage - Compute Firing Rates
# Compute firing rates for the extracted session data
if 'session_data' in locals() and 'filtered_spike_times' in session_data:
    print("\nComputing firing rates...")
    firing_rates, time_bins = spike_times_to_firing_rate(session_data["filtered_spike_times"])
    
    if firing_rates is not None:
        print(f"Firing rates shape: {firing_rates.shape}")
        print(f"Time bins: {time_bins[0]} to {time_bins[-1]} seconds")


In [None]:

# Cell 17: Example Usage - Analyze LFP
# Analyze LFP data for the extracted session
if 'session_data' in locals() and 'filtered_lfp' in session_data and 'filtered_lfp_areas' in session_data:
    print("\nAnalyzing LFP data...")
    lfp_analysis = analyze_lfp(session_data["filtered_lfp"], session_data["filtered_lfp_areas"])
    
    print("LFP analysis results:")
    print(f"Analyzed {len(lfp_analysis['brain_areas'])} LFP channels")
    print(f"Frequency range: {lfp_analysis['frequencies'][0]} to {lfp_analysis['frequencies'][-1]} Hz")
    
    # Print average power in different frequency bands for each area
    unique_lfp_areas = np.unique(session_data["filtered_lfp_areas"])
    print("\nAverage power by brain area and frequency band:")
    for area in unique_lfp_areas:
        area_mask = np.array(session_data["filtered_lfp_areas"]) == area
        print(f"{area}:")
        print(f"  Theta (4-8 Hz): {np.mean(lfp_analysis['theta_power'][area_mask]):.2f}")
        print(f"  Beta (13-30 Hz): {np.mean(lfp_analysis['beta_power'][area_mask]):.2f}")
        print(f"  Gamma (30-80 Hz): {np.mean(lfp_analysis['gamma_power'][area_mask]):.2f}")

# Cell 18: Example Usage - Compute Cross-Regional Connectivity
# Compute functional connectivity between brain regions
if 'firing_rates' in locals() and 'session_data' in locals() and 'filtered_brain_areas' in session_data:
    print("\nComputing cross-regional connectivity...")
    connectivity = compute_cross_regional_connectivity(firing_rates, session_data["filtered_brain_areas"])
    
    print("Connectivity results:")
    print(f"Analyzed {len(connectivity['areas'])} brain areas: {connectivity['areas']}")
    
    # Plot connectivity matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(connectivity['area_corr'], annot=True, cmap='coolwarm', 
                xticklabels=connectivity['areas'], yticklabels=connectivity['areas'])
    plt.title('Cross-Regional Functional Connectivity')
    plt.tight_layout()
    plt.show()

# Cell 19: Example Usage - Analyze Neural Dynamics
# Analyze neural dynamics using dimensionality reduction
if 'firing_rates' in locals() and 'session_data' in locals() and 'filtered_brain_areas' in session_data:
    print("\nAnalyzing neural dynamics...")
    
    # Create trial info dictionary
    trial_info = {}
    for key in ['response', 'contrast_left', 'contrast_right']:
        if key in session_data:
            trial_info[key] = session_data[key]
    
    dynamics = analyze_neural_dynamics(firing_rates, session_data["filtered_brain_areas"], trial_info)
    
    print("Neural dynamics results:")
    print(f"Analyzed {len(dynamics['areas'])} brain areas: {dynamics['areas']}")
    
    # Plot explained variance for each area
    plt.figure(figsize=(12, 5))
    for i, area in enumerate(dynamics['pca_results'].keys()):
        plt.subplot(1, len(dynamics['pca_results']), i+1)
        plt.bar(range(1, len(dynamics['pca_results'][area]['explained_variance_ratio'])+1), 
                dynamics['pca_results'][area]['explained_variance_ratio'])
        plt.title(f'{area} Explained Variance')
        plt.xlabel('PC')
        plt.ylabel('Explained Variance Ratio')
    plt.tight_layout()
    plt.show()

# Cell 20: Visualization Function - Plot Neural Trajectories
def plot_neural_trajectories(dynamics, area, condition='all'):
    """
    Plot neural trajectories in PCA space
    
    Parameters:
    dynamics : dict
        Dictionary with neural dynamics results from analyze_neural_dynamics
    area : str
        Brain area to plot
    condition : str
        'all', 'left', or 'right' to plot all trials or specific conditions
    """
    if area not in dynamics['trajectories']:
        print(f"Area {area} not found in dynamics results")
        return
    
    # Get trajectories for this area
    if condition == 'all':
        trajectories = dynamics['trajectories'][area]
        n_components, n_trials, n_timepoints = trajectories.shape
        
        # Plot average trajectory
        plt.figure(figsize=(10, 8))
        
        # 3D plot if we have 3 components
        if n_components >= 3:
            ax = plt.subplot(111, projection='3d')
            mean_traj = np.mean(trajectories, axis=1)
            ax.plot(mean_traj[0], mean_traj[1], mean_traj[2], 'k-', linewidth=2)
            
            # Add points at specific timepoints
            timepoints = [0, n_timepoints//4, n_timepoints//2, 3*n_timepoints//4, n_timepoints-1]
            colors = ['blue', 'cyan', 'green', 'orange', 'red']
            
            for i, t in enumerate(timepoints):
                ax.scatter(mean_traj[0, t], mean_traj[1, t], mean_traj[2, t], 
                           color=colors[i], s=100, label=f'Time {t}')
            
            ax.set_xlabel('PC1')
            ax.set_ylabel('PC2')
            ax.set_zlabel('PC3')
            ax.set_title(f'{area} Neural Trajectory')
            ax.legend()
        
        # 2D plot if we have 2 components
        elif n_components >= 2:
            mean_traj = np.mean(trajectories, axis=1)
            plt.plot(mean_traj[0], mean_traj[1], 'k-', linewidth=2)
            
            # Add points at specific timepoints
            timepoints = [0, n_timepoints//4, n_timepoints//2, 3*n_timepoints//4, n_timepoints-1]
            colors = ['blue', 'cyan', 'green', 'orange', 'red']
            
            for i, t in enumerate(timepoints):
                plt.scatter(mean_traj[0, t], mean_traj[1, t], 
                           color=colors[i], s=100, label=f'Time {t}')
            
            plt.xlabel('PC1')
            plt.ylabel('PC2')
            plt.title(f'{area} Neural Trajectory')
            plt.legend()
    
    # Plot condition-specific trajectories
    elif condition in ['left', 'right'] and 'condition_trajectories' in dynamics:
        if area not in dynamics['condition_trajectories']:
            print(f"Area {area} not found in condition trajectories")
            return
        
        trajectories = dynamics['condition_trajectories'][area][condition]
        n_components, n_trials, n_timepoints = trajectories.shape
        
        # Plot average trajectory
        plt.figure(figsize=(10, 8))
        
        # 3D plot if we have 3 components
        if n_components >= 3:
            ax = plt.subplot(111, projection='3d')
            mean_traj = np.mean(trajectories, axis=1)
            ax.plot(mean_traj[0], mean_traj[1], mean_traj[2], 'k-', linewidth=2)
            
            # Add points at specific timepoints
            timepoints = [0, n_timepoints//4, n_timepoints//2, 3*n_timepoints//4, n_timepoints-1]
            colors = ['blue', 'cyan', 'green', 'orange', 'red']
            
            for i, t in enumerate(timepoints):
                ax.scatter(mean_traj[0, t], mean_traj[1, t], mean_traj[2, t], 
                           color=colors[i], s=100, label=f'Time {t}')
            
            ax.set_xlabel('PC1')
            ax.set_ylabel('PC2')
            ax.set_zlabel('PC3')
            ax.set_title(f'{area} Neural Trajectory - {condition.capitalize()} Trials')
            ax.legend()
        
        # 2D plot if we have 2 components
        elif n_components >= 2:
            mean_traj = np.mean(trajectories, axis=1)
            plt.plot(mean_traj[0], mean_traj[1], 'k-', linewidth=2)
            
            # Add points at specific timepoints
            timepoints = [0, n_timepoints//4, n_timepoints//2, 3*n_timepoints//4, n_timepoints-1]
            colors = ['blue', 'cyan', 'green', 'orange', 'red']
            
            for i, t in enumerate(timepoints):
                plt.scatter(mean_traj[0, t], mean_traj[1, t], 
                           color=colors[i], s=100, label=f'Time {t}')
            
            plt.xlabel('PC1')
            plt.ylabel('PC2')
            plt.title(f'{area} Neural Trajectory - {condition.capitalize()} Trials')
            plt.legend()
    
    else:
        print(f"Condition {condition} not recognized or condition trajectories not available")
    
    plt.tight_layout()
    plt.show()

# Cell 21: Example Usage - Plot Neural Trajectories
# Plot neural trajectories for a specific brain area
if 'dynamics' in locals():
    # Find an area with enough neurons for PCA
    for area in dynamics['areas']:
        if area in dynamics['trajectories']:
            print(f"\nPlotting neural trajectories for {area}...")
            plot_neural_trajectories(dynamics, area)
            
            # Plot condition-specific trajectories if available
            if 'condition_trajectories' in dynamics and area in dynamics['condition_trajectories']:
                plot_neural_trajectories(dynamics, area, 'left')
                plot_neural_trajectories(dynamics, area, 'right')
            
            break

# Cell 22: Age-Related Analysis Function
def analyze_age_differences(integrated_data, young_sessions, old_sessions, areas_of_interest):
    """
    Analyze age-related differences in neural activity and connectivity
    
    Parameters:
    integrated_data : list
        List of integrated session data
    young_sessions : list
        List of session indices for young mice
    old_sessions : list
        List of session indices for old mice
    areas_of_interest : list
        List of brain areas to analyze
        
    Returns:
    age_analysis : dict
        Dictionary with age-related analysis results
    """
    # Initialize results dictionary
    age_analysis = {
        "young": {"connectivity": [], "firing_rates": []},
        "old": {"connectivity": [], "firing_rates": []}
    }
    
    # Analyze young sessions
    for session_idx in young_sessions:
        # Extract session data
        session_data = extract_session_data(integrated_data, session_idx, areas_of_interest)
        
        if 'filtered_spike_times' in session_data:
            # Compute firing rates
            firing_rates, _ = spike_times_to_firing_rate(session_data["filtered_spike_times"])
            
            if firing_rates is not None:
                # Compute mean firing rate for each area
                for area in np.unique(session_data["filtered_brain_areas"]):
                    area_mask = session_data["filtered_brain_areas"] == area
                    area_rates = firing_rates[area_mask]
                    mean_rate = np.mean(area_rates)
                    
                    age_analysis["young"]["firing_rates"].append({
                        "session_idx": session_idx,
                        "area": area,
                        "mean_rate": mean_rate
                    })
                
                # Compute connectivity
                connectivity = compute_cross_regional_connectivity(firing_rates, session_data["filtered_brain_areas"])
                age_analysis["young"]["connectivity"].append({
                    "session_idx": session_idx,
                    "area_corr": connectivity["area_corr"],
                    "areas": connectivity["areas"]
                })
    
    # Analyze old sessions
    for session_idx in old_sessions:
        # Extract session data
        session_data = extract_session_data(integrated_data, session_idx, areas_of_interest)
        
        if 'filtered_spike_times' in session_data:
            # Compute firing rates
            firing_rates, _ = spike_times_to_firing_rate(session_data["filtered_spike_times"])
            
            if firing_rates is not None:
                # Compute mean firing rate for each area
                for area in np.unique(session_data["filtered_brain_areas"]):
                    area_mask = session_data["filtered_brain_areas"] == area
                    area_rates = firing_rates[area_mask]
                    mean_rate = np.mean(area_rates)
                    
                    age_analysis["old"]["firing_rates"].append({
                        "session_idx": session_idx,
                        "area": area,
                        "mean_rate": mean_rate
                    })
                
                # Compute connectivity
                connectivity = compute_cross_regional_connectivity(firing_rates, session_data["filtered_brain_areas"])
                age_analysis["old"]["connectivity"].append({
                    "session_idx": session_idx,
                    "area_corr": connectivity["area_corr"],
                    "areas": connectivity["areas"]
                })
    
    return age_analysis

# Cell 23: Example Usage - Age-Related Analysis
# For demonstration purposes, let's assume the first half of sessions are from young mice
# and the second half are from old mice
if 'integrated_data' in locals():
    num_sessions = len(integrated_data)
    young_sessions = list(range(num_sessions // 2))
    old_sessions = list(range(num_sessions // 2, num_sessions))
    
    print("\nPerforming age-related analysis...")
    print(f"Young sessions: {young_sessions[:5]}...")
    print(f"Old sessions: {old_sessions[:5]}...")
    
    # Define areas of interest
    areas_of_interest = ['MOs', 'ACB', 'CP', 'ACA', 'PL']  # MOs, basal ganglia, prefrontal cortex
    
    # Note: This analysis would be computationally intensive and time-consuming
    # For demonstration, we'll just show the function definition
    print("Note: Age-related analysis would analyze differences in neural activity and connectivity")
    print("between young and old mice across multiple sessions.")
    print("This would be computationally intensive and is not executed in this notebook.")

# Cell 24: Visualization Function - Plot LFP Power Spectra
def plot_lfp_power_spectra(lfp_analysis, areas_to_plot=None):
    """
    Plot LFP power spectra for specific brain areas
    
    Parameters:
    lfp_analysis : dict
        Dictionary with LFP analysis results from analyze_lfp
    areas_to_plot : list, optional
        List of brain areas to plot. If None, plot all areas.
    """
    # Get unique brain areas
    unique_areas = np.unique(lfp_analysis["brain_areas"])
    
    # Filter areas if specified
    if areas_to_plot is not None:
        plot_areas = [area for area in areas_to_plot if area in unique_areas]
    else:
        plot_areas = unique_areas
    
    # Plot power spectra
    plt.figure(figsize=(15, 10))
    
    for i, area in enumerate(plot_areas):
        # Get channels for this area
        area_mask = np.array(lfp_analysis["brain_areas"]) == area
        area_spectra = lfp_analysis["power_spectra"][area_mask]
        
        # Plot average spectrum
        plt.subplot(len(plot_areas), 1, i+1)
        mean_spectrum = np.mean(area_spectra, axis=0)
        std_spectrum = np.std(area_spectra, axis=0)
        
        plt.plot(lfp_analysis["frequencies"], mean_spectrum, 'k-', linewidth=2)
        plt.fill_between(lfp_analysis["frequencies"], 
                         mean_spectrum - std_spectrum, 
                         mean_spectrum + std_spectrum, 
                         alpha=0.3)
        
        # Mark frequency bands
        plt.axvspan(4, 8, alpha=0.2, color='blue', label='Theta (4-8 Hz)')
        plt.axvspan(13, 30, alpha=0.2, color='green', label='Beta (13-30 Hz)')
        plt.axvspan(30, 80, alpha=0.2, color='red', label='Gamma (30-80 Hz)')
        
        plt.title(f'{area} LFP Power Spectrum')
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Power')
        plt.xlim(0, 100)  # Limit to 0-100 Hz for better visualization
        plt.legend()
    
    plt.tight_layout()
    plt.show()

# Cell 25: Example Usage - Plot LFP Power Spectra
# Plot LFP power spectra for specific brain areas
if 'lfp_analysis' in locals():
    print("\nPlotting LFP power spectra...")
    
    # Get unique brain areas
    unique_areas = np.unique(lfp_analysis["brain_areas"])
    print(f"Available areas: {unique_areas}")
    
    # Plot power spectra for up to 3 areas
    areas_to_plot = unique_areas[:min(3, len(unique_areas))]
    plot_lfp_power_spectra(lfp_analysis, areas_to_plot)

# Cell 26: Visualization Function - Plot Firing Rate Heatmap
def plot_firing_rate_heatmap(firing_rates, brain_areas, time_bins):
    """
    Plot firing rate heatmap for neurons grouped by brain area
    
    Parameters:
    firing_rates : numpy.ndarray
        Firing rates (neurons x trials x time)
    brain_areas : numpy.ndarray
        Brain area for each neuron
    time_bins : numpy.ndarray
        Time bin centers
    """
    # Get unique brain areas
    unique_areas = np.unique(brain_areas)
    
    # Compute mean firing rate across trials
    mean_rates = np.mean(firing_rates, axis=1)  # neurons x time
    
    # Sort neurons by brain area
    sorted_indices = np.argsort(brain_areas)
    sorted_rates = mean_rates[sorted_indices]
    sorted_areas = brain_areas[sorted_indices]
    
    # Create area boundaries for plotting
    area_boundaries = []
    current_area = sorted_areas[0]
    current_count = 0
    
    for area in sorted_areas:
        if area == current_area:
            current_count += 1
        else:
            area_boundaries.append((current_area, current_count))
            current_area = area
            current_count = 1
    
    area_boundaries.append((current_area, current_count))
    
    # Plot heatmap
    plt.figure(figsize=(15, 10))
    
    # Plot firing rate heatmap
    plt.imshow(sorted_rates, aspect='auto', cmap='viridis', 
               extent=[time_bins[0], time_bins[-1], len(sorted_rates), 0])
    
    # Add area labels
    y_pos = 0
    for area, count in area_boundaries:
        plt.axhline(y_pos + count, color='white', linestyle='-')
        plt.text(time_bins[-1] + 0.05, y_pos + count/2, area, 
                 verticalalignment='center', fontsize=10)
        y_pos += count
    
    # Add stimulus onset line
    plt.axvline(0, color='red', linestyle='--', label='Stimulus Onset')
    
    plt.colorbar(label='Firing Rate (Hz)')
    plt.xlabel('Time (s)')
    plt.ylabel('Neuron')
    plt.title('Firing Rate Heatmap by Brain Area')
    plt.tight_layout()
    plt.show()

# Cell 27: Example Usage - Plot Firing Rate Heatmap
# Plot firing rate heatmap for neurons grouped by brain area
if 'firing_rates' in locals() and 'time_bins' in locals() and 'session_data' in locals() and 'filtered_brain_areas' in session_data:
    print("\nPlotting firing rate heatmap...")
    plot_firing_rate_heatmap(firing_rates, session_data["filtered_brain_areas"], time_bins)

# Cell 28: Research Question Analysis Function
def analyze_research_questions(integrated_data, session_idx, areas_of_interest):
    """
    Analyze data specifically for the research questions:
    1. Neural dynamics across MOs, basal ganglia, and prefrontal cortex in decision-making
    2. Age-related differences in functional connectivity
    
    Parameters:
    integrated_data : list
        List of integrated session data
    session_idx : int
        Index of session to analyze
    areas_of_interest : list
        List of brain areas to analyze
        
    Returns:
    analysis : dict
        Dictionary with analysis results
    """
    # Extract session data
    session_data = extract_session_data(integrated_data, session_idx, areas_of_interest)
    
    if 'filtered_spike_times' not in session_data:
        print("No spike data available for the specified areas")
        return None
    
    # Compute firing rates
    firing_rates, time_bins = spike_times_to_firing_rate(session_data["filtered_spike_times"])
    
    if firing_rates is None:
        print("Failed to compute firing rates")
        return None
    
    # Initialize results dictionary
    analysis = {
        "session_idx": session_idx,
        "areas_analyzed": areas_of_interest,
        "time_bins": time_bins
    }
    
    # 1. Neural dynamics analysis
    print("Analyzing neural dynamics...")
    
    # Create trial info dictionary
    trial_info = {}
    for key in ['response', 'contrast_left', 'contrast_right']:
        if key in session_data:
            trial_info[key] = session_data[key]
    
    # Analyze neural dynamics
    dynamics = analyze_neural_dynamics(firing_rates, session_data["filtered_brain_areas"], trial_info)
    analysis["neural_dynamics"] = dynamics
    
    # 2. Functional connectivity analysis
    print("Analyzing functional connectivity...")
    connectivity = compute_cross_regional_connectivity(firing_rates, session_data["filtered_brain_areas"])
    analysis["connectivity"] = connectivity
    
    # 3. Decision-related activity
    if 'response' in session_data:
        print("Analyzing decision-related activity...")
        
        # Split trials by response
        left_trials = session_data['response'] == -1
        right_trials = session_data['response'] == 1
        
        # Compute mean firing rates for each condition
        left_rates = np.mean(firing_rates[:, left_trials, :], axis=1)  # neurons x time
        right_rates = np.mean(firing_rates[:, right_trials, :], axis=1)  # neurons x time
        
        # Compute selectivity index for each neuron
        selectivity = (right_rates - left_rates) / (right_rates + left_rates + 1e-10)  # Add small constant to avoid division by zero
        
        # Compute mean selectivity for each brain area
        area_selectivity = {}
        for area in np.unique(session_data["filtered_brain_areas"]):
            area_mask = session_data["filtered_brain_areas"] == area
            area_selectivity[area] = np.mean(selectivity[area_mask], axis=0)
        
        analysis["decision_selectivity"] = {
            "neuron_selectivity": selectivity,
            "area_selectivity": area_selectivity
        }
    
    # 4. LFP analysis if available
    if 'filtered_lfp' in session_data and 'filtered_lfp_areas' in session_data:
        print("Analyzing LFP data...")
        lfp_analysis = analyze_lfp(session_data["filtered_lfp"], session_data["filtered_lfp_areas"])
        analysis["lfp_analysis"] = lfp_analysis
    
    return analysis

# Cell 29: Example Usage - Research Question Analysis
# Analyze data specifically for the research questions
if 'integrated_data' in locals() and 'matching_sessions' in locals() and matching_sessions:
    print("\nAnalyzing data for research questions...")
    
    # Choose the best session
    best_session_idx = matching_sessions[0]["session_idx"]
    
    # Define areas of interest
    areas_of_interest = ['MOs', 'ACB', 'CP', 'ACA', 'PL']  # MOs, basal ganglia, prefrontal cortex
    
    # Analyze data
    research_analysis = analyze_research_questions(integrated_data, best_session_idx, areas_of_interest)
    
    if research_analysis is not None:
        print("Research question analysis complete")
        print(f"Analyzed session {research_analysis['session_idx']}")
        print(f"Areas analyzed: {research_analysis['areas_analyzed']}")
        
        # Plot decision selectivity
        if 'decision_selectivity' in research_analysis:
            plt.figure(figsize=(15, 5))
            for i, (area, selectivity) in enumerate(research_analysis['decision_selectivity']['area_selectivity'].items()):
                plt.plot(research_analysis['time_bins'], selectivity, label=area)
            
            plt.axvline(0, color='red', linestyle='--', label='Stimulus Onset')
            plt.xlabel('Time (s)')
            plt.ylabel('Decision Selectivity')
            plt.title('Decision Selectivity by Brain Area')
            plt.legend()
            plt.tight_layout()
            plt.show()



# Summary and Conclusion

This notebook provides a comprehensive framework for loading, integrating, and analyzing the Steinmetz datasets:
- steinmetz_st.npz (spike times)
- steinmetz_lfp.npz (local field potentials)
- steinmetz_wav.npz (waveforms)

The key components of this notebook are:

1. **Data Loading and Integration**
   - Functions to download and load all three datasets
   - Integration of data across datasets while preserving original structure
   - Verification of data correspondence between datasets

2. **Data Exploration and Selection**
   - Functions to explore available sessions and brain areas
   - Selection of sessions containing neurons from areas of interest
   - Extraction and preprocessing of data for analysis

3. **Neural Activity Analysis**
   - Conversion of spike times to firing rates
   - Analysis of LFP data in different frequency bands
   - Visualization of neural activity patterns

4. **Cross-Regional Connectivity Analysis**
   - Computation of functional connectivity between brain regions
   - Visualization of connectivity matrices
   - Analysis of connectivity patterns in relation to behavior

5. **Neural Dynamics Analysis**
   - Dimensionality reduction of neural activity
   - Visualization of neural trajectories in state space
   - Comparison of neural dynamics between conditions

6. **Research Question-Specific Analysis**
   - Analysis of neural dynamics across MOs, basal ganglia, and prefrontal cortex
   - Investigation of decision-related activity in these regions
   - Framework for analyzing age-related differences in functional connectivity

This notebook provides a foundation for addressing the research questions:
1. How do neural dynamics across MOs, basal ganglia, and prefrontal cortex drive strategy selection and decision-making during visual discrimination tasks?
2. How do age-related differences in functional connectivity between these regions influence cognitive processes and behavioral performance?

Further analyses could include:
- More detailed investigation of information flow between regions using Granger causality or other methods
- Analysis of trial-by-trial variability in neural dynamics and its relationship to behavior
- Comparison of neural activity patterns between young and old mice
- Integration with behavioral models to link neural activity to cognitive processes