# Detect multi unit events

    Multiunit High Synchrony Event detector. Finds times when the multiunit
    population spiking activity is high relative to the average.

In [None]:
from ripple_detection import multiunit_HSE_detector
import os
import h5py
import pickle
import pandas as pd
import numpy as np
import hdf5storage
from matplotlib import pyplot as plt
import glob
import multiprocessing
from joblib import Parallel, delayed
from ripple_detection.core import get_multiunit_population_firing_rate
from scipy.signal import find_peaks,peak_prominences
from scipy.stats import norm

In [2]:
def load_position(session):
    f = h5py.File(session,'r')
    # load frames [ts x y a s] 
    frames = np.transpose(np.array(f['frames']))
    return pd.DataFrame(frames,columns=['ts', 'x', 'y', 'hd', 'speed'])   

def get_session_path(session):
    f = h5py.File(session,'r')
    return f['session_path'][()].tobytes()[::2].decode()

def get_spikes(filename):
    data = hdf5storage.loadmat(filename,variable_names=['Spikes'])
    spike_times=data['Spikes']
    spike_times=np.squeeze(spike_times)
    for i in range(spike_times.shape[0]):
        spike_times[i]=np.squeeze(spike_times[i])
    return spike_times

def bin_spikes(spike_times,dt,wdw_start,wdw_end):
    """
    Function that puts spikes into bins
    Parameters
    ----------
    spike_times: an array of arrays
        an array of neurons. within each neuron's array is an array containing all the spike times of that neuron
    dt: number (any format)
        size of time bins
    wdw_start: number (any format)
        the start time for putting spikes in bins
    wdw_end: number (any format)
        the end time for putting spikes in bins
    Returns
    -------
    neural_data: a matrix of size "number of time bins" x "number of neurons"
        the number of spikes in each time bin for each neuron
    """
    edges=np.arange(wdw_start,wdw_end,dt) #Get edges of time bins
    num_bins=edges.shape[0]-1 #Number of bins
    num_neurons=spike_times.shape[0] #Number of neurons
    neural_data=np.empty([num_bins,num_neurons]) #Initialize array for binned neural data
    #Count number of spikes in each bin for each neuron, and put in array
    for i in range(num_neurons):
        neural_data[:,i]=np.histogram(spike_times[i],edges)[0]
    return neural_data

def get_peak_ts(high_synchrony_event_times,firing_rate,ts):
    peak_ts = []
    for event in high_synchrony_event_times.itertuples():
        idx = (ts >= event.start_time) & (ts <= event.end_time)
        temp_ts = ts[idx]
        peak_ts.append(temp_ts[np.argmax(firing_rate[idx])])
    return peak_ts

def fastrms(x,window=5):
    window = np.ones(window)
    power = x**2
    rms = np.convolve(power,window,mode='same')
    return  np.sqrt(rms/sum(window));
    
def get_place_fields(ratemap,min_peak_rate=2,min_field_width=2,max_field_width=39,percent_threshold=0.2):
    
    std_rates = np.std(ratemap)
    
    locs,properties = find_peaks(fastrms(ratemap), height=min_peak_rate, width=min_field_width)
    pks = properties['peak_heights']

    exclude = []
    for j in range(len(locs)-1):
        if min(ratemap[locs[j]:locs[j+1]]) > ((pks[j] + pks[j+1]) / 2) * percent_threshold:
            if pks[j] > pks[j+1]:
                exclude.append(j+1)
            elif pks[j] < pks[j+1]:
                exclude.append(j)
       
    if any(ratemap[locs] < std_rates*.5):
        exclude.append(np.where(ratemap[locs] < std_rates*.5))
    if not exclude:
        pks = np.delete(pks, exclude)
        locs = np.delete(locs, exclude)
    
    fields = []
    for j in range(len(locs)):
        Map_Field = (ratemap > pks[j] * percent_threshold)*1;
        start = locs[j]
        stop = locs[j]
        
        while (Map_Field[start] == 1)  & (start > 0):
            start -= 1
        while (Map_Field[stop] == 1)  & (stop < len(Map_Field)-1):
            stop += 1

        if ((stop - start) > min_field_width) & ((stop - start) < max_field_width):
            com = start
            while sum(ratemap[start:stop]) - sum(ratemap[start:com]) > sum(ratemap[start:com])/2:
                com += 1
            fields.append((start,stop,stop - start,pks[j],locs[j],com))
                        
    # add to data frames
    fields = pd.DataFrame(fields, columns=("start", "stop", "width", "peakFR", "peakLoc", "COM"))   
    
    # remove fields with the same field boundaries and keep the one with the highest peak rate
    fields = fields.sort_values(by=['peakFR'],ascending=False)
    fields.drop_duplicates(subset = ['start', 'stop'])

    return fields

def get_place_cell_idx(session):
    """
    find cells to include. At least 1 field from both directions
    """
    data = hdf5storage.loadmat(session,variable_names=['ratemap'])
    include = []
    field = 0
    for i in range(data['ratemap'].shape[0]):
        for d in range(2):
            fields = get_place_fields(data['ratemap'][i,d][0])
            if not fields.empty:
                field += 1
        if field > 0:
            include.append(1)
        else:
            include.append(0)
        field = 0
    return include    

