In [None]:
"""
Oscillatory Burst Analysis Module

This module provides comprehensive analysis of oscillatory bursts in neural time series data.
It combines temporal and spectral approaches to characterize burst dynamics across different
frequency bands (delta, theta, alpha) with detailed statistical and spectral parameterization.

Key Features:
- Dual-threshold burst detection with morphological dilation
- Frequency-band specific burst analysis (Delta: 1-4Hz, Theta: 4-8Hz, Alpha: 8-13Hz)
- Per-burst spectral parameterization using FOOOF/SpectralModel
- Comprehensive statistical characterization of burst properties
- Automated batch processing with visualization capabilities
- Standardized well plate mapping for experimental organization

Dependencies:
    - numpy, pandas, scipy, matplotlib
    - neurodsp (for filtering and spectral analysis)
    - specparam (FOOOF algorithm for spectral parameterization)
    - scikit-image or scipy.ndimage (for morphological operations)
"""

import json
import logging
import pickle
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Tuple, Union, Optional, Any

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import hilbert, find_peaks
from scipy.ndimage import binary_dilation, gaussian_filter1d, label

from neurodsp.filt import filter_signal
from neurodsp.spectral import compute_spectrum_welch
from specparam import SpectralModel


class OscillatoryBurstAnalyzer:
    """
    Analyzes oscillatory bursts in neural time series across multiple frequency bands.
    
    This analyzer detects transient oscillatory events (bursts) using a dual-threshold
    approach combined with frequency-specific filtering. Each detected burst is then
    characterized both temporally (duration, frequency) and spectrally (power, slope).
    """
    
    # Standard 48-well plate mapping (6 rows x 8 columns)
    DEFAULT_PLATEMAP = {
        'well00': 'A1', 'well01': 'A2', 'well02': 'A3', 'well03': 'A4', 
        'well04': 'A5', 'well05': 'A6', 'well06': 'A7', 'well07': 'A8',
        'well10': 'B1', 'well11': 'B2', 'well12': 'B3', 'well13': 'B4', 
        'well14': 'B5', 'well15': 'B6', 'well16': 'B7', 'well17': 'B8',
        'well20': 'C1', 'well21': 'C2', 'well22': 'C3', 'well23': 'C4', 
        'well24': 'C5', 'well25': 'C6', 'well26': 'C7', 'well27': 'C8',
        'well30': 'D1', 'well31': 'D2', 'well32': 'D3', 'well33': 'D4', 
        'well34': 'D5', 'well35': 'D6', 'well36': 'D7', 'well37': 'D8',
        'well40': 'E1', 'well41': 'E2', 'well42': 'E3', 'well43': 'E4', 
        'well44': 'E5', 'well45': 'E6', 'well46': 'E7', 'well47': 'E8',
        'well50': 'F1', 'well51': 'F2', 'well52': 'F3', 'well53': 'F4', 
        'well54': 'F5', 'well55': 'F6', 'well56': 'F7', 'well57': 'F8'
    }
    
    # Default frequency band definitions with analysis parameters
    DEFAULT_FREQUENCY_BANDS = {
        'Delta': {
            'freq_range': (1, 4),
            'min_cycles': 8,        # Minimum cycles for valid burst
            'cycles_dont_drop': 3   # Cycles to bridge gaps in detection
        },
        'Theta': {
            'freq_range': (4, 8),
            'min_cycles': 16,
            'cycles_dont_drop': 4
        },
        'Alpha': {
            'freq_range': (8, 13),
            'min_cycles': 24,
            'cycles_dont_drop': 5
        }
    }
    
    def __init__(self, 
                 base_path: Union[str, Path] = None,
                 output_subdir: str = "burst_analysis",
                 frequency_bands: Dict = None,
                 platemap: Dict = None):
        """
        Initialize the oscillatory burst analyzer.
        
        Parameters:
        -----------
        base_path : str or Path, optional
            Base directory for data processing
        output_subdir : str
            Subdirectory name for output files
        frequency_bands : dict, optional
            Custom frequency band definitions. Uses defaults if None.
        platemap : dict, optional
            Custom well-to-name mapping. Uses standard 48-well format if None.
        """
        self.base_path = Path(base_path) if base_path else Path.cwd()
        self.output_dir = self.base_path / output_subdir
        self.plots_dir = self.output_dir / "plots"
        
        self.frequency_bands = frequency_bands or self.DEFAULT_FREQUENCY_BANDS
        self.platemap = platemap or self.DEFAULT_PLATEMAP
        
        self.setup_logging()
        
    def setup_logging(self):
        """Configure logging for analysis tracking."""
        log_file = self.base_path / f'burst_analysis_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
        
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(log_file),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)

    def compute_amplitude_envelope(self, signal: np.ndarray, 
                                 sampling_rate: float,
                                 freq_range: Tuple[float, float]) -> np.ndarray:
        """
        Compute amplitude envelope with frequency-adaptive smoothing.
        
        Parameters:
        -----------
        signal : np.ndarray
            Input signal
        sampling_rate : float
            Sampling rate in Hz
        freq_range : tuple
            (low_freq, high_freq) for the frequency band
            
        Returns:
        --------
        np.ndarray
            Smoothed amplitude envelope
        """
        # Calculate amplitude envelope using Hilbert transform
        analytic_signal = hilbert(signal)
        amplitude_envelope = np.abs(analytic_signal)
        
        # Frequency-adaptive Gaussian smoothing
        low_freq, high_freq = freq_range
        center_freq = (low_freq + high_freq) / 2
        bandwidth = high_freq - low_freq
        
        # Adaptive smoothing sigma based on frequency characteristics
        sigma = sampling_rate * (0.4 / center_freq + 0.2 * (bandwidth / center_freq))
        smoothed_envelope = gaussian_filter1d(amplitude_envelope, sigma)
        
        return smoothed_envelope

    def detect_bursts_dual_threshold(self, 
                                   amplitude_envelope: np.ndarray,
                                   sampling_rate: float,
                                   freq_range: Tuple[float, float],
                                   min_cycles: int,
                                   cycles_dont_drop: int,
                                   low_percentile: float = 50,
                                   high_percentile: float = 75) -> np.ndarray:
        """
        Detect oscillatory bursts using dual-threshold method with morphological operations.
        
        This method:
        1. Sets thresholds based on amplitude distribution percentiles
        2. Identifies candidate regions above low threshold containing high threshold crossings
        3. Applies morphological dilation to bridge brief gaps
        4. Filters bursts by minimum duration requirements
        
        Parameters:
        -----------
        amplitude_envelope : np.ndarray
            Smoothed amplitude envelope
        sampling_rate : float
            Sampling rate in Hz
        freq_range : tuple
            (low_freq, high_freq) for minimum cycle calculations
        min_cycles : int
            Minimum number of cycles for a valid burst
        cycles_dont_drop : int
            Number of cycles to bridge gaps in detection
        low_percentile : float
            Percentile for low threshold (default: 50th percentile)
        high_percentile : float
            Percentile for high threshold (default: 75th percentile)
            
        Returns:
        --------
        np.ndarray
            Boolean array indicating burst periods
        """
        # Calculate thresholds from amplitude distribution
        low_threshold = np.percentile(amplitude_envelope, low_percentile)
        high_threshold = np.percentile(amplitude_envelope, high_percentile)
        
        # Pre-calculate cycle-based parameters
        freq_low = freq_range[0]  # Use lowest frequency for conservative estimates
        samples_per_cycle = sampling_rate / freq_low
        max_drop_samples = int(cycles_dont_drop * samples_per_cycle)
        min_duration_samples = int(min_cycles * samples_per_cycle)
        
        # Create morphological structure for dilation
        dilation_structure = np.ones(max_drop_samples)
        
        # Initialize burst detection array
        is_burst = np.zeros_like(amplitude_envelope, dtype=bool)
        
        # Find regions above thresholds
        high_threshold_crossings = amplitude_envelope >= high_threshold
        low_threshold_crossings = amplitude_envelope >= low_threshold
        
        # Label connected regions above low threshold
        labeled_regions, num_regions = label(low_threshold_crossings)
        
        # Process each potential burst region
        for region_id in range(1, num_regions + 1):
            region_mask = labeled_regions == region_id
            
            # Check if region contains high threshold crossings
            if np.any(high_threshold_crossings[region_mask]):
                # Apply morphological dilation to bridge gaps
                dilated_region = binary_dilation(region_mask, structure=dilation_structure)
                
                # Ensure we only get the connected component containing original region
                connected_labels, _ = label(dilated_region)
                original_label = connected_labels[region_mask][0]
                final_region = connected_labels == original_label
                
                # Check minimum duration requirement
                if np.sum(final_region) >= min_duration_samples:
                    is_burst[final_region] = True
        
        return is_burst

    def compute_burst_statistics(self, 
                               is_burst: np.ndarray, 
                               sampling_rate: float,
                               freq_range: Tuple[float, float],
                               cycles_dont_drop: int) -> Dict[str, float]:
        """
        Compute comprehensive burst statistics.
        
        Parameters:
        -----------
        is_burst : np.ndarray
            Boolean array indicating burst periods
        sampling_rate : float
            Sampling rate in Hz
        freq_range : tuple
            Frequency range for the analysis band
        cycles_dont_drop : int
            Cycles parameter used in detection
            
        Returns:
        --------
        dict
            Dictionary containing burst statistics
        """
        # Apply same dilation as in detection for consistent measurement
        freq_low = freq_range[1]  # Use high frequency for gap bridging
        samples_per_cycle = sampling_rate / freq_low
        max_drop_samples = int(cycles_dont_drop * samples_per_cycle)
        dilation_structure = np.ones(max_drop_samples)
        
        is_burst_merged = binary_dilation(is_burst, structure=dilation_structure)
        
        # Find burst segments
        burst_edges = np.diff(is_burst_merged.astype(int))
        burst_starts = np.where(burst_edges == 1)[0] + 1
        burst_ends = np.where(burst_edges == -1)[0] + 1
        
        # Handle edge cases
        if is_burst_merged[0]:
            burst_starts = np.insert(burst_starts, 0, 0)
        if is_burst_merged[-1]:
            burst_ends = np.append(burst_ends, len(is_burst_merged))
        
        # Calculate durations in seconds
        durations = (burst_ends - burst_starts) / sampling_rate
        total_recording_time = len(is_burst_merged) / sampling_rate
        
        return {
            'n_bursts': int(len(durations)),
            'duration_mean': float(np.mean(durations) if len(durations) > 0 else 0),
            'duration_std': float(np.std(durations) if len(durations) > 0 else 0),
            'percent_burst': float(100 * np.mean(is_burst_merged)),
            'bursts_per_second': float(len(durations) / total_recording_time)
        }

    def analyze_burst_spectra(self, 
                            signal: np.ndarray,
                            is_burst: np.ndarray,
                            sampling_rate: float,
                            freq_range: Tuple[float, float]) -> Dict[str, List[float]]:
        """
        Analyze spectral properties of individual bursts.
        
        Parameters:
        -----------
        signal : np.ndarray
            Original (unfiltered) signal
        is_burst : np.ndarray
            Boolean array indicating burst periods
        sampling_rate : float
            Sampling rate in Hz
        freq_range : tuple
            Frequency range for peak detection
            
        Returns:
        --------
        dict
            Dictionary containing lists of per-burst spectral parameters
        """
        # Find connected burst segments
        labeled_bursts, num_bursts = label(is_burst)
        
        # Initialize lists for burst-wise parameters
        burst_offsets = []
        burst_exponents = []
        burst_r_squared = []
        burst_peak_freqs = []
        burst_peak_powers = []
        
        for burst_id in range(1, num_bursts + 1):
            burst_mask = labeled_bursts == burst_id
            burst_indices = np.where(burst_mask)[0]
            
            if len(burst_indices) == 0:
                continue
                
            # Extract burst segment from original signal
            burst_start, burst_end = burst_indices[0], burst_indices[-1]
            burst_segment = signal[burst_start:burst_end]
            
            # Only analyze bursts longer than 1 second for reliable spectral estimation
            if len(burst_segment) < sampling_rate:
                continue
                
            try:
                # Compute power spectrum for this burst
                nperseg = min(len(burst_segment), int(sampling_rate * 4))
                noverlap = nperseg // 2
                
                frequencies, power_spectrum = compute_spectrum_welch(
                    burst_segment, 
                    sampling_rate,
                    avg_type='median',
                    window='hann',
                    nperseg=nperseg,
                    noverlap=noverlap,
                    f_range=[0.05, 40]
                )
                
                # Fit SpectralModel to burst
                burst_spectral_model = SpectralModel(
                    peak_width_limits=[1, 8],
                    min_peak_height=0.2,
                    max_n_peaks=4,
                    peak_threshold=1.5,
                    aperiodic_mode='fixed'
                )
                
                burst_spectral_model.fit(frequencies, power_spectrum, [0.5, 13])
                
                # Extract aperiodic parameters
                offset, exponent = burst_spectral_model.get_params('aperiodic_params')
                burst_offsets.append(offset)
                burst_exponents.append(exponent)
                burst_r_squared.append(burst_spectral_model.r_squared_)
                
                # Find peak in frequency band using flattened spectrum
                band_mask = ((frequencies >= freq_range[0]) & 
                           (frequencies <= freq_range[1]))
                
                if np.any(band_mask):
                    band_frequencies = frequencies[band_mask]
                    flattened_spectrum = burst_spectral_model._spectrum_flat[band_mask]
                    
                    # Find maximum in the frequency band
                    peak_index = np.argmax(flattened_spectrum)
                    peak_frequency = band_frequencies[peak_index]
                    peak_power = flattened_spectrum[peak_index]
                    
                    burst_peak_freqs.append(peak_frequency)
                    burst_peak_powers.append(peak_power)
                
            except Exception as e:
                self.logger.warning(f"Failed to analyze burst spectrum: {str(e)}")
                continue
        
        return {
            'offsets': burst_offsets,
            'exponents': burst_exponents,
            'r_squared': burst_r_squared,
            'peak_frequencies': burst_peak_freqs,
            'peak_powers': burst_peak_powers
        }

    def analyze_frequency_band(self, 
                             signal: np.ndarray, 
                             sampling_rate: float,
                             band_name: str,
                             band_params: Dict) -> Dict[str, Any]:
        """
        Complete analysis pipeline for a single frequency band.
        
        Parameters:
        -----------
        signal : np.ndarray
            Input neural signal
        sampling_rate : float
            Sampling rate in Hz
        band_name : str
            Name of frequency band (e.g., 'Delta', 'Theta', 'Alpha')
        band_params : dict
            Parameters for this frequency band
            
        Returns:
        --------
        dict
            Comprehensive analysis results for this frequency band
        """
        freq_range = band_params['freq_range']
        min_cycles = band_params['min_cycles']
        cycles_dont_drop = band_params['cycles_dont_drop']
        
        # Filter signal for burst detection
        try:
            filtered_signal = filter_signal(
                signal, sampling_rate, 
                pass_type='bandpass', 
                f_range=freq_range,
                filter_type='fir', 
                remove_edges=True
            )
            
            # Remove NaN values from filtering
            valid_indices = ~np.isnan(filtered_signal)
            filtered_signal = filtered_signal[valid_indices]
            original_signal_trimmed = signal[valid_indices]
            
        except Exception as e:
            self.logger.error(f"Failed to filter signal for {band_name}: {str(e)}")
            return self._create_empty_band_results()
        
        if len(filtered_signal) == 0:
            self.logger.warning(f"No valid signal after filtering for {band_name}")
            return self._create_empty_band_results()
        
        # Compute amplitude envelope
        amplitude_envelope = self.compute_amplitude_envelope(
            filtered_signal, sampling_rate, freq_range
        )
        
        # Detect bursts
        is_burst = self.detect_bursts_dual_threshold(
            amplitude_envelope, sampling_rate, freq_range, 
            min_cycles, cycles_dont_drop
        )
        
        # Compute basic burst statistics
        burst_stats = self.compute_burst_statistics(
            is_burst, sampling_rate, freq_range, cycles_dont_drop
        )
        
        # Analyze individual burst spectra
        burst_spectral_params = self.analyze_burst_spectra(
            original_signal_trimmed, is_burst, sampling_rate, freq_range
        )
        
        # Add averaged spectral parameters if bursts were found
        if burst_spectral_params['offsets']:
            burst_stats.update({
                'mean_offset': float(np.mean(burst_spectral_params['offsets'])),
                'std_offset': float(np.std(burst_spectral_params['offsets'])),
                'mean_exponent': float(np.mean(burst_spectral_params['exponents'])),
                'std_exponent': float(np.std(burst_spectral_params['exponents'])),
                'mean_r_squared': float(np.mean(burst_spectral_params['r_squared'])),
                'mean_peak_freq': float(np.mean(burst_spectral_params['peak_frequencies'])),
                'std_peak_freq': float(np.std(burst_spectral_params['peak_frequencies'])),
                'mean_peak_power': float(np.mean(burst_spectral_params['peak_powers'])),
                'std_peak_power': float(np.std(burst_spectral_params['peak_powers']))
            })
        
        # Store additional analysis data for plotting
        analysis_data = {
            'filtered_signal': filtered_signal,
            'amplitude_envelope': amplitude_envelope,
            'is_burst': is_burst,
            'burst_spectral_params': burst_spectral_params
        }
        
        return {
            'statistics': burst_stats,
            'analysis_data': analysis_data
        }

    def _create_empty_band_results(self) -> Dict[str, Any]:
        """Create empty results structure for failed band analysis."""
        empty_stats = {
            'n_bursts': 0,
            'duration_mean': 0.0,
            'duration_std': 0.0,
            'percent_burst': 0.0,
            'bursts_per_second': 0.0
        }
        
        return {
            'statistics': empty_stats,
            'analysis_data': None
        }

    def create_diagnostic_plots(self, 
                              signal: np.ndarray,
                              sampling_rate: float,
                              band_results: Dict,
                              band_name: str) -> plt.Figure:
        """
        Create diagnostic plots for burst analysis.
        
        Parameters:
        -----------
        signal : np.ndarray
            Original signal
        sampling_rate : float
            Sampling rate in Hz
        band_results : dict
            Results from analyze_frequency_band
        band_name : str
            Name of the frequency band
            
        Returns:
        --------
        matplotlib.Figure
            Figure with diagnostic plots
        """
        analysis_data = band_results['analysis_data']
        if analysis_data is None:
            # Create empty plot for failed analysis
            fig, ax = plt.subplots(figsize=(12, 6))
            ax.text(0.5, 0.5, f'Analysis failed for {band_name} band', 
                   ha='center', va='center', transform=ax.transAxes)
            return fig
        
        filtered_signal = analysis_data['filtered_signal']
        amplitude_envelope = analysis_data['amplitude_envelope']
        is_burst = analysis_data['is_burst']
        stats = band_results['statistics']
        
        # Create time axis
        time_axis = np.arange(len(filtered_signal)) / sampling_rate
        
        # Create figure with subplots
        fig, axes = plt.subplots(3, 1, figsize=(15, 10), sharex=True)
        
        # Plot 1: Original signal (trimmed to match filtered length)
        signal_trimmed = signal[:len(filtered_signal)]
        axes[0].plot(time_axis, signal_trimmed, color='#2E4053', alpha=0.7, linewidth=0.8)
        axes[0].set_ylabel('Amplitude')
        axes[0].set_title(f'{band_name} Band Analysis - Original Signal')
        axes[0].grid(True, alpha=0.3)
        
        # Plot 2: Filtered signal
        axes[1].plot(time_axis, filtered_signal, color='#2E4053', linewidth=0.8)
        axes[1].set_ylabel('Amplitude')
        axes[1].set_title('Band-pass Filtered Signal')
        axes[1].grid(True, alpha=0.3)
        
        # Plot 3: Burst detection
        axes[2].plot(time_axis, amplitude_envelope, color='#27AE60', alpha=0.6, 
                    linewidth=1, label='Amplitude Envelope')
        
        # Add threshold lines
        low_threshold = np.percentile(amplitude_envelope, 50)
        high_threshold = np.percentile(amplitude_envelope, 75)
        axes[2].axhline(y=high_threshold, color='#C0392B', linestyle='--', 
                       alpha=0.8, label='High Threshold (75th %ile)')
        axes[2].axhline(y=low_threshold, color='#E67E22', linestyle='--', 
                       alpha=0.8, label='Low Threshold (50th %ile)')
        
        # Highlight detected bursts
        burst_envelope = np.ma.masked_where(~is_burst, amplitude_envelope)
        axes[2].plot(time_axis, burst_envelope, color='#E74C3C', 
                    linewidth=2, label='Detected Bursts')
        
        axes[2].set_ylabel('Amplitude')
        axes[2].set_xlabel('Time (s)')
        axes[2].legend(loc='upper right')
        axes[2].set_title('Burst Detection Results')
        axes[2].grid(True, alpha=0.3)
        
        # Add statistics text
        stats_text = (
            f"Bursts detected: {stats['n_bursts']}\n"
            f"Mean duration: {stats['duration_mean']:.3f} s\n"
            f"Burst rate: {stats['bursts_per_second']:.3f} bursts/s\n"
            f"Time in bursts: {stats['percent_burst']:.1f}%"
        )
        
        plt.figtext(0.02, 0.02, stats_text, fontsize=10, 
                   bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8))
        
        plt.tight_layout()
        plt.subplots_adjust(bottom=0.15)
        
        return fig

    def create_spectral_plot(self, spectral_model: SpectralModel) -> plt.Figure:
        """Create a plot of the fitted spectral model."""
        fig, ax = plt.subplots(figsize=(10, 6))
        spectral_model.plot(ax=ax)
        ax.set_xlim(0, 30)  # Focus on relevant frequency range
        ax.set_title('Whole-Signal Spectral Analysis')
        plt.tight_layout()
        return fig

    def analyze_well_file(self, 
                         pickle_file_path: Union[str, Path],
                         create_plots: bool = False) -> Dict[str, Any]:
        """
        Analyze a single well's preprocessed data file.
        
        Parameters:
        -----------
        pickle_file_path : str or Path
            Path to pickle file containing preprocessed data
        create_plots : bool
            Whether to generate diagnostic plots
            
        Returns:
        --------
        dict
            Complete analysis results for all frequency bands
        """
        pickle_file_path = Path(pickle_file_path)
        
        try:
            # Load preprocessed data
            with open(pickle_file_path, 'rb') as f:
                processed_data = pickle.load(f)
            
            # Extract signal and metadata
            signal = np.array(processed_data['smoothed'])
            sampling_rate = processed_data['new_fs']
            spectral_model = processed_data.get('spectral_model', None)
            
            if len(signal) == 0:
                self.logger.warning(f"Empty signal in {pickle_file_path.name}")
                return self._create_empty_well_results()
            
            # Initialize results storage
            well_results = {
                'metadata': {
                    'file_name': pickle_file_path.name,
                    'sampling_rate': sampling_rate,
                    'signal_length': len(signal),
                    'duration_seconds': len(signal) / sampling_rate
                },
                'band_analyses': {},
                'whole_signal_spectral': None,
                'plots': {} if create_plots else None
            }
            
            # Analyze each frequency band
            for band_name, band_params in self.frequency_bands.items():
                self.logger.info(f"Analyzing {band_name} band for {pickle_file_path.name}")
                
                try:
                    band_results = self.analyze_frequency_band(
                        signal, sampling_rate, band_name, band_params
                    )
                    well_results['band_analyses'][band_name] = band_results
                    
                    # Create diagnostic plots if requested
                    if create_plots:
                        plot_fig = self.create_diagnostic_plots(
                            signal, sampling_rate, band_results, band_name
                        )
                        well_results['plots'][f'{band_name}_analysis'] = plot_fig
                        
                except Exception as e:
                    self.logger.error(f"Error analyzing {band_name} band: {str(e)}")
                    well_results['band_analyses'][band_name] = self._create_empty_band_results()
            
            # Store whole-signal spectral analysis
            if spectral_model is not None:
                well_results['whole_signal_spectral'] = spectral_model
                
                if create_plots:
                    spectral_fig = self.create_spectral_plot(spectral_model)
                    well_results['plots']['whole_signal_spectrum'] = spectral_fig
            
            return well_results
            
        except Exception as e:
            self.logger.error(f"Error analyzing {pickle_file_path.name}: {str(e)}")
            return self._create_empty_well_results()

    def _create_empty_well_results(self) -> Dict[str, Any]:
        """Create empty results structure for failed well analysis."""
        return {
            'metadata': {'error': 'Analysis failed'},
            'band_analyses': {},
            'whole_signal_spectral': None,
            'plots': None
        }

    def extract_experimental_info(self, file_path: Path) -> Dict[str, str]:
        """
        Extract experimental information from file/folder naming conventions.
        
        This method should be customized based on your naming conventions.
        """
        file_name = file_path.stem
        folder_name = file_path.parent.name
        
        # Extract well information
        well_parts = file_name.split('_')
        well_number = well_parts[0].lower() if well_parts else 'unknown'
        well_name = self.platemap.get(well_number, f"Unknown-{well_number}")
        
        # Extract timepoint/condition information from folder name
        # Customize this logic based on your naming conventions
        timepoint = 'unknown'
        condition = 'unknown'
        
        folder_parts = folder_name.split('_')
        if len(folder_parts) >= 2:
            # Example logic - adapt to your naming scheme
            if folder_name.endswith('_PTX'):
                timepoint = folder_parts[-2] if len(folder_parts) > 1 else 'unknown'
                condition = 'PTX'
            elif folder_name.endswith('_Eggan'):
                timepoint = folder_parts[-2] if len(folder_parts) > 1 else 'unknown'
                condition = 'Eggan'
            elif '_' in folder_name:
                timepoint = folder_parts[-1]
                condition = '_'.join(folder_parts[:-1])
        
        return {
            'well_number': well_number,
            'well_name': well_name,
            'timepoint': timepoint,
            'condition': condition
        }

    def compile_results_to_dataframe(self, well_results: Dict[str, Any], 
                                   experimental_info: Dict[str, str]) -> Dict[str, Any]:
        """
        Compile analysis results into a flat dictionary suitable for DataFrame conversion.
        
        Parameters:
        -----------
        well_results : dict
            Results from analyze_well_file
        experimental_info : dict
            Experimental metadata
            
        Returns:
        --------
        dict
            Flattened results dictionary
        """
        compiled_results = {
            'Well': experimental_info['well_number'],
            'Well_name': experimental_info['well_name'],
            'Timepoint': experimental_info['timepoint'],
            'Condition': experimental_info['condition']
        }
        
        # Add whole-signal spectral parameters
        spectral_model = well_results.get('whole_signal_spectral')
        if spectral_model is not None:
            try:
                offset, exponent = spectral_model.get_params('aperiodic_params')
                compiled_results.update({
                    'Offset': float(offset),
                    'Exponent': float(exponent),
                    'R_squared': float(spectral_model.r_squared_)
                })
                
                # Add band-specific SNR and peak frequencies from whole signal
                frequencies = spectral_model.freqs
                flattened_spectrum = spectral_model._spectrum_flat
                
                for band_name, band_params in self.frequency_bands.items():
                    freq_range = band_params['freq_range']
                    band_mask = ((frequencies >= freq_range[0]) & 
                               (frequencies <= freq_range[1]))
                    
                    if np.any(band_mask):
                        band_spectrum = flattened_spectrum[band_mask]
                        band_frequencies = frequencies[band_mask]
                        
                        max_idx = np.argmax(band_spectrum)
                        max_snr = float(band_spectrum[max_idx])
                        peak_freq = float(band_frequencies[max_idx])
                        
                        compiled_results[f"{band_name}_SNR"] = max_snr
                        compiled_results[f"{band_name}_peak_frequency"] = peak_freq
                    else:
                        compiled_results[f"{band_name}_SNR"] = float('nan')
                        compiled_results[f"{band_name}_peak_frequency"] = float('nan')
                        
            except Exception as e:
                self.logger.warning(f"Error extracting spectral parameters: {str(e)}")
        
        # Add burst analysis results for each frequency band
        band_analyses = well_results.get('band_analyses', {})
        
        for band_name, band_results in band_analyses.items():
            band_stats = band_results.get('statistics', {})
            
            # Required burst statistics
            required_keys = ['duration_mean', 'n_bursts', 'bursts_per_second', 
                           'percent_burst', 'duration_std']
            
            # Add basic burst statistics
            for key in required_keys:
                if key in band_stats:
                    if key == 'n_bursts':
                        compiled_results[f"{band_name}_Burst_Number"] = int(band_stats[key])
                    elif key == 'duration_mean':
                        compiled_results[f"{band_name}_Mean_Burst_Duration"] = float(band_stats[key])
                    elif key == 'bursts_per_second':
                        compiled_results[f"{band_name}_Burst_Frequency"] = float(band_stats[key])
                    elif key == 'percent_burst':
                        compiled_results[f"{band_name}_Percent_Burst"] = float(band_stats[key])
                    elif key == 'duration_std':
                        compiled_results[f"{band_name}_Duration_Std"] = float(band_stats[key])
                else:
                    # Set default values for missing statistics
                    default_val = 0 if key == 'n_bursts' else 0.0
                    compiled_results[f"{band_name}_{key}"] = default_val
            
            # Add per-burst spectral parameters if available
            spectral_keys = ['mean_offset', 'std_offset', 'mean_exponent', 'std_exponent',
                           'mean_r_squared', 'mean_peak_freq', 'std_peak_freq', 
                           'mean_peak_power', 'std_peak_power']
            
            for key in spectral_keys:
                if key in band_stats:
                    # Convert camelCase to readable format
                    formatted_key = key.replace('_', ' ').title().replace(' ', '_')
                    compiled_results[f"{band_name}_Burst_{formatted_key}"] = float(band_stats[key])
        
        return compiled_results

    def save_plots(self, plots: Dict[str, plt.Figure], 
                  output_folder: Path, 
                  file_prefix: str):
        """
        Save analysis plots to files.
        
        Parameters:
        -----------
        plots : dict
            Dictionary of plot names to matplotlib figures
        output_folder : Path
            Directory to save plots
        file_prefix : str
            Prefix for plot filenames
        """
        output_folder.mkdir(parents=True, exist_ok=True)
        
        for plot_name, figure in plots.items():
            try:
                plot_path = output_folder / f"{file_prefix}_{plot_name}.png"
                figure.savefig(plot_path, dpi=300, bbox_inches='tight')
                plt.close(figure)
                self.logger.info(f"Saved plot: {plot_path}")
            except Exception as e:
                self.logger.error(f"Error saving plot {plot_name}: {str(e)}")

    def process_folder(self, 
                      input_folder: Union[str, Path],
                      create_plots: bool = False,
                      plot_patterns: Optional[List[str]] = None) -> Tuple[List[Dict], List[Dict]]:
        """
        Process all pickle files in a folder.
        
        Parameters:
        -----------
        input_folder : str or Path
            Folder containing pickle files
        create_plots : bool
            Whether to generate diagnostic plots
        plot_patterns : list, optional
            List of well patterns to plot (e.g., ['well55', 'well56']). 
            If None, plots all wells when create_plots=True.
            
        Returns:
        --------
        tuple
            (results_list, errors_list)
        """
        input_folder = Path(input_folder)
        results_list = []
        errors_list = []
        
        # Setup plots directory if needed
        plots_folder = None
        if create_plots:
            plots_folder = self.plots_dir / input_folder.name
            plots_folder.mkdir(parents=True, exist_ok=True)
        
        # Process each pickle file
        pickle_files = list(input_folder.glob('*.pkl'))
        self.logger.info(f"Processing {len(pickle_files)} files in {input_folder.name}")
        
        for pickle_file in pickle_files:
            try:
                # Extract experimental information
                experimental_info = self.extract_experimental_info(pickle_file)
                
                # Determine if we should create plots for this well
                should_plot = False
                if create_plots:
                    if plot_patterns is None:
                        should_plot = True
                    else:
                        well_name = experimental_info['well_number'].lower()
                        should_plot = any(well_name.startswith(pattern.lower()) 
                                        for pattern in plot_patterns)
                
                # Analyze the well
                well_results = self.analyze_well_file(pickle_file, create_plots=should_plot)
                
                # Compile results for DataFrame
                compiled_results = self.compile_results_to_dataframe(
                    well_results, experimental_info
                )
                results_list.append(compiled_results)
                
                # Save plots if generated
                if should_plot and well_results.get('plots'):
                    self.save_plots(
                        well_results['plots'], 
                        plots_folder, 
                        pickle_file.stem
                    )
                
                self.logger.info(f"Successfully processed {pickle_file.name}")
                
            except Exception as e:
                error_info = {
                    'file': str(pickle_file),
                    'error': str(e),
                    'error_type': type(e).__name__,
                    'timestamp': datetime.now().isoformat()
                }
                errors_list.append(error_info)
                self.logger.error(f"Error processing {pickle_file.name}: {str(e)}")
        
        return results_list, errors_list

    def process_experiment(self, 
                          input_base_folder: Union[str, Path],
                          experiment_id: str = None,
                          create_plots: bool = False,
                          plot_patterns: Optional[List[str]] = None) -> Dict[str, Any]:
        """
        Process an entire experiment with multiple timepoints/conditions.
        
        Parameters:
        -----------
        input_base_folder : str or Path
            Base folder containing subfolders with pickle files
        experiment_id : str, optional
            Identifier for the experiment (used in output filenames)
        create_plots : bool
            Whether to generate diagnostic plots
        plot_patterns : list, optional
            List of well patterns to plot
            
        Returns:
        --------
        dict
            Processing summary and statistics
        """
        input_base_folder = Path(input_base_folder)
        
        if experiment_id is None:
            experiment_id = input_base_folder.name
        
        # Initialize aggregation lists
        all_results = []
        all_errors = []
        processing_stats = {}
        
        # Process each subfolder
        subfolders = [f for f in input_base_folder.iterdir() if f.is_dir()]
        self.logger.info(f"Processing experiment '{experiment_id}' with {len(subfolders)} timepoints")
        
        for subfolder in sorted(subfolders):
            try:
                self.logger.info(f"Processing folder: {subfolder.name}")
                
                folder_results, folder_errors = self.process_folder(
                    subfolder, create_plots, plot_patterns
                )
                
                all_results.extend(folder_results)
                all_errors.extend(folder_errors)
                
                processing_stats[subfolder.name] = {
                    'processed': len(folder_results),
                    'failed': len(folder_errors)
                }
                
            except Exception as e:
                self.logger.error(f"Error processing folder {subfolder.name}: {str(e)}")
                all_errors.append({
                    'folder': str(subfolder),
                    'error': str(e),
                    'error_type': type(e).__name__,
                    'timestamp': datetime.now().isoformat()
                })
        
        # Create comprehensive results DataFrame
        if all_results:
            results_df = pd.DataFrame(all_results)
            
            # Sort by well number for consistent ordering
            if 'Well' in results_df.columns:
                results_df['well_sort_key'] = results_df['Well'].str.extract(r'(\d+)').astype(int)
                results_df = results_df.sort_values(['well_sort_key', 'Timepoint'])
                results_df = results_df.drop('well_sort_key', axis=1)
            
            # Save results
            output_file = self.output_dir / f"{experiment_id}_burst_analysis_results.csv"
            output_file.parent.mkdir(parents=True, exist_ok=True)
            results_df.to_csv(output_file, index=False)
            self.logger.info(f"Saved combined results to {output_file}")
        
        # Save error log if there were errors
        if all_errors:
            error_file = self.output_dir / f"{experiment_id}_analysis_errors.json"
            with open(error_file, 'w') as f:
                json.dump({
                    'experiment_id': experiment_id,
                    'processing_date': datetime.now().isoformat(),
                    'errors': all_errors
                }, f, indent=2)
            self.logger.info(f"Saved error log to {error_file}")
        
        # Compile summary statistics
        total_processed = sum(stats['processed'] for stats in processing_stats.values())
        total_failed = sum(stats['failed'] for stats in processing_stats.values())
        
        summary = {
            'experiment_id': experiment_id,
            'total_files_processed': total_processed,
            'total_files_failed': total_failed,
            'folder_statistics': processing_stats,
            'output_files': {
                'results_csv': str(output_file) if all_results else None,
                'error_log': str(error_file) if all_errors else None
            }
        }
        
        self.logger.info(f"Experiment processing complete: {total_processed} successful, {total_failed} failed")
        return summary


