### 0. Detect MUA events and Insert into the Database. [This notebook populates the HSETimes table where entries are high multi-unit events and MUA table where entries are MUA traces.]
##### The HSETimes and MUA table is under shijiegu github/spyglass/shijiegu/Analysis_SGU.py. It is not in the Franklab spyglass.
##### Only run this after another shijiegu's customized TrialChoice table is populated.

Jun 23, 2024
Shijie Gu

In [5]:
%reload_ext autoreload
%autoreload 2

In [6]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import spikeinterface as si
import pynwb
import xarray as xr
import os

from spyglass.common.common_interval import _intersection
from spyglass.common import (IntervalPositionInfo, IntervalPositionInfoSelection, IntervalList, 
                             ElectrodeGroup, LFP, BrainRegion, LFPBand, Electrode)
from ripple_detection.core import (gaussian_smooth,
                                   get_envelope,
                                   get_multiunit_population_firing_rate,
                                   threshold_by_zscore,
                                   segment_boolean_series,
                                   exclude_close_events,
                                   exclude_movement,
                                   extend_threshold_to_mean,
                                    merge_overlapping_ranges)
from ripple_detection.detectors import multiunit_HSE_detector,_get_event_stats
from spyglass.utils.nwb_helper_fn import get_nwb_copy_filename

from spyglass.spikesorting.v0 import (SortGroup, 
                                    SortInterval,
                                    SpikeSortingPreprocessingParameters,
                                    SpikeSortingRecording, 
                                    SpikeSorterParameters,
                                    SpikeSortingRecordingSelection,
                                    ArtifactDetectionParameters, ArtifactDetectionSelection,
                                    ArtifactRemovedIntervalList, ArtifactDetection,
                                      SpikeSortingSelection, SpikeSorting,)
from spyglass.shijiegu.load import load_spike
from spyglass.shijiegu.Analysis_SGU import DecodeResultsLinear

[2025-06-14 12:48:06,509][INFO]: DataJoint 0.14.4 connected to shijiegu-alt@lmf-db.cin.ucsf.edu:3306


In [7]:
from spyglass.shijiegu.Analysis_SGU import TrialChoice,RippleTimes,EpochPos
from spyglass.shijiegu.helpers import interval_union,interpolate_to_new_time
from spyglass.shijiegu.load import load_LFP,load_position,load_maze_spike
from spyglass.shijiegu.ripple_detection import (loadRippleLFP,ExtendInterSection,InterSection,
                                                plot_ripple,threshold_by_zscore_Gu,
                                                Kay_ripple_detector,Karlsson_ripple_detector,Gu_ripple_detector,multiunit_HSE_detector,
                                                removeDataBeforeTrial1,removeArtifactTime,
                                                loadRippleLFP_OneChannelPerElectrode,ripple_detection_master)
from spyglass.shijiegu.mua_detection import mua_detection_master
from spyglass.shijiegu.Analysis_SGU import TetrodeNumber,MUA,HSETimes
from spyglass.common.common_position import IntervalLinearizedPosition
from spyglass.common.common_task import TaskEpoch

In [18]:
#nwb_file_name = 'julio20230731.nwb'
nwb_file_name = 'klein20231111.nwb'

In [19]:
#%debug

In [20]:
nwb_copy_file_name = get_nwb_copy_filename(nwb_file_name)
epochs = (EpochPos() & {'nwb_file_name': nwb_copy_file_name}).fetch('epoch')

EpochPos() & {'nwb_file_name': nwb_copy_file_name}

nwb_file_name  name of the NWB file,epoch  the session epoch for this task and apparatus(1 based),epoch_name  TaskEpoch or IntervalList,position_interval  IntervalPositionInfo
klein20231111_.nwb,1,01_Rev2Sleep1,pos 0 valid times
klein20231111_.nwb,2,02_Rev2Session1,pos 1 valid times
klein20231111_.nwb,3,03_Rev2Sleep2,pos 2 valid times
klein20231111_.nwb,4,04_Rev2Session2,pos 3 valid times
klein20231111_.nwb,5,05_Rev2Sleep3,pos 4 valid times
klein20231111_.nwb,6,06_Rev2Session3,pos 5 valid times
klein20231111_.nwb,7,07_Rev2Sleep4,pos 6 valid times
klein20231111_.nwb,8,08_Rev2Session4,pos 7 valid times
klein20231111_.nwb,9,09_Rev2Sleep5,pos 8 valid times
klein20231111_.nwb,10,10_Rev2Session5,pos 9 valid times


