# Notebook to analyze spike count within burst

## Imports

In [33]:
import os 
import glob
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy import signal

%load_ext autoreload
%autoreload 2
    
# from matplotlib import rcParams
from mosquito.process_abf import load_processed_data
from mosquito.analyze_bursts import run_spike_detection, load_burst_data
from mosquito.util import iir_notch_filter, butter_highpass_filter, butter_bandpass_filter, moving_avg


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Helper functions

In [66]:
def get_spike_num_fraction(spike_df):
    """
    Function to get the distributions of spike numbers per burst from a spike dataframe 
    (the output of 'mosquito/analyze_bursts/run_spike_detection')

    Returns a 2 arrays, with first giving spike number and the second giving
    the number of instances of that spike number
    """
    # group by burst index and get spike numbers
    tmp = spike_df.groupby(by=['burst_idx']).max()
    spike_nums = tmp['peak_num'].values + 1  # add 1 to account for 0 index

    # get the counts for the two most frequent spike number values (others are likely errors)
    spike_nums_unique, unique_counts = np.unique(spike_nums, return_counts=True)
    sort_idx = np.flip(np.argsort(unique_counts))  # flip to get it in descending order
    stop_ind = min([2, spike_nums_unique.size])
    keep_counts = unique_counts[sort_idx][:stop_ind]  # only take two most frequent values
    keep_nums = spike_nums_unique[sort_idx][:stop_ind]

    return keep_nums, keep_counts
    

In [86]:
def get_spike_freq(spike_df, fs=35087):
    """
    Function to get frequency of spikes within bursts
    
    """
    # initialize some storage
    spike_freqs = list()
    
    # loop over unique burst values
    burst_idx_unique = spike_df['burst_idx'].unique()
    for bidx in burst_idx_unique:
        # get timing differences between peaks corresponding to current burst
        peak_idx_curr = spike_df['peak_idx'].loc[spike_df['burst_idx'] == bidx]
        peak_diff_sec = (1/fs)*np.diff(peak_idx_curr)
        spike_freqs.append(np.mean(peak_diff_sec)**(-1))

    # convert to array
    spike_freq_arr = np.asarray(spike_freqs)

    # remove nans 
    nan_idx = np.isnan(spike_freq_arr)
    spike_freq_arr = spike_freq_arr[~nan_idx]
    
    return spike_freq_arr
    

In [87]:
spike_df

Unnamed: 0,peak_idx,peak_idx_global,burst_idx,peak_num
0,1085,8810,8749,0
1,1178,8903,8749,1
2,1264,8989,8749,2
3,1350,9075,8749,3
4,1430,9155,8749,4
...,...,...,...,...
1147,1177,3260251,3260098,1
1148,1266,3260340,3260098,2
1149,1352,3260426,3260098,3
1150,1439,3260513,3260098,4


## Params

In [88]:
# plot params
plt.style.use('dark_background')
plt.rc('axes', titlesize=18)     # fontsize of the axes title
plt.rc('axes', labelsize=16)    # fontsize of the x and y labels

In [89]:
# data files that we want/have spike info for
data_files = ['19.1', '19.2', '19.5', '19.6', 
                  '22.2', '22.4', '22.5', '22.12',
                  '23.2', '24.7', '24.8', '26.2',
                  '28.0', '28.1', '28.2', '28.9',
                  '38.9', '38.10', '38.11']

# where to look for data
data_root = '/media/sam/SamData/Mosquitoes'
save_path = os.path.join(data_root, 'analysis')


In [90]:
# path to log file (and load)
log_path = os.path.join(data_root, 'experiment_log.xlsx')
log_df = pd.read_excel(log_path)


In [91]:
# do we need to run spike detection for data files?
run_spike_detect_flag = False


In [92]:
# sampling freq
fs=35087  # Hz

## Get spike number for all files

