# Notebook to make plots of spike detection to check accuracy

## Imports

In [8]:
import os 
import glob
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from matplotlib import rcParams
from mosquito.process_abf import load_processed_data

## Params

In [9]:
# path to data
data_root = '/media/sam/SamData/Mosquitoes'
save_path ='/home/sam/Desktop/temp_spikes'

#folders to check
folder_numbers = [75]  # np.arange(51, 75)


In [10]:
# time ranges to plot
t_range_steer = (0, 0.35)
t_range_power = (0, 10)


## Helper functions

In [11]:
def plot_spike_detection(data, t_range=(0,10)):
    """
    Convenience function that should make a plot of emg activity + detected spikes
    """
    # read out data
    t = data['time']
    emg_filt = data['emg_filt']
    spike_idx = data['spike_idx']

    # define time range
    mask = (t >= t_range[0]) & (t <= t_range[1])

    # deal with single vs multichannel
    if not isinstance(emg_filt, list):
        emg_filt = [emg_filt]
        spike_idx = [spike_idx]
        figure_height = 7
    else:
        figure_height = 4

    # initialize figure
    fig, ax_list = plt.subplots(len(emg_filt), 1, figsize=(11,figure_height))

    if len(emg_filt) == 1:
        ax_list = np.array([ax_list])
        
    # loop over channels
    for ith, (idx, emg) in enumerate(zip(spike_idx, emg_filt)):
        # mask to current time range
        mask_spikes = (t[idx] >= t_range[0]) & (t[idx] <= t_range[1])
    
        # plot
        ax_list[ith].plot(t[mask], emg[mask])
        ax_list[ith].plot(t[idx][mask_spikes], emg[idx][mask_spikes], 'rx')
    
        ax_list[ith].set_xlabel('time (s)')
        ax_list[ith].set_ylabel('emg (V)')
        ax_list[ith].set_xlim(t_range)

    return fig, ax_list

## Loop over data files and make plots

In [12]:
# get list of experiment folders
expr_folders = sorted([f for f in os.listdir(data_root) if os.path.isdir(os.path.join(data_root, f)) and f[:2].isdigit()])
expr_folder_inds = [int(f.split('_')[0]) for f in expr_folders]
# expr_folder_inds

In [13]:
# loop over experiments folders we want to look at
for folder_num in folder_numbers:
    # get current folder
    ith = expr_folder_inds.index(folder_num)
    expr_folder = expr_folders[ith]

    # get all processed data files in that folder
    data_files = glob.glob(os.path.join(data_root, expr_folder, '**', '*_processed.pkl'))

    for data_file in data_files:
        # see if we have a more processed version of the file
        if os.path.exists(data_file.replace('processed', 'spikes')):
            data_file = data_file.replace('processed', 'spikes')
            
        # load current data 
        data = pickle.load(open(data_file, "rb"))

        # determine which time range to look at 
        muscle_type = data['muscle_type']
        if muscle_type.lower() == 'steer':
            t_range = t_range_steer
        elif muscle_type.lower() == 'power':
            t_range = t_range_power
        else:
            raise ValueError(f'{muscle_type} is not a valid muscle type')
            
        # make plot
        try:
            fig, ax = plot_spike_detection(data, t_range=t_range)
        except IndexError:
            print(f'Failed to plot {data_file} -- reanalyze!')
            continue
            
        # save
        # save_path, _ = os.path.split(data_file)
        _, data_fn = os.path.split(data_file)
        data_fn_no_ext, _ = os.path.splitext(data_fn)
        axo_num_str = data_fn_no_ext.split('_')[-2] 
        save_name = f'expr_{folder_num:02d}_axo_{axo_num_str}_spikes.png'
        fig.savefig(os.path.join(save_path, save_name))

        plt.close(fig)



In [14]:
data_file

'/media/sam/SamData/Mosquitoes/75_20250221/2025_02_21_0002/2025_02_21_0002_processed.pkl'