In [21]:
#(IntervalList & {'nwb_file_name': nwb_copy_file_name}).fetch('interval_list_name')

In [22]:
run_session = []
sleep_session = []
for e in epochs:
    epoch_name = (EpochPos() & {'nwb_file_name': nwb_copy_file_name,'epoch':e}).fetch1('epoch_name')
    if epoch_name.split('_')[1][4:8] == 'Sess':
        run_session.append(e)
    else:
        sleep_session.append(e)

In [23]:
run_session

[2, 4, 6, 8, 10]

### 1. Run one session. Can skip to 3 directly to run all sessions

In [22]:
epochID = 6

In [14]:
nwb_copy_file_name=get_nwb_copy_filename(nwb_file_name)

# find epoch/session name and position interval name
key = (EpochPos & {'nwb_file_name':nwb_copy_file_name,'epoch':epochID}).fetch1()
epoch_name = key['epoch_name']
position_interval = key['position_interval']
if epoch_name.split('_')[1][4:8] == 'Sess':
    is_run_session = True

In [15]:
if is_run_session:
    # for run session: Get MUA
    _0,_1,mua_time,mua,_2=load_maze_spike(nwb_copy_file_name,epoch_name)
    
    mua_smooth = gaussian_smooth(mua, 0.004, 30000) # 4ms smoothing, as in Kay, Karlsson, spiking data are in 30000Hz
    mua_ds = mua_smooth[::10]
    mua_time_ds = mua_time[::10]
else:
    # for sleep session: Get MUA
    _0,_1,mua_time,mua,_2=load_spike(nwb_copy_file_name,epoch_name)
    
    mua_smooth = gaussian_smooth(mua, 0.004, 30000) # 4ms smoothing, as in Kay, Karlsson, spiking data are in 30000Hz
    mua_ds = mua_smooth[::10]
    mua_time_ds = mua_time[::10]

KeyboardInterrupt: 

In [None]:
position_valid_times = (IntervalList & {'nwb_file_name': nwb_copy_file_name,
                                            'interval_list_name': position_interval}).fetch1('valid_times')

In [None]:
# Remove Data before 1st trial and after last trial and artifact
# to remove artifact, we use LFP to help, where artifact times are noted already

filtered_lfps, filtered_lfps_t, CA1TetrodeInd, CCTetrodeInd = loadRippleLFP_OneChannelPerElectrode(
        nwb_copy_file_name,epoch_name,position_valid_times)

position_info = load_position(nwb_copy_file_name,position_interval)
position_info_upsample = interpolate_to_new_time(position_info, filtered_lfps_t)
position_info_upsample = removeArtifactTime(position_info_upsample, filtered_lfps)

if is_run_session:
    StateScript = pd.DataFrame(
        (TrialChoice & {'nwb_file_name':nwb_copy_file_name,'epoch':int(epoch_name[:2])}).fetch1('choice_reward')
    )
    trial_1_t = StateScript.loc[1].timestamp_O
    trial_last_t = StateScript.loc[len(StateScript)-1].timestamp_O
    position_info_upsample = removeDataBeforeTrial1(position_info_upsample,trial_1_t,trial_last_t)
    
position_info_upsample2 = interpolate_to_new_time(position_info_upsample, mua_time_ds)

In [None]:
hse_times,firing_rate_raw, mua_mean, mua_std = multiunit_HSE_detector(mua_time_ds,mua_ds,
                                                                      np.array(position_info_upsample2.head_speed),
                                                                      3000,speed_threshold=4.0,
                                                                      zscore_threshold=0,use_speed_threshold_for_zscore=True)