In [3]:
def run_all(session,dt=0.01):
    
    # get data session path from mat file
    path = get_session_path(session)
    
    # load position data from .mat file
    df = load_position(session)
    
    spike_times = get_spikes(session)
    
    # get place cells
    include = get_place_cell_idx(session)
    
    multiunit = bin_spikes(spike_times[include],dt,min(df.ts),max(df.ts))
    ts = np.arange(min(df.ts) + dt/2, max(df.ts) - dt/2, dt)
    
    # interp speed of the animal
    speed = np.interp(ts,df.ts,df.speed)
    speed[np.isnan(speed)] = 0
    
    # detect ripples
    high_synchrony_event_times = multiunit_HSE_detector(ts, multiunit, speed, dt)
    
    # add peak time stamp
    firing_rate = get_multiunit_population_firing_rate(multiunit, dt, 0.015)
    peak_time = get_peak_ts(high_synchrony_event_times,firing_rate,ts)
    high_synchrony_event_times['peak_time'] = peak_time
    
    return high_synchrony_event_times

In [5]:
def main_loop(session,data_path,save_path):
    base = os.path.basename(session)
    os.path.splitext(base)
    save_file = save_path + os.path.splitext(base)[0] + '.pkl'
    
    # check if saved file exists
    if os.path.exists(save_file):
        return
        
    # detect ripples and calc some features
    high_synchrony_event_times = run_all(session)   

    # save file
    with open(save_file, 'wb') as f:
        pickle.dump(high_synchrony_event_times, f)


data_path = 'F:\\Projects\\PAE_PlaceCell\\ProcessedData\\'
save_path = "F:\\Projects\\PAE_PlaceCell\\multiunit_data\\"

# find HPC sessions
df_sessions = pd.read_csv('D:/ryanh/github/harvey_et_al_2020/Rdata_pae_track_cylinder_all_cells.csv')
sessions = pd.unique(df_sessions.session)
sessions = data_path+sessions

# for session in sessions:
#     print(session)
#     main_loop(session,data_path,save_path)
 
num_cores = multiprocessing.cpu_count()         
processed_list = Parallel(n_jobs=num_cores)(delayed(main_loop)(session,data_path,save_path) for session in sessions)


In [6]:
save_path = "F:/Projects/PAE_PlaceCell/multiunit_data/"
sessions = glob.glob(save_path + '*.pkl')

df=pd.DataFrame()
for session in sessions:
    with open(session, 'rb') as f:
        high_synchrony_event_times = pickle.load(f)
   
    # add data frame of ripple features and add session id
    base = os.path.basename(session)
    high_synchrony_event_times['session'] = os.path.splitext(base)[0]
    df = df.append(high_synchrony_event_times,ignore_index=True)

print(df)

      start_time  end_time  peak_time                  session
0        504.125   504.155    504.135  LEM3116_S20180715121821
1        504.175   504.205    504.185  LEM3116_S20180715121821
2        563.575   563.595    563.575  LEM3116_S20180715121821
3        676.335   676.425    676.415  LEM3116_S20180715121821
4       1403.805  1403.835   1403.805  LEM3116_S20180715121821
...          ...       ...        ...                      ...
7015     228.735   228.755    228.735     RH16_S20161207130000
7016     228.865   228.885    228.865     RH16_S20161207130000
7017     247.555   247.575    247.555     RH16_S20161207130000
7018     276.665   276.685    276.665     RH16_S20161207130000
7019     281.965   281.985    281.965     RH16_S20161207130000

[7020 rows x 4 columns]


In [None]:

data_path = 'F:\\Projects\\PAE_PlaceCell\\ProcessedData\\'
dicts = {}
for session in df.session:
    f = h5py.File(data_path+session+'.mat','r')
    ex_ep = []
    for i in range(f['events'].shape[0]):
        ex_ep.append(f['events'][i])
    dicts[session] = ex_ep
    
ep_type = ['pedestal_1','track','pedestal_2','cylinder_1','pedestal_3','cylinder_2','pedestal_4']
df['ep_type'] = np.ones_like(df.session)
# session_df=pd.DataFrame()
for session in np.unique(df.session):
    # stack epoch times
    b = np.hstack(dicts[session])
    
    # add 0 to start to indicate the start of the recording session
    b = np.insert(b,0,0)
    
    # add the ts of the last ripple of the session to indicate end of session
    b = list(b)
    last_rip = max(df.end_time[df.session == session])
    if b[-1] < last_rip:
        b.append(last_rip)
    
    # loop through each epoch and label each ripple
    for ep in range(len(b)-1):
        idx = (df.session == session) & (df.peak_time >= b[ep]) & (df.peak_time <= b[ep+1])
        df['ep_type'][idx] = ep_type[ep]

print(df)

In [9]:
os.mkdir(save_path+'post_processed')
df.to_csv(save_path+'post_processed/mua_df.csv')