In [None]:
from typing import Union, List
import numpy as np
import h5py
import remfile
import spikeinterface.preprocessing as spre
import spikeinterface as si
import mountainsort5 as ms5

In [None]:
dandi_nwb_fname = 'sub-paired-english_ses-paired-english-m139-200114-222743_ecephys.nwb'
url = 'https://dandiarchive.s3.amazonaws.com/blobs/0e1/534/0e1534a8-7d31-49d6-a926-27c0ee56dee4'

# open the remote file
disk_cache = remfile.DiskCache('/tmp/remfile_cache')
remf = remfile.File(url, verbose=True, disk_cache=disk_cache)
f = h5py.File(remf, 'r')

In [None]:
class NwbRecording(si.BaseRecording):
    def __init__(self,
        file: h5py.File,
        electrical_series_path: str
    ) -> None:
        electrical_series = file[electrical_series_path]
        electrical_series_data = electrical_series['data']
        dtype = electrical_series_data.dtype

        # Get sampling frequency
        if 'starting_time' in electrical_series.keys():
            t_start = electrical_series['starting_time'][()]
            sampling_frequency = electrical_series['starting_time'].attrs['rate']
        elif 'timestamps' in electrical_series.keys():
            t_start = electrical_series['timestamps'][0]
            sampling_frequency = 1 / np.median(np.diff(electrical_series['timestamps'][:1000]))
        
        # Get channel ids
        electrode_indices = electrical_series['electrodes'][:]
        electrodes_table = file['/general/extracellular_ephys/electrodes']
        channel_ids = [electrodes_table['id'][i] for i in electrode_indices]
        
        si.BaseRecording.__init__(self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype)
        
        # Set electrode locations
        if 'x' in electrodes_table:
            channel_loc_x = [electrodes_table['x'][i] for i in electrode_indices]
            channel_loc_y = [electrodes_table['y'][i] for i in electrode_indices]
            if 'z' in electrodes_table:
                channel_loc_z = [electrodes_table['z'][i] for i in electrode_indices]
            else:
                channel_loc_z = None
        elif 'rel_x' in electrodes_table:
            channel_loc_x = [electrodes_table['rel_x'][i] for i in electrode_indices]
            channel_loc_y = [electrodes_table['rel_y'][i] for i in electrode_indices]
            if 'rel_z' in electrodes_table:
                channel_loc_z = [electrodes_table['rel_z'][i] for i in electrode_indices]
            else:
                channel_loc_z = None
        else:
            channel_loc_x = None
            channel_loc_y = None
            channel_loc_z = None
        if channel_loc_x is not None:
            ndim = 2 if channel_loc_z is None else 3
            locations = np.zeros((len(electrode_indices), ndim), dtype=float)
            for i, electrode_index in enumerate(electrode_indices):
                locations[i, 0] = channel_loc_x[electrode_index]
                locations[i, 1] = channel_loc_y[electrode_index]
                if channel_loc_z is not None:
                    locations[i, 2] = channel_loc_z[electrode_index]
            self.set_dummy_probe_from_locations(locations)

        recording_segment = NwbRecordingSegment(
            electrical_series_data=electrical_series_data,
            sampling_frequency=sampling_frequency
        )
        self.add_recording_segment(recording_segment)

class NwbRecordingSegment(si.BaseRecordingSegment):
    def __init__(self, electrical_series_data: h5py.Dataset, sampling_frequency: float) -> None:
        self._electrical_series_data = electrical_series_data
        si.BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency)

    def get_num_samples(self) -> int:
        return self._electrical_series_data.shape[0]

    def get_traces(self, start_frame: int, end_frame: int, channel_indices: Union[List[int], None]=None) -> np.ndarray:
        if channel_indices is None:
            return self._electrical_series_data[start_frame:end_frame, :]
        else:
            return self._electrical_series_data[start_frame:end_frame, channel_indices]

recording = NwbRecording(
    file=f,
    electrical_series_path='/acquisition/ElectricalSeries'
)

In [None]:
# Make sure the recording is preprocessed appropriately
# lazy preprocessing
recording_filtered = spre.bandpass_filter(recording, freq_min=300, freq_max=6000)
recording_preprocessed: si.BaseRecording = spre.whiten(recording_filtered, dtype='float32')