In [None]:
# Insert into HSETimes table
animal = nwb_copy_file_name[:5]
savePath=os.path.join(f'/cumulus/shijie/recording_pilot/{animal}/decoding',
                    nwb_copy_file_name+'_'+epoch_name+'_hse_times.nc')
hse_times.to_csv(savePath)

key = {'nwb_file_name': nwb_copy_file_name, 'interval_list_name': epoch_name}
key['hse_times'] = savePath
HSETimes().insert1(key,replace = True)

# Insert into MUA table for future plotting
key = {'nwb_file_name': nwb_copy_file_name, 'interval_list_name': epoch_name}
mua_df = pd.DataFrame(data=firing_rate_raw, index=mua_time_ds, columns = ['mua'])
mua_df.index.name='time'
mua_df=xr.Dataset.from_dataframe(mua_df)

savePath=os.path.join(f'/cumulus/shijie/recording_pilot/{animal}/decoding',
                             nwb_copy_file_name+'_'+epoch_name+'_mua.nc')
mua_df.to_netcdf(savePath)

key['mua_trace'] = savePath
(key['mean'],key['sd']) = (mua_mean,mua_std)
MUA().insert1(key,replace = True)

In [27]:
MUA & {'nwb_file_name': nwb_copy_file_name}

nwb_file_name  name of the NWB file,interval_list_name  descriptive name of this interval list,mua_trace  file name for MUA trace,mean  mean,sd  sd
klein20231111_.nwb,02_Rev2Session1,/cumulus/shijie/recording_pilot/klein/decoding/klein20231111_.nwb_02_Rev2Session1_mua.nc,5672.997762016841,3218.704654322044
klein20231111_.nwb,04_Rev2Session2,/cumulus/shijie/recording_pilot/klein/decoding/klein20231111_.nwb_04_Rev2Session2_mua.nc,5699.27036175627,3385.9310437090644
klein20231111_.nwb,06_Rev2Session3,/cumulus/shijie/recording_pilot/klein/decoding/klein20231111_.nwb_06_Rev2Session3_mua.nc,5380.896644841476,2987.67025025481
klein20231111_.nwb,08_Rev2Session4,/cumulus/shijie/recording_pilot/klein/decoding/klein20231111_.nwb_08_Rev2Session4_mua.nc,5486.888672325063,2948.6066797408766
klein20231111_.nwb,10_Rev2Session5,/cumulus/shijie/recording_pilot/klein/decoding/klein20231111_.nwb_10_Rev2Session5_mua.nc,5650.699007327018,3265.956911861291


### 3. Run all sessions.

In [24]:
sleep_session

[1, 3, 5, 7, 9, 11]

In [25]:
epochs_to_run = run_session #epochs #if not running the first session
print(epochs_to_run)

[2, 4, 6, 8, 10]


In [None]:
for e in epochs_to_run:
    mua_detection_master(nwb_copy_file_name, e)
    
# to get data out:
# test = (RippleTimes() & {'nwb_file_name': nwb_copy_file_name, 'interval_list_name': epoch_name}).fetch1('ripple_times')
# len(pd.DataFrame(test))

Using LFP from these eletrodes: 
[ 0  8 10 11 12 15 16 28 30 36 45 52 53 54 63]




  is_start_time = (~series.shift(1).fillna(False)) & series
  is_end_time = series & (~series.shift(-1).fillna(False))


In [None]:
fieldname = "filtered data"
key = {'nwb_file_name': nwb_copy_file_name,
           'target_interval_list_name': interval_list_name,
           'filter_name': 'Ripple 150-250 Hz'}

ripple_nwb_file_name = (LFPBandArtifact & key).fetch1('analysis_nwb_file_name')
analysisNWBFilePath = AnalysisNwbfile.get_abs_path(ripple_nwb_file_name)

with pynwb.NWBHDF5IO(analysisNWBFilePath, 'r',load_namespaces=True) as io:
    ripple_nwb = io.read()

    #filtered_t=np.array(ripple_nwb.scratch[fieldname].timestamps)
    electrodes=ripple_nwb.scratch[fieldname].electrodes.to_dataframe()

In [None]:
MUA & {'nwb_file_name': "julio20230731_.nwb"}