In [93]:
# loop over data files and save spike info
if run_spike_detect_flag:
        
    save_path = '/home/sam/Desktop/temp_bursts'
    save_name = 'spike_df.pkl'
    
    # loop over data files
    # get paths to experiment folders
    expr_folders = sorted([f for f in os.listdir(data_root) if os.path.isdir(os.path.join(data_root, f)) and f[:2].isdigit()])
    expr_folder_inds = [f.split('_')[0] for f in expr_folders]
       
    # turn data file numbers into paths
    for file_str in data_files:
        # get experiment and axo number from data file number
        expr_num = int(file_str.split('.')[0])
        axo_num = int(file_str.split('.')[1])
    
        # get expr folder matching expr number
        expr_folder = expr_folders[expr_folder_inds.index(str(expr_num))]
        
        # load data
        data = load_processed_data(expr_num, axo_num)
    
        # do spike detection
        spike_df = run_spike_detection(data, viz_flag=False)
    
        # save data
        save_folder_search = glob.glob(os.path.join(data_root, expr_folder, f'*{axo_num:04d}'))
        if len(save_folder_search) == 1:
            save_name_full = os.path.join(save_folder_search[0], save_name)
            spike_df.to_pickle(save_name_full)
            print(save_name_full)
        

## Make a dictionary containing spike counts for each fly

In [94]:
# initialize dictionary
fly_dict = dict()

# intialize with some empty lists
fly_dict['expr_num'] = list()
fly_dict['axo_num'] = list()
fly_dict['muscle_target'] = list()
fly_dict['sex'] = list()
fly_dict['species'] = list()
fly_dict['fly_num'] = list()
fly_dict['electrode_num'] = list()
fly_dict['spike_nums'] = list()
fly_dict['spike_num_counts'] = list()
fly_dict['peak_times'] = list()
fly_dict['spike_freqs'] = list()

# get paths to experiment folders
expr_folders = sorted([f for f in os.listdir(data_root) if os.path.isdir(os.path.join(data_root, f)) and f[:2].isdigit()])
expr_folder_inds = [f.split('_')[0] for f in expr_folders]
   
# turn data file numbers into paths
for ith, file_str in enumerate(data_files):
    # get experiment and axo number from data file number
    expr_num = int(file_str.split('.')[0])
    axo_num = int(file_str.split('.')[1])

    # get expr folder matching expr number
    expr_folder = expr_folders[expr_folder_inds.index(str(expr_num))]

    # get identifying info from log file
    row_idx = (log_df['Day'] == expr_folder) & (log_df['Axo Num'] == axo_num)
    muscle_target = log_df.loc[row_idx]['Target Muscle'].values[0]
    sex = log_df.loc[row_idx]['Sex'].values[0]
    species = log_df.loc[row_idx]['Species'].values[0]
    fly_num = log_df.loc[row_idx]['Fly Num'].values[0]
    electrode_num = log_df.loc[row_idx]['Electrode Num'].values[0]
    
    # load and read data
    spike_df = load_burst_data(expr_folder, axo_num)

    # get spike numbers/counts
    spike_nums, spike_num_counts = get_spike_num_fraction(spike_df)
    
    # get timing of spikes relative to burst onset
    peak_times = (1/fs)*(spike_df['peak_idx_global'] - spike_df['burst_idx'])
    peak_times = peak_times[peak_times > 0]
    
    # get overall spike frequency
    spike_freqs = get_spike_freq(spike_df, fs=fs)

    # append things to dict
    fly_dict['expr_num'].append(expr_num)
    fly_dict['axo_num'].append(axo_num)
    fly_dict['muscle_target'].append(muscle_target)
    fly_dict['sex'].append(sex)
    fly_dict['species'].append(species)
    fly_dict['fly_num'].append(fly_num)
    fly_dict['electrode_num'].append(int(electrode_num))
    fly_dict['spike_nums'].append(spike_nums)
    fly_dict['spike_num_counts'].append(spike_num_counts)
    fly_dict['peak_times'].append(peak_times.values)
    fly_dict['spike_freqs'].append(spike_freqs)

    # print update
    print(f'completed {expr_folder}, {axo_num}')
    
# save results
save_name = 'burst_analysis_dict.pkl'
save_path_full = os.path.join(save_path, save_name)
pickle.dump(fly_dict, open(save_path_full, "wb"))

completed 19_20240510, 1
completed 19_20240510, 2
completed 19_20240510, 5
completed 19_20240510, 6
completed 22_20240516, 2
completed 22_20240516, 4


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


completed 22_20240516, 5
completed 22_20240516, 12
completed 23_20240517, 2
completed 24_20240520, 7
completed 24_20240520, 8
completed 26_20240524, 2
completed 28_20240529, 0
completed 28_20240529, 1
completed 28_20240529, 2
completed 28_20240529, 9
completed 38_20240711, 9
completed 38_20240711, 10
completed 38_20240711, 11