In [None]:
x = recording_preprocessed.get_traces(start_frame=20530101, end_frame=20632101, channel_ids=[0, 1, 2, 3])

In [None]:
sorting_params = ms5.Scheme2SortingParameters(
    phase1_detect_channel_radius=50,
    detect_channel_radius=50
)
sorting = ms5.sorting_scheme2(
    recording=recording_preprocessed,
    sorting_parameters=sorting_params
)

In [None]:
sorting.get_unit_ids()

In [None]:
from typing import List


class NwbSorting(si.BaseSorting):
    def __init__(self, *, file, units_path: str, sampling_frequency: float):
        self._file = file
        self._units_path = units_path

        unit_ids = file[units_path]['id'][:]

        si.BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=unit_ids)

        sorting_segment = NWbSortingSegment(
            file=file, unit_ids=unit_ids, units_path=units_path, sampling_frequency=sampling_frequency
        )
        self.add_sorting_segment(sorting_segment)

class NWbSortingSegment(si.BaseSortingSegment):
    def __init__(self, *, file, unit_ids, units_path: str, sampling_frequency: float):
        si.BaseSortingSegment.__init__(self)
        self._file = file
        self._unit_ids = unit_ids
        self._units_path = units_path
        self._sampling_frequency = sampling_frequency
    def get_unit_spike_train(
        self,
        unit_id,
        start_frame: Union[int, None] = None,
        end_frame: Union[int, None] = None,
    ) -> np.ndarray:
        if start_frame is None:
            start_frame = 0
        if end_frame is None:
            end_frame = np.inf
        uu = self._file[self._units_path]
        spike_times_index = uu['spike_times_index'][:]
        aa = np.where(self._unit_ids == unit_id)[0]
        if len(aa) == 0:
            raise Exception('Unit does not exist: {}'.format(unit_id))
        ind = aa[0]
        if ind == 0:
            i1 = 0
        else:
            i1 = spike_times_index[ind - 1]
        i2 = spike_times_index[ind]

        times = self._file[self._units_path]["spike_times"][i1:i2]

        frames = np.round(times * self._sampling_frequency).astype("int64")
        return frames[(frames >= start_frame) & (frames < end_frame)]
    

sorting_true = NwbSorting(
    file=f,
    units_path='/units',
    sampling_frequency=recording.get_sampling_frequency()
)

In [None]:
import pynwb
from uuid import uuid4
from pynwb.misc import Units

with pynwb.NWBHDF5IO(file=h5py.File(remf, 'r'), mode='r') as io:
    nwbfile_rec = io.read()

    nwbfile = pynwb.NWBFile(
        session_description=nwbfile_rec.session_description,
        identifier=str(uuid4()),
        session_start_time=nwbfile_rec.session_start_time,
        experimenter=nwbfile_rec.experimenter,
        experiment_description=nwbfile_rec.experiment_description,
        lab=nwbfile_rec.lab,
        institution=nwbfile_rec.institution,
        subject=pynwb.file.Subject(
            subject_id=nwbfile_rec.subject.subject_id,
            age=nwbfile_rec.subject.age,
            date_of_birth=nwbfile_rec.subject.date_of_birth,
            sex=nwbfile_rec.subject.sex,
            species=nwbfile_rec.subject.species,
            description=nwbfile_rec.subject.description
        ),
        session_id=nwbfile_rec.session_id,
        keywords=nwbfile_rec.keywords
    )

    for unit_id in sorting.get_unit_ids():
        st = sorting.get_unit_spike_train(unit_id) / sorting.get_sampling_frequency()
        nwbfile.add_unit(
            id=unit_id,
            spike_times=st
        )
    
    nwbfile.create_processing_module('ecephys', 'ground_truth_units')
    gtunits = Units(name='ground_truth_units')
    for unit_id in sorting_true.get_unit_ids():
        st = sorting_true.get_unit_spike_train(unit_id) / sorting_true.get_sampling_frequency()
        gtunits.add_unit(
            id=unit_id,
            spike_times=st
        )
    nwbfile.processing['ecephys'].add(gtunits)
    
    sorting_fname = '_'.join(dandi_nwb_fname.split('_')[:-1] + ['desc-ms5-units'] + dandi_nwb_fname.split('_')[-1:])
    
    # Write the nwb file
    with pynwb.NWBHDF5IO(sorting_fname, 'w') as io:
        io.write(nwbfile, cache_spec=True)