def main():
    """
    Example usage of the OscillatoryBurstAnalyzer.
    
    Configure the paths and parameters below to match your experimental setup.
    """
    # Configuration
    base_directory = Path("your_data_directory_here")  # Update this path
    processed_data_folder = base_directory / "spectral_analysis"  # Folder with pickle files
    
    # Custom frequency bands (optional - remove to use defaults)
    custom_bands = {
        'Delta': {'freq_range': (1, 4), 'min_cycles': 8, 'cycles_dont_drop': 3},
        'Theta': {'freq_range': (4, 8), 'min_cycles': 16, 'cycles_dont_drop': 4},
        'Alpha': {'freq_range': (8, 13), 'min_cycles': 24, 'cycles_dont_drop': 5},
        'Beta': {'freq_range': (13, 30), 'min_cycles': 32, 'cycles_dont_drop': 6}
    }
    
    # Initialize analyzer
    analyzer = OscillatoryBurstAnalyzer(
        base_path=base_directory,
        output_subdir="burst_analysis_results",
        frequency_bands=custom_bands  # Remove this line to use defaults
    )
    
    # Process entire experiment
    if processed_data_folder.exists():
        # Option 1: Process with plots for specific wells
        summary = analyzer.process_experiment(
            processed_data_folder,
            experiment_id="drug_screen_experiment",
            create_plots=True,
            plot_patterns=['well55', 'well56']  # Only plot these wells
        )
        
        # Option 2: Process without plots (faster)
        # summary = analyzer.process_experiment(
        #     processed_data_folder,
        #     experiment_id="drug_screen_experiment",
        #     create_plots=False
        # )
        
        print(f"Processing complete. Summary: {summary}")
        
    else:
        print(f"Input folder {processed_data_folder} does not exist. Please update the path.")


if __name__ == "__main__":
    main()