In [None]:
"""
Neural Oscillation Analysis Module

This module processes spike train data to detect neural oscillations using spectral analysis.
It converts discrete spike times into smoothed firing rate signals, computes power spectra,
and uses the SpectralModel (FOOOF) algorithm to separate periodic oscillations from 
aperiodic background activity.

Key Features:
- Double exponential kernel smoothing for realistic neural dynamics
- Welch's method for robust power spectral density estimation
- Automated oscillation detection with parameterized fitting
- Batch processing with comprehensive error handling

Dependencies:
    - numpy, scipy
    - neurodsp (for spectral analysis)
    - specparam (FOOOF algorithm)
    - pathlib, json, pickle
"""

import json
import logging
import pickle
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple, Union, Optional

import numpy as np
from neurodsp.spectral import compute_spectrum_welch
from specparam import SpectralModel
from scipy.signal import find_peaks


class NeuralSpectralAnalyzer:
    """
    Analyzes neural oscillations in spike train data using spectral methods.
    
    This class processes spike timing data through several stages:
    1. Converts spike times to smoothed firing rate using biophysical kernels
    2. Computes power spectral density using Welch's method
    3. Fits parametric models to separate oscillations from aperiodic activity
    4. Saves results for further analysis
    """
    
    def __init__(self, 
                 base_path: Union[str, Path] = None,
                 output_subdir: str = "spectral_analysis",
                 bin_size_ms: float = 2.0,
                 kernel_window_ms: int = 300,
                 tau_rise_ms: float = 2.0,
                 tau_decay_ms: float = 25.0):
        """
        Initialize the spectral analyzer with processing parameters.
        
        Parameters:
        -----------
        base_path : str or Path, optional
            Base directory for data processing
        output_subdir : str
            Subdirectory name for output files
        bin_size_ms : float
            Temporal resolution for spike histogram (milliseconds)
        kernel_window_ms : int
            Duration of smoothing kernel (milliseconds)
        tau_rise_ms : float
            Rise time constant for double exponential kernel (milliseconds)
        tau_decay_ms : float
            Decay time constant for double exponential kernel (milliseconds)
        """
        self.base_path = Path(base_path) if base_path else Path.cwd()
        self.output_dir = self.base_path / output_subdir
        
        # Temporal processing parameters
        self.bin_size_ms = float(bin_size_ms)
        self.kernel_window_ms = int(kernel_window_ms)
        self.tau_rise_ms = float(tau_rise_ms)
        self.tau_decay_ms = float(tau_decay_ms)
        
        # Derived parameters
        self.sampling_rate = 1000.0 / self.bin_size_ms  # Hz
        
        self.setup_logging()
        
    def setup_logging(self):
        """Configure logging for processing tracking."""
        log_file = self.base_path / f'spectral_analysis_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
        
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(log_file),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)

    @staticmethod
    def double_exponential_kernel(t: np.ndarray, tau_rise: float, tau_decay: float) -> np.ndarray:
        """
        Create a double exponential kernel modeling glutamatergic synaptic dynamics.
        
        This kernel mimics the typical rise and decay phases of excitatory postsynaptic
        potentials (EPSPs), providing biologically realistic smoothing of spike trains.
        
        Parameters:
        -----------
        t : np.ndarray
            Time points for kernel evaluation (milliseconds)
        tau_rise : float
            Rise time constant (milliseconds)
        tau_decay : float
            Decay time constant (milliseconds)
            
        Returns:
        --------
        np.ndarray
            Normalized kernel values
        """
        kernel = (1.0 - np.exp(-t / tau_rise)) * np.exp(-t / tau_decay)
        return kernel / np.sum(kernel)  # Normalize to preserve total spike count

    def spikes_to_smoothed_signal(self, spike_times: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Convert discrete spike times to continuous smoothed firing rate signal.
        
        This method:
        1. Creates a histogram of spike times at specified temporal resolution
        2. Applies a double exponential kernel for biologically realistic smoothing
        3. Returns time bins and corresponding smoothed firing rates
        
        Parameters:
        -----------
        spike_times : np.ndarray
            Array of spike times in seconds
            
        Returns:
        --------
        tuple
            (time_bins, smoothed_signal) where:
            - time_bins: Time points for each bin (seconds)
            - smoothed_signal: Smoothed firing rate density
        """
        if len(spike_times) == 0:
            return np.array([]), np.array([])
            
        # Determine time range and create bins
        max_time = np.max(spike_times)
        num_bins = int(np.ceil((max_time * 1000) / self.bin_size_ms))
        time_bins = np.linspace(0, max_time, num_bins + 1)
        
        # Create spike histogram
        spike_histogram, _ = np.histogram(spike_times, bins=time_bins, density=True)
        
        # Generate smoothing kernel
        kernel_time = np.linspace(0, self.kernel_window_ms, num=self.kernel_window_ms)
        kernel = self.double_exponential_kernel(
            kernel_time, self.tau_rise_ms, self.tau_decay_ms
        )
        
        # Apply smoothing through convolution
        smoothed_signal = np.convolve(spike_histogram, kernel, mode='same')
        
        return time_bins[:-1], smoothed_signal

    def compute_power_spectrum(self, signal: np.ndarray, 
                             frequency_range: Tuple[float, float] = (0.05, 40.0),
                             window_duration_sec: float = 60.0) -> Tuple[np.ndarray, np.ndarray]:
        """
        Compute power spectral density using Welch's method.
        
        Parameters:
        -----------
        signal : np.ndarray
            Input signal for spectral analysis
        frequency_range : tuple
            (min_freq, max_freq) in Hz for analysis
        window_duration_sec : float
            Duration of each window for Welch's method (seconds)
            
        Returns:
        --------
        tuple
            (frequencies, power_spectrum)
        """
        # Calculate Welch parameters
        nperseg = int(self.sampling_rate * window_duration_sec)
        noverlap = nperseg // 2
        
        frequencies, power_spectrum = compute_spectrum_welch(
            signal, 
            self.sampling_rate,
            avg_type='median',
            window='hann',
            nperseg=nperseg,
            noverlap=noverlap,
            f_range=frequency_range
        )
        
        return frequencies, power_spectrum

    def fit_spectral_model(self, frequencies: np.ndarray, power_spectrum: np.ndarray,
                          fit_range: Tuple[float, float] = (0.5, 13.0)) -> SpectralModel:
        """
        Fit parametric spectral model to separate oscillations from aperiodic activity.
        
        Uses the SpectralModel (FOOOF) algorithm to decompose power spectra into:
        - Aperiodic component (1/f-like background)
        - Periodic components (oscillatory peaks)
        
        Parameters:
        -----------
        frequencies : np.ndarray
            Frequency values (Hz)
        power_spectrum : np.ndarray
            Power spectral density values
        fit_range : tuple
            (min_freq, max_freq) for model fitting
            
        Returns:
        --------
        SpectralModel
            Fitted model containing oscillation parameters
        """
        spectral_model = SpectralModel(
            peak_width_limits=[1, 8],      # Expected oscillation bandwidth (Hz)
            min_peak_height=0.2,           # Minimum peak height above aperiodic
            max_n_peaks=4,                 # Maximum number of oscillations to detect
            peak_threshold=1.5,            # Statistical threshold for peak detection
            aperiodic_mode='fixed'         # Use fixed aperiodic fitting
        )
        
        spectral_model.fit(frequencies, power_spectrum, fit_range)
        return spectral_model

    def analyze_well_data(self, well_data: Dict,
                         frequency_range: Tuple[float, float] = (0.05, 40.0),
                         fit_range: Tuple[float, float] = (0.5, 13.0)) -> Dict:
        """
        Complete spectral analysis pipeline for a single well.
        
        Parameters:
        -----------
        well_data : dict
            Dictionary containing 'all_spikes' key with spike time array
        frequency_range : tuple
            Frequency range for power spectrum computation
        fit_range : tuple
            Frequency range for oscillation model fitting
            
        Returns:
        --------
        dict
            Comprehensive analysis results including:
            - Smoothed signal and time bins
            - Power spectrum and frequencies
            - Fitted spectral model
            - Processing metadata
        """
        # Extract and validate spike data
        spike_times = np.array(well_data.get('all_spikes', []))
        
        if len(spike_times) == 0:
            self.logger.warning("No spikes found in well data")
            return self._create_empty_results()
        
        # Convert spikes to smoothed signal
        time_bins, smoothed_signal = self.spikes_to_smoothed_signal(spike_times)
        
        if len(smoothed_signal) == 0:
            self.logger.warning("Failed to create smoothed signal")
            return self._create_empty_results()
        
        # Compute power spectrum
        frequencies, power_spectrum = self.compute_power_spectrum(
            smoothed_signal, frequency_range
        )
        
        # Fit spectral model for oscillation detection
        spectral_model = self.fit_spectral_model(frequencies, power_spectrum, fit_range)
        
        # Compile results
        results = {
            'time_bins': time_bins,
            'smoothed_signal': smoothed_signal,
            'frequencies': frequencies,
            'power_spectrum': power_spectrum,
            'spectral_model': spectral_model,
            'processing_metadata': {
                'sampling_rate_hz': self.sampling_rate,
                'bin_size_ms': self.bin_size_ms,
                'kernel_parameters': {
                    'window_ms': self.kernel_window_ms,
                    'tau_rise_ms': self.tau_rise_ms,
                    'tau_decay_ms': self.tau_decay_ms
                },
                'spectral_parameters': {
                    'frequency_range': frequency_range,
                    'fit_range': fit_range
                },
                'n_spikes': len(spike_times),
                'recording_duration_sec': float(np.max(spike_times)) if len(spike_times) > 0 else 0.0
            }
        }
        
        return results

    def _create_empty_results(self) -> Dict:
        """Create empty results structure for failed processing."""
        return {
            'time_bins': np.array([]),
            'smoothed_signal': np.array([]),
            'frequencies': np.array([]),
            'power_spectrum': np.array([]),
            'spectral_model': None,
            'processing_metadata': {
                'error': 'No valid spike data found',
                'sampling_rate_hz': self.sampling_rate
            }
        }

    def process_json_file(self, json_file_path: Union[str, Path]) -> Optional[Dict]:
        """
        Process a single JSON file containing well data.
        
        Parameters:
        -----------
        json_file_path : str or Path
            Path to JSON file with spike data
            
        Returns:
        --------
        dict or None
            Analysis results or None if processing failed
        """
        json_file_path = Path(json_file_path)
        
        try:
            with open(json_file_path, 'r') as f:
                well_data = json.load(f)
            
            results = self.analyze_well_data(well_data)
            self.logger.info(f"Successfully processed {json_file_path.name}")
            return results
            
        except Exception as e:
            self.logger.error(f"Error processing {json_file_path.name}: {str(e)}")
            return None

    def save_results(self, results: Dict, output_file_path: Union[str, Path]):
        """
        Save analysis results to pickle file.
        
        Parameters:
        -----------
        results : dict
            Analysis results from analyze_well_data
        output_file_path : str or Path
            Path for output pickle file
        """
        output_file_path = Path(output_file_path)
        output_file_path.parent.mkdir(parents=True, exist_ok=True)
        
        try:
            with open(output_file_path, 'wb') as f:
                pickle.dump(results, f)
            self.logger.info(f"Results saved to {output_file_path}")
            
        except Exception as e:
            self.logger.error(f"Error saving results to {output_file_path}: {str(e)}")

    def process_folder(self, input_folder: Union[str, Path], 
                      output_folder: Union[str, Path] = None) -> Dict[str, int]:
        """
        Process all JSON files in a folder and save results as pickle files.
        
        Parameters:
        -----------
        input_folder : str or Path
            Folder containing JSON files to process
        output_folder : str or Path, optional
            Folder for saving pickle files. If None, uses output_dir.
            
        Returns:
        --------
        dict
            Processing statistics: {'processed': int, 'failed': int, 'errors': list}
        """
        input_folder = Path(input_folder)
        if output_folder is None:
            output_folder = self.output_dir / input_folder.name
        else:
            output_folder = Path(output_folder)
        
        output_folder.mkdir(parents=True, exist_ok=True)
        
        # Initialize processing statistics
        stats = {'processed': 0, 'failed': 0, 'errors': []}
        
        # Process each JSON file
        json_files = list(input_folder.glob('*.json'))
        self.logger.info(f"Found {len(json_files)} JSON files to process in {input_folder}")
        
        for json_file in json_files:
            try:
                # Process the file
                results = self.process_json_file(json_file)
                
                if results is not None:
                    # Save results as pickle file
                    pickle_file = output_folder / f"{json_file.stem}.pkl"
                    self.save_results(results, pickle_file)
                    stats['processed'] += 1
                else:
                    stats['failed'] += 1
                    stats['errors'].append({
                        'file': str(json_file),
                        'error': 'Processing returned None',
                        'timestamp': datetime.now().isoformat()
                    })
                    
            except Exception as e:
                stats['failed'] += 1
                error_info = {
                    'file': str(json_file),
                    'error': str(e),
                    'error_type': type(e).__name__,
                    'timestamp': datetime.now().isoformat()
                }
                stats['errors'].append(error_info)
                self.logger.error(f"Failed to process {json_file.name}: {str(e)}")
        
        # Save error log if there were failures
        if stats['errors']:
            error_log_path = output_folder / f"processing_errors_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
            with open(error_log_path, 'w') as f:
                json.dump(stats['errors'], f, indent=2)
            self.logger.info(f"Error log saved to {error_log_path}")
        
        self.logger.info(f"Processing complete: {stats['processed']} successful, {stats['failed']} failed")
        return stats

    def process_nested_folders(self, base_input_folder: Union[str, Path],
                              base_output_folder: Union[str, Path] = None) -> Dict[str, Dict]:
        """
        Process multiple folders containing JSON files.
        
        Parameters:
        -----------
        base_input_folder : str or Path
            Base folder containing subfolders with JSON files
        base_output_folder : str or Path, optional
            Base folder for saving results
            
        Returns:
        --------
        dict
            Processing statistics for each subfolder
        """
        base_input_folder = Path(base_input_folder)
        if base_output_folder is None:
            base_output_folder = self.output_dir
        else:
            base_output_folder = Path(base_output_folder)
        
        all_stats = {}
        
        # Process each subfolder
        for subfolder in base_input_folder.iterdir():
            if subfolder.is_dir():
                self.logger.info(f"Processing folder: {subfolder.name}")
                output_subfolder = base_output_folder / subfolder.name
                
                try:
                    stats = self.process_folder(subfolder, output_subfolder)
                    all_stats[subfolder.name] = stats
                    
                except Exception as e:
                    self.logger.error(f"Error processing folder {subfolder.name}: {str(e)}")
                    all_stats[subfolder.name] = {
                        'processed': 0, 
                        'failed': 0, 
                        'errors': [{'folder_error': str(e)}]
                    }
        
        return all_stats

In [None]:


def main():
    """
    Example usage of the NeuralSpectralAnalyzer.
    
    Configure the paths below to match your data organization.
    """
    # Configuration
    base_directory = Path("your_data_directory_here")  # Update this path
    json_input_folder = base_directory / "processed_spike_data"  # Folder with JSON files
    
    # Initialize analyzer with custom parameters
    analyzer = NeuralSpectralAnalyzer(
        base_path=base_directory,
        output_subdir="oscillation_analysis",
        bin_size_ms=2.0,           # 2ms temporal resolution
        kernel_window_ms=300,      # 300ms smoothing window
        tau_rise_ms=2.0,          # Fast synaptic rise
        tau_decay_ms=25.0         # Slower synaptic decay
    )
    
    # Process all JSON files in nested folder structure
    if json_input_folder.exists():
        results = analyzer.process_nested_folders(json_input_folder)
        print(f"Processing complete. Results: {results}")
    else:
        print(f"Input folder {json_input_folder} does not exist. Please update the path.")
        
    # Alternative: Process a single folder
    # single_folder = json_input_folder / "experimental_condition_1"
    # if single_folder.exists():
    #     stats = analyzer.process_folder(single_folder)
    #     print(f"Processed {stats['processed']} files, {stats['failed']} failed")


if __name__ == "__main__":
    main()