In [None]:
"""
Neural Activity Analysis for Multi-Electrode Array (MEA) Data

This module processes spike train data from MEA plates, detecting bursts and
analyzing neural activity patterns. Designed for drug screening experiments
using electrophysiological recordings.

Dependencies:
    - numpy, pandas, matplotlib
    - mat73 (for MATLAB v7.3 file loading)
    - scipy
    - pathlib, logging, json
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import os
import json
import mat73
import pickle
from pathlib import Path
import logging
from datetime import datetime
from scipy.signal import find_peaks, hilbert


class MEADataProcessor:
    """
    Processes Multi-Electrode Array (MEA) data for neural activity analysis.
    
    This class handles spike train analysis, burst detection, and visualization
    for electrophysiological data stored in MATLAB format.
    """
    
    def __init__(self, base_path=None, output_subdir="processed_data"):
        """
        Initialize the MEA data processor.
        
        Parameters:
        -----------
        base_path : str or Path, optional
            Base directory for data processing. If None, uses current directory.
        output_subdir : str
            Subdirectory name for output files
        """
        self.base_path = Path(base_path) if base_path else Path.cwd()
        self.output_dir = self.base_path / output_subdir
        self.setup_logging()
        
    def setup_logging(self):
        """Set up logging configuration."""
        log_file = self.base_path / f'mea_processing_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
        
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(log_file),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)

    def load_spike_data(self, filepath, well_row, well_col, channel_row, channel_col):
        """
        Extract spike times from a single channel in a MEA plate.
        
        Parameters:
        -----------
        filepath : str or Path
            Path to the .mat file containing plate data
        well_row, well_col : int
            Row and column indices of the well (0-indexed)
        channel_row, channel_col : int
            Row and column indices of the channel within the well (0-indexed)
            
        Returns:
        --------
        pd.DataFrame
            DataFrame with 'Time (s)' column containing spike times
        """
        try:
            data_dict = mat73.loadmat(str(filepath))
            well = data_dict['Plate'][well_row][well_col]
            channel = well[channel_row][channel_col]
            
            if not isinstance(channel, np.ndarray):
                channel = np.array([])
                self.logger.warning(f'Empty channel at ({channel_row}, {channel_col})')
                
            return pd.DataFrame({'Time (s)': channel})
            
        except Exception as e:
            self.logger.error(f"Error loading spike data: {str(e)}")
            return pd.DataFrame({'Time (s)': np.array([])})

    def detect_bursts(self, spiketrain, max_begin_isi=0.17, max_end_isi=0.3, 
                     min_ibi=0.2, min_burst_duration=0.01, min_spikes_in_burst=5):
        """
        Detect bursts in spike train data using the MaxInterval method.
        
        This three-phase algorithm:
        1. Identifies potential bursts based on inter-spike intervals (ISI)
        2. Merges bursts separated by short inter-burst intervals (IBI)
        3. Filters out bursts that don't meet duration/spike count criteria
        
        Parameters:
        -----------
        spiketrain : array-like
            Array of spike times in seconds
        max_begin_isi : float
            Maximum ISI to begin a burst (seconds)
        max_end_isi : float
            Maximum ISI to continue a burst (seconds)
        min_ibi : float
            Minimum inter-burst interval for separate bursts (seconds)
        min_burst_duration : float
            Minimum duration for a valid burst (seconds)
        min_spikes_in_burst : int
            Minimum number of spikes for a valid burst
            
        Returns:
        --------
        tuple
            (burst_data_dict, num_rejected_bursts)
            burst_data_dict: Dictionary with burst indices as keys, spike times as values
            num_rejected_bursts: Number of bursts rejected in quality control
        """
        if len(spiketrain) < 2:
            return {}, 0
            
        all_burst_data = {}
        
        # Phase 1: Initial burst detection
        in_burst = False
        burst_num = 0
        current_burst = []
        
        for n in range(1, len(spiketrain)):
            isi = spiketrain[n] - spiketrain[n - 1]
            
            if in_burst:
                if isi > max_end_isi:  # End the burst
                    current_burst.append(spiketrain[n - 1])
                    all_burst_data[burst_num] = np.array(current_burst)
                    current_burst = []
                    burst_num += 1
                    in_burst = False
                elif isi <= max_end_isi and n == len(spiketrain) - 1:
                    current_burst.extend([spiketrain[n - 1], spiketrain[n]])
                    all_burst_data[burst_num] = np.array(current_burst)
                    burst_num += 1
                else:
                    current_burst.append(spiketrain[n - 1])
            else:
                if isi < max_begin_isi:
                    current_burst.append(spiketrain[n - 1])
                    in_burst = True
        
        if burst_num == 0:
            return {}, 0
            
        # Calculate inter-burst intervals (IBIs)
        ibi_list = []
        for b in range(1, burst_num):
            prev_burst_end = all_burst_data[b - 1][-1]
            curr_burst_begin = all_burst_data[b][0]
            ibi_list.append(curr_burst_begin - prev_burst_end)
        
        # Phase 2: Merge bursts with short IBIs
        if len(ibi_list) > 0:
            temp_bursts = all_burst_data.copy()
            all_burst_data = {}
            merged_burst_num = 0
            
            current_merged = temp_bursts[0]
            
            for b in range(1, len(temp_bursts)):
                if b-1 < len(ibi_list) and ibi_list[b-1] < min_ibi:
                    # Merge with previous burst
                    current_merged = np.concatenate([current_merged, temp_bursts[b]])
                else:
                    # Save previous merged burst and start new one
                    all_burst_data[merged_burst_num] = current_merged
                    merged_burst_num += 1
                    current_merged = temp_bursts[b]
            
            # Don't forget the last burst
            all_burst_data[merged_burst_num] = current_merged
        
        # Phase 3: Quality control
        temp_bursts = all_burst_data.copy()
        all_burst_data = {}
        final_burst_num = 0
        rejected_count = 0
        
        for burst_data in temp_bursts.values():
            burst_duration = burst_data[-1] - burst_data[0] if len(burst_data) > 0 else 0
            
            if (len(burst_data) >= min_spikes_in_burst and 
                burst_duration >= min_burst_duration):
                all_burst_data[final_burst_num] = burst_data
                final_burst_num += 1
            else:
                rejected_count += 1
        
        return all_burst_data, rejected_count

    def analyze_well(self, filename, well_row, well_col, plot_spikes=False, plot_bursts=False):
        """
        Analyze all channels in a single well of a MEA plate.
        
        Parameters:
        -----------
        filename : str or Path
            Path to the .mat file
        well_row, well_col : int
            Well coordinates (0-indexed)
        plot_spikes : bool
            Whether to create spike raster plots
        plot_bursts : bool
            Whether to highlight bursts in plots
            
        Returns:
        --------
        tuple
            (channel_spikes, well_bursts, aggregated_spikes)
        """
        try:
            data_dict = mat73.loadmat(str(filename))
            well = data_dict['Plate'][well_row][well_col]
            
            # Initialize storage
            channel_spikes = {}
            well_bursts = []
            all_spikes = np.array([])
            
            # Process each channel (assuming 4x4 grid)
            for i in range(4):
                for j in range(4):
                    channel_idx = i * 4 + j
                    
                    # Handle empty channels
                    if not isinstance(well[i][j], np.ndarray):
                        well[i][j] = np.array([])
                        self.logger.warning(f'Empty channel at ({i}, {j}) in well ({well_row}, {well_col})')
                    
                    spike_times = well[i][j]
                    if spike_times.ndim == 0:
                        spike_times = np.array([spike_times]) if spike_times.size > 0 else np.array([])
                    
                    channel_spikes[channel_idx] = spike_times
                    all_spikes = np.concatenate([all_spikes, spike_times]) if spike_times.size > 0 else all_spikes
                    
                    # Detect bursts for this channel
                    if len(spike_times) > 1:
                        burst_data, _ = self.detect_bursts(spike_times)
                        well_bursts.extend(list(burst_data.values()))
            
            # Create visualizations if requested
            if plot_spikes or plot_bursts:
                self._create_well_visualization(well, well_row, well_col, all_spikes, 
                                              plot_spikes, plot_bursts)
            
            return list(channel_spikes.values()), well_bursts, all_spikes
            
        except Exception as e:
            self.logger.error(f"Error analyzing well ({well_row}, {well_col}): {str(e)}")
            return [], [], np.array([])

    def _create_well_visualization(self, well_data, well_row, well_col, all_spikes, 
                                 plot_spikes, plot_bursts):
        """Create visualization for a single well's activity."""
        fig, (ax_hist, ax_raster) = plt.subplots(
            2, 1, figsize=(15, 8), 
            gridspec_kw={'height_ratios': [1, 2]}, 
            sharex=True
        )
        
        channel_idx = 0
        for i in range(4):
            for j in range(4):
                spike_times = well_data[i][j]
                if spike_times.ndim == 0:
                    spike_times = np.array([spike_times]) if spike_times.size > 0 else np.array([])
                
                # Plot bursts
                if plot_bursts and len(spike_times) > 1:
                    burst_data, _ = self.detect_bursts(spike_times)
                    for burst in burst_data.values():
                        if len(burst) > 0:
                            burst_start = burst[0]
                            burst_duration = burst[-1] - burst_start
                            
                            # Raster plot burst highlight
                            raster_rect = Rectangle(
                                (burst_start, channel_idx + 0.5), burst_duration, 1,
                                alpha=0.5, edgecolor='turquoise', facecolor='turquoise'
                            )
                            ax_raster.add_patch(raster_rect)
                            
                            # Histogram burst highlight
                            hist_rect = Rectangle(
                                (burst_start, 0), burst_duration, 0.11,
                                alpha=0.075, edgecolor='turquoise', facecolor='turquoise'
                            )
                            ax_hist.add_patch(hist_rect)
                
                # Plot spikes
                if plot_spikes and len(spike_times) > 0:
                    ax_raster.vlines(
                        spike_times, channel_idx + 0.55, channel_idx + 1.45,
                        linewidth=0.4, color='black', alpha=0.4
                    )
                
                channel_idx += 1
        
        if plot_spikes and len(all_spikes) > 0:
            ax_hist.hist(all_spikes, bins=min(1200, len(all_spikes)), 
                        density=True, color='black')
            ax_hist.set_ylabel('Frequency')
            ax_hist.set_title(f'Neural Activity Summary - Well [{well_row}, {well_col}]')
            
            ax_raster.set_yticks(range(1, 17))
            ax_raster.set_ylabel('Channels')
            ax_raster.set_xlabel('Time (s)')
            
            plt.tight_layout()
            plt.show()

    def process_plate_file(self, filename, condition_name=None):
        """
        Process an entire plate file (all wells).
        
        Parameters:
        -----------
        filename : str or Path
            Path to the .mat file
        condition_name : str, optional
            Experimental condition name for labeling
            
        Returns:
        --------
        dict
            Results dictionary with well data
        """
        filename = Path(filename)
        if condition_name is None:
            condition_name = filename.stem
            
        results = {}
        
        try:
            # Process standard 6x8 well plate
            for row in range(6):
                for col in range(8):
                    well_id = f"well_{row:02d}_{col:02d}_{condition_name}"
                    self.logger.info(f"Processing {well_id}")
                    
                    try:
                        spikes, bursts, all_spikes = self.analyze_well(
                            filename, row, col, plot_spikes=False, plot_bursts=False
                        )
                        
                        results[well_id] = {
                            "channel_spikes": spikes,
                            "detected_bursts": bursts,
                            "aggregated_spikes": all_spikes,
                            "metadata": {
                                "well_position": (row, col),
                                "condition": condition_name,
                                "num_channels": len(spikes),
                                "total_spikes": len(all_spikes),
                                "num_bursts": len(bursts)
                            }
                        }
                        
                    except Exception as e:
                        self.logger.error(f"Error processing {well_id}: {str(e)}")
                        results[well_id] = {"error": str(e)}
                        
        except Exception as e:
            self.logger.error(f"Error processing plate {filename}: {str(e)}")
            raise
            
        return results

    def save_results(self, results, output_name):
        """
        Save processing results to JSON files.
        
        Parameters:
        -----------
        results : dict
            Results from process_plate_file
        output_name : str
            Base name for output files
        """
        output_path = self.output_dir / output_name
        output_path.mkdir(parents=True, exist_ok=True)
        
        class NumpyEncoder(json.JSONEncoder):
            def default(self, obj):
                if isinstance(obj, np.ndarray):
                    return obj.tolist()
                elif isinstance(obj, np.integer):
                    return int(obj)
                elif isinstance(obj, np.floating):
                    return float(obj)
                return super().default(obj)
        
        # Save individual well results
        for well_id, data in results.items():
            file_path = output_path / f"{well_id}.json"
            try:
                with open(file_path, 'w') as f:
                    json.dump(data, f, cls=NumpyEncoder, indent=2)
                self.logger.info(f"Saved {well_id} to {file_path}")
            except Exception as e:
                self.logger.error(f"Error saving {well_id}: {str(e)}")
        
        # Save summary statistics
        summary = self._generate_summary(results)
        summary_path = output_path / "processing_summary.json"
        with open(summary_path, 'w') as f:
            json.dump(summary, f, cls=NumpyEncoder, indent=2)

    def _generate_summary(self, results):
        """Generate summary statistics from processing results."""
        summary = {
            "total_wells_processed": len(results),
            "successful_wells": sum(1 for r in results.values() if "error" not in r),
            "failed_wells": sum(1 for r in results.values() if "error" in r),
            "total_spikes": 0,
            "total_bursts": 0,
            "processing_timestamp": datetime.now().isoformat()
        }
        
        for well_data in results.values():
            if "error" not in well_data:
                summary["total_spikes"] += well_data["metadata"]["total_spikes"]
                summary["total_bursts"] += well_data["metadata"]["num_bursts"]
        
        return summary

    def process_folder(self, folder_path):
        """
        Process all .mat files in a specified folder.
        
        Parameters:
        -----------
        folder_path : str or Path
            Path to folder containing .mat files
        """
        folder_path = Path(folder_path)
        mat_files = list(folder_path.glob('*.mat'))
        
        self.logger.info(f"Found {len(mat_files)} .mat files in {folder_path}")
        
        for file_path in mat_files:
            try:
                self.logger.info(f"Processing {file_path.name}")
                results = self.process_plate_file(file_path)
                self.save_results(results, file_path.stem)
                self.logger.info(f"Completed {file_path.name}")
                
            except Exception as e:
                self.logger.error(f"Failed to process {file_path.name}: {str(e)}")
                continue

In [None]:

def main():
    """
    Example usage of the MEADataProcessor.
    
    Modify the paths below to match your data organization.
    """
    # Configuration
    base_directory = Path("your_data_directory_here")  # Update this path
    input_folder = base_directory / "raw_data"         # Folder with .mat files
    
    # Initialize processor
    processor = MEADataProcessor(
        base_path=base_directory,
        output_subdir="processed_results"
    )
    
    # Process all files in the input folder
    if input_folder.exists():
        processor.process_folder(input_folder)
    else:
        print(f"Input folder {input_folder} does not exist. Please update the path.")
        
    # Alternative: Process a single file
    # single_file = input_folder / "your_plate_file.mat"
    # if single_file.exists():
    #     results = processor.process_plate_file(single_file, "control_condition")
    #     processor.save_results(results, "single_plate_analysis")


if __name__ == "__main__":
    main()