In [1]:
from typing import Union, List
import numpy as np
import h5py
import remfile
import spikeinterface.preprocessing as spre
import spikeinterface as si
import mountainsort5 as ms5

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
url = 'https://dandiarchive.s3.amazonaws.com/blobs/a4c/525/a4c525eb-8f49-4161-a56a-ca0e99bedfbb'

# open the remote file
disk_cache = remfile.DiskCache('/tmp/remfile_cache')
f = h5py.File(remfile.File(url, verbose=True, disk_cache=disk_cache), 'r')

# load the neurodata object
X = f['/acquisition/ElectricalSeries']

starting_time = X['starting_time'][()]
rate = X['starting_time'].attrs['rate']
data = X['data']

print(f'starting_time: {starting_time}')
print(f'rate: {rate}')
print(f'data shape: {data.shape}')

starting_time: 0.0
rate: 30000.0
data shape: (30532101, 32)


In [3]:
class NwbRecording(si.BaseRecording):
    def __init__(self,
        file: h5py.File,
        electrical_series_path: str
    ) -> None:
        electrical_series = file[electrical_series_path]
        electrical_series_data = electrical_series['data']
        dtype = electrical_series_data.dtype

        # Get sampling frequency
        if 'starting_time' in electrical_series.keys():
            t_start = electrical_series['starting_time'][()]
            sampling_frequency = electrical_series['starting_time'].attrs['rate']
        elif 'timestamps' in electrical_series.keys():
            t_start = electrical_series['timestamps'][0]
            sampling_frequency = 1 / np.median(np.diff(electrical_series['timestamps'][:1000]))
        
        # Get channel ids
        electrode_indices = electrical_series['electrodes'][:]
        electrodes_table = file['/general/extracellular_ephys/electrodes']
        channel_ids = [electrodes_table['id'][i] for i in electrode_indices]
        
        si.BaseRecording.__init__(self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype)
        
        # Set electrode locations
        if 'x' in electrodes_table:
            channel_loc_x = [electrodes_table['x'][i] for i in electrode_indices]
            channel_loc_y = [electrodes_table['y'][i] for i in electrode_indices]
            if 'z' in electrodes_table:
                channel_loc_z = [electrodes_table['z'][i] for i in electrode_indices]
            else:
                channel_loc_z = None
        elif 'rel_x' in electrodes_table:
            channel_loc_x = [electrodes_table['rel_x'][i] for i in electrode_indices]
            channel_loc_y = [electrodes_table['rel_y'][i] for i in electrode_indices]
            if 'rel_z' in electrodes_table:
                channel_loc_z = [electrodes_table['rel_z'][i] for i in electrode_indices]
            else:
                channel_loc_z = None
        else:
            channel_loc_x = None
            channel_loc_y = None
            channel_loc_z = None
        if channel_loc_x is not None:
            ndim = 2 if channel_loc_z is None else 3
            locations = np.zeros((len(electrode_indices), ndim), dtype=float)
            for i, electrode_index in enumerate(electrode_indices):
                locations[i, 0] = channel_loc_x[electrode_index]
                locations[i, 1] = channel_loc_y[electrode_index]
                if channel_loc_z is not None:
                    locations[i, 2] = channel_loc_z[electrode_index]
            self.set_dummy_probe_from_locations(locations)

        recording_segment = NwbRecordingSegment(
            electrical_series_data=electrical_series_data,
            sampling_frequency=sampling_frequency
        )
        self.add_recording_segment(recording_segment)

class NwbRecordingSegment(si.BaseRecordingSegment):
    def __init__(self, electrical_series_data: h5py.Dataset, sampling_frequency: float) -> None:
        self._electrical_series_data = electrical_series_data
        si.BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency)

    def get_num_samples(self) -> int:
        return self._electrical_series_data.shape[0]

    def get_traces(self, start_frame: int, end_frame: int, channel_indices: Union[List[int], None]=None) -> np.ndarray:
        if channel_indices is None:
            return self._electrical_series_data[start_frame:end_frame, :]
        else:
            return self._electrical_series_data[start_frame:end_frame, channel_indices]

recording = NwbRecording(
    file=f,
    electrical_series_path='/acquisition/ElectricalSeries'
)

In [4]:
# Make sure the recording is preprocessed appropriately
# lazy preprocessing
recording_filtered = spre.bandpass_filter(recording, freq_min=300, freq_max=6000)
recording_preprocessed: si.BaseRecording = spre.whiten(recording_filtered, dtype='float32')

In [5]:
x = recording_preprocessed.get_traces(start_frame=20530101, end_frame=20632101, channel_ids=[0, 1, 2, 3])

In [6]:
x.shape

(102000, 4)

In [7]:
recording_preprocessed.get_num_frames()

30532101

In [8]:
sorting = sorting_params = ms5.Scheme2SortingParameters(
    phase1_detect_channel_radius=50,
    detect_channel_radius=50
)
ms5.sorting_scheme2(
    recording=recording_preprocessed,
    sorting_parameters=sorting_params
)

Number of channels: 32
Number of timepoints: 30532101
Sampling frequency: 30000.0 Hz
Channel 0: [0. 0.]
Channel 1: [-18.   12.5]
Channel 2: [18.  12.5]
Channel 3: [ 0. 25.]
Channel 4: [-18.   37.5]
Channel 5: [18.  37.5]
Channel 6: [ 0. 50.]
Channel 7: [-18.   62.5]
Channel 8: [18.  62.5]
Channel 9: [ 0. 75.]
Channel 10: [-18.   87.5]
Channel 11: [18.  87.5]
Channel 12: [  0. 100.]
Channel 13: [-18.  112.5]
Channel 14: [ 18.  112.5]
Channel 15: [  0. 125.]
Channel 16: [-18.  137.5]
Channel 17: [ 18.  137.5]
Channel 18: [  0. 150.]
Channel 19: [-18.  162.5]
Channel 20: [ 18.  162.5]
Channel 21: [  0. 175.]
Channel 22: [-18.  187.5]
Channel 23: [ 18.  187.5]
Channel 24: [  0. 200.]
Channel 25: [-18.  212.5]
Channel 26: [ 18.  212.5]
Channel 27: [  0. 225.]
Channel 28: [-18.  237.5]
Channel 29: [ 18.  237.5]
Channel 30: [  0. 250.]
Channel 31: [  0. 275.]
Loading traces
Loading 2 chunks starting at 1153 (0.2048 million bytes)
Loading 4 chunks starting at 1155 (0.4096 million bytes)
Loadin

KeyboardInterrupt: 