# Notebook to make plots of spike detection to check accuracy

## Imports

In [9]:
import os 
import glob
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from matplotlib import rcParams
from mosquito.process_abf import load_processed_data

## Params

In [10]:
# path to data
data_root = '/media/sam/SamData/Mosquitoes'
save_path ='/home/sam/Desktop/temp_spikes'

#folders to check
folder_numbers = np.arange(11, 32)


## Helper functions

In [11]:
def plot_spike_detection(data, trange=(0,10)):
    """
    Convenience function that should make a plot of emg activity + detected spikes
    """
    # read out data
    t = data['time']
    emg_filt = data['emg_filt']
    spike_idx = data['spike_idx']

    # define time range
    mask = (t >= trange[0]) & (t <= trange[1])
    mask_spikes = (t[spike_idx] >= trange[0]) & (t[spike_idx] <= trange[1])
    
    # make plot
    fig, ax = plt.subplots(figsize=(11,4))

    ax.plot(t[mask], emg_filt[mask])
    ax.plot(t[spike_idx][mask_spikes], emg_filt[spike_idx][mask_spikes], 'rx')

    ax.set_xlabel('time (s)')
    ax.set_ylabel('emg (V)')
    ax.set_xlim(trange)

    return fig, ax

## Loop over data files and make plots

In [12]:
# get list of experiment folders
expr_folders = sorted([f for f in os.listdir(data_root) if os.path.isdir(os.path.join(data_root, f)) and f[:2].isdigit()])
expr_folder_inds = [int(f.split('_')[0]) for f in expr_folders]
# expr_folder_inds

In [16]:
# loop over experiments folders we want to look at
for folder_num in folder_numbers:
    # get current folder
    ith = expr_folder_inds.index(folder_num)
    expr_folder = expr_folders[ith]

    # get all processed data files in that folder
    data_files = glob.glob(os.path.join(data_root, expr_folder, '**', '*_processed.pkl'))

    for data_file in data_files:
        # load current data
        data = pickle.load(open(data_file, "rb"))

        # make plot
        fig, ax = plot_spike_detection(data)

        # save
        # save_path, _ = os.path.split(data_file)
        _, data_fn = os.path.split(data_file)
        data_fn_no_ext, _ = os.path.splitext(data_fn)
        axo_num_str = data_fn_no_ext.split('_')[-2] 
        save_name = f'expr_{folder_num:02d}_axo_{axo_num_str}_spikes.png'
        fig.savefig(os.path.join(save_path, save_name))

        plt.close(fig)


In [15]:
data_fn_no_ext

'2024_06_11_0007_processed'