# Convert recording and sorting extractor data to TINT format

The Hussaini lab uses the proprietary TINT software from Axona to analyze extracellular electrophysiology data. While we are already able to read various data formats from Axona (`raw` data or `unit` data) into spikeinterface, perform preprocessing, spike sorting and export the data to NWB, we also want to allow to export data to the TINT format. 

The TINT format is essentially the same as the `unit` data, including `.X` and `.pos` files, but also `.cut` or `.clu`. The latter two contain information about the spike sorted units.

The conversion can be facilitated by using the existing tools from the Hussaini lab, which [convert `.bin` data to `.X` and `.pos`](https://github.com/HussainiLab/BinConverter/blob/master/BinConverter/core/ConversionFunctions.py). Some of this code is only relevant for using the GUI, which did not work for me. I cleared out GUI code and ran a conversion from `.bin` to `.X` and `.pos` in this notebook: [explore_hussaini_tools.ipynb](https://github.com/sbuergers/hussaini-lab-to-nwb-notebooks/blob/master/explore_hussaini_tools.ipynb).

They also already wrote a [`write_cut()`](https://github.com/GeoffBarrett/gebaSpike/blob/967097ec28592182ef9783d2d391930e1c63ca58/gebaSpike/core/writeCut.py) function.

We can test our solutions by reading data with these [Hussaini lab tools](https://github.com/HussainiLab/BinConverter/blob/master/BinConverter/core/Tint_Matlab.py). 

<a id='index'></a>
## Index

* [Testing functions](#testing_functions)
* [Hussaini-lab functions](#hussaini-lab_functions)
* [Convert Recording Extractor to TINT](#Convert_recording_extractor_to_tint)
* [Convert Sorting Extractor to TINT](#Convert_sorting_extractor_to_tint)

In [1]:
import sys
import os
from pathlib import Path
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (16, 8)
plt.rcParams.update({'font.size':14})
%matplotlib inline

import spikeextractors as se
import spiketoolkit as st

print(sys.version, sys.platform, sys.executable)

3.8.5 (default, Sep  4 2020, 07:30:14) 
[GCC 7.3.0] linux /home/sbuergers/spikeinterface/spikeinterface_new_api/venv/bin/python


In [2]:
# Directories

dir_name = Path('/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/sample_bin_to_tint_no_bin')
print('Input directory = ', dir_name)

save_dir = dir_name / 'conversion_to_tint'
save_dir.mkdir(parents=True, exist_ok=True)
print('Output directory = ', save_dir)

Input directory =  /mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/sample_bin_to_tint_no_bin
Output directory =  /mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/sample_bin_to_tint_no_bin/conversion_to_tint


In [3]:
# Read cached spikeextractors data

r_cache = se.load_extractor_from_pickle(os.path.join(dir_name, 'cached_unit_data_no_bin_preproc.pkl'))

In [4]:
# Read NWB recording data

nwb_dir = Path(dir_name, 'nwb')
recording_nwb = se.NwbRecordingExtractor(nwb_dir / 'axona_tutorial_re2.nwb')

  warn(msg)


In [5]:
# Read NWB sorting data

sorting_nwb = se.NwbSortingExtractor(nwb_dir / 'axona_se_MS4.nwb', sampling_frequency=48000)

In [6]:
# Show data types of different objects

print(type(r_cache))
print(type(recording_nwb))
print(type(sorting_nwb))

<class 'spikeextractors.extractors.bindatrecordingextractor.bindatrecordingextractor.BinDatRecordingExtractor'>
<class 'spikeextractors.extractors.nwbextractors.nwbextractors.NwbRecordingExtractor'>
<class 'spikeextractors.extractors.nwbextractors.nwbextractors.NwbSortingExtractor'>


<a id="testing_functions"></a>
## Testing functions
[back to index](#index)

As we start exporting to putative TINT format, we will want to check if we can read it back in.

In [7]:
from spikeextractors.extractors.axonaunitrecordingextractor import AxonaUnitRecordingExtractor
import os


def test_axonaunitrecordingextractor(filename):
    '''Reads UNIT data with AxonaUnitRecordingExtractor and
    performs some simple operations as a sanity check. 
    
    Parameters
    ----------
    filename : str or Path
        Full filename of `.set` file (could be any extension actually)
    '''
    re = AxonaUnitRecordingExtractor(filename=filename)
    
    # TEST AXONARECORDINGEXTRACTOR
    # Retrieve some simple recording information and print it
    recording = re
    print('Channel ids = {}'.format(recording.get_channel_ids()))
    print('Num. channels = {}'.format(len(recording.get_channel_ids())))
    print('Sampling frequency = {} Hz'.format(recording.get_sampling_frequency()))
    print('Num. timepoints = {}'.format(recording.get_num_frames()))
    print('Stdev. on third channel = {}'.format(np.std(recording.get_traces(channel_ids=2))))
    print('Location of third electrode = {}'.format(
        recording.get_channel_property(channel_id=2, property_name='location')))
    print('Channel groups = {}'.format(recording.get_channel_groups()))
    
    # TEST NEO_READER (axonaio)
    print(recording.neo_reader.header['signal_channels'])
    
    
def test_tetrode_files(filename):
    '''Reads UNIT data with AxonaUnitRecordingExtractor and
    performs some simple operations as a sanity check. 
    Will only test .X  and .set files (no .clu or .cut, no .pos).
    
    Parameters
    ----------
    filename : str or Path
        Full filename of `.set` file (could be any extension actually)
    '''
    test_axonaunitrecordingextractor(filename)

<a id="hussaini-lab_functions"></a>
## Hussaini-lab functions
[back to index](#index)

`gebaSpike` actually wants already existing `.cut` or `.clu` files, and allows modifying them. So these might not be all that useful for exporting to `.cut` or `.clu`.

In [8]:
# From 
# https://github.com/GeoffBarrett/gebaSpike/blob/967097ec28592182ef9783d2d391930e1c63ca58/gebaSpike/main.py

def save_function(self):
    """
    this method will save the .cut file
    :return:
    """
    if self.cut_filename.text() == default_filename:
        return

    save_filename = os.path.realpath(self.cut_filename.text())

    if os.path.exists(save_filename):
        self.choice = None
        self.LogError.signal.emit('OverwriteCut!%s' % save_filename)
        while self.choice is None:
            time.sleep(0.1)

        if self.choice != QtWidgets.QMessageBox.Yes:
            return

    if len(self.tetrode_data) == 0:
        return

    # organize the cut data
    n_spikes_expected = self.tetrode_data.shape[1]
    n_spikes = len(np.asarray([item for sublist in self.cell_indices.values() for item in sublist]))

    # check that with the manipulation of the spikes, that we still have the correct number of spikes
    if n_spikes != n_spikes_expected:
        self.choice = None
        self.LogError.signal.emit('cutSizeError')
        while self.choice is None:
            time.sleep(0.1)
        return

    # we will check if we are missing some of the spikes somehow. If we kept track of them, then the indices from
    # the spikes, when sorted, should produce an array from 0 -> N-1 spikes.
    if not np.array_equal(np.sort(np.asarray([item for sublist in self.cell_indices.values() for item in sublist])),
                      np.arange(len(self.cut_data_original))):
        self.choice = None
        self.LogError.signal.emit('cutIndexError')
        while self.choice is None:
            time.sleep(0.1)
        return

    cut_values = np.zeros(n_spikes)
    for cell, cell_indices in self.cell_indices.items():
        cut_values[cell_indices] = cell

    if '.clu.' in save_filename:
        # save the .clu filename
        write_clu(save_filename, cut_values)
        self.choice = None
        self.LogError.signal.emit('saveCompleteClu')
        while self.choice is None:
            time.sleep(0.1)
        self.actions_made = False

    else:
        # save the cut filename
        write_cut(save_filename, cut_values)
        self.choice = None
        self.LogError.signal.emit('saveComplete')
        while self.choice is None:
            time.sleep(0.1)
        self.actions_made = False

In [9]:
# From 
# https://github.com/GeoffBarrett/gebaSpike/blob/967097ec28592182ef9783d2d391930e1c63ca58/gebaSpike/core/writeCut.py

def write_cut(cut_filename, cut, basename=None):
    if basename is None:
        basename = os.path.basename(os.path.splitext(cut_filename)[0])

    unique_cells = np.unique(cut)

    if 0 not in unique_cells:
        # if it happens that there is no zero cell, add it anyways
        unique_cells = np.insert(unique_cells, 0, 0)  # object, index, value to insert

    n_clusters = len(np.unique(cut))
    n_spikes = len(cut)

    write_list = []  # the list of values to write

    tab = '    '  # the spaces didn't line up with my tab so I just created a string with enough spaces
    empty_space = '               '  # some of the empty spaces don't line up to x tabs

    # we add 1 to n_clusters because zero is the garbage cell that no one uses
    write_list.append('n_clusters: %d\n' % (n_clusters))
    write_list.append('n_channels: 4\n')
    write_list.append('n_params: 2\n')
    write_list.append('times_used_in_Vt:%s' % ((tab + '0') * 4 + '\n'))

    zero_string = (tab + '0') * 8 + '\n'

    for cell_i in np.arange(n_clusters):
        write_list.append(' cluster: %d center:%s' % (cell_i, zero_string))
        write_list.append('%smin:%s' % (empty_space, zero_string))
        write_list.append('%smax:%s' % (empty_space, zero_string))
    write_list.append('\nExact_cut_for: %s spikes: %d\n' % (basename, n_spikes))

    # now the cut file lists 25 values per row
    n_rows = int(np.floor(n_spikes / 25))  # number of full rows

    remaining = int(n_spikes - n_rows * 25)
    cut_string = ('%3u' * 25 + '\n') * n_rows + '%3u' * remaining

    write_list.append(cut_string % (tuple(cut)))

    with open(cut_filename, 'w') as f:
        f.writelines(write_list)

In [10]:
# From 
# https://github.com/GeoffBarrett/gebaSpike/blob/967097ec28592182ef9783d2d391930e1c63ca58/gebaSpike/core/writeCut.py

def write_clu(clu_filename, data):
    # the .clu files and the .cut files are different since the .clu files are the .cut files (with no manual sorting)
    # without the headers, and the values go from 1 -> N instead of 0 -> N, (1-based numbering instead of 0-based). Thus
    # we add 1 to the .cut data to get the .clu data

    data = np.asarray(data).astype(int)  # ensuring that the data is the integer data-type

    data += 1  # making the data 1-based instead of 0-based

    # calculating the number of clusters
    n_clust = len(np.unique(data))

    # ensuring that the cluster number is the 1st value
    data = np.concatenate(([n_clust], data))

    # saving the data as a column (delimter='\n') and integer format.
    np.savetxt(clu_filename, data, fmt='%d', delimiter='\n')

In [11]:
# From 
# https://github.com/HussainiLab/BinConverter/blob/master/BinConverter/core/ConvertTetrode.py

import os
from BinConverter.core.conversion_utils import get_set_header
import numpy as np
import struct


def write_tetrode(filepath, data, Fs):

    session_path, session_filename = os.path.split(filepath)
    tint_basename = os.path.splitext(session_filename)[0]
    set_filename = os.path.join(session_path, '%s.set' % tint_basename)

    n = len(data)

    header = get_set_header(set_filename)

    with open(filepath, 'w') as f:
        num_chans = 'num_chans 4'
        timebase_head = '\ntimebase %d hz' % (96000)
        bp_timestamp = '\nbytes_per_timestamp %d' % (4)
        # samps_per_spike = '\nsamples_per_spike %d' % (int(Fs*1e-3))
        samps_per_spike = '\nsamples_per_spike %d' % (50)
        sample_rate = '\nsample_rate %d hz' % (Fs)
        b_p_sample = '\nbytes_per_sample %d' % (1)
        # b_p_sample = '\nbytes_per_sample %d' % (4)
        spike_form = '\nspike_format t,ch1,t,ch2,t,ch3,t,ch4'
        num_spikes = '\nnum_spikes %d' % (n)
        start = '\ndata_start'

        write_order = [header, num_chans, timebase_head,
                       bp_timestamp,
                       samps_per_spike, sample_rate, b_p_sample, spike_form, num_spikes, start]

        f.writelines(write_order)

    # rearranging the data to have a flat array of t1, waveform1, t2, waveform2, t3, waveform3, etc....
    spike_times = np.asarray(sorted(data.keys()))

    # the spike times are repeated for each channel so lets tile this
    spike_times = np.tile(spike_times, (4, 1))
    spike_times = spike_times.flatten(order='F')

    spike_values = np.asarray([value for (key, value) in sorted(data.items())])

    # this will create a (n_samples, n_channels, n_samples_per_spike) => (n, 4, 50) sized matrix, we will create a
    # matrix of all the samples and channels going from ch1 -> ch4 for each spike time
    # time1 ch1_data
    # time1 ch2_data
    # time1 ch3_data
    # time1 ch4_data
    # time2 ch1_data
    # time2 ch2_data
    # .
    # .
    # .

    spike_values = spike_values.reshape((n * 4, 50))  # create the 4nx50 channel data matrix

    # make the first column the time values
    spike_array = np.hstack((spike_times.reshape(len(spike_times), 1), spike_values))

    data = None
    spike_times = None
    spike_values = None

    spike_n = spike_array.shape[0]

    t_packed = struct.pack('>%di' % spike_n, *spike_array[:, 0].astype(int))
    spike_array = spike_array[:, 1:]  # removing time data from this matrix to save memory

    spike_data_pack = struct.pack('<%db' % (spike_n*50), *spike_array.astype(int).flatten())

    spike_array = None

    # now we need to combine the lists by alternating

    comb_list = [None] * (2*spike_n)
    comb_list[::2] = [t_packed[i:i + 4] for i in range(0, len(t_packed), 4)]  # breaks up t_packed into a list,
    # each timestamp is one 4 byte integer
    comb_list[1::2] = [spike_data_pack[i:i + 50] for i in range(0, len(spike_data_pack), 50)]  # breaks up spike_data_
    # pack and puts it into a list, each spike is 50 one byte integers

    t_packed = None
    spike_data_pack = None

    write_order = []
    with open(filepath, 'rb+') as f:

        write_order.extend(comb_list)
        write_order.append(bytes('\r\ndata_end\r\n', 'utf-8'))

        f.seek(0, 2)
        f.writelines(write_order)

<a id="Convert_recording_extractor_to_tint"></a>
## Convert Recording extractor to TINT
[back to index](#index)

Hmm, since we are writing to TINT, thereby creating `.X` tetrode files, we throw away all information in-between spikes. There is no point to convert the fake continuous recording used for spike sorting to TINT at all. We really only want to export the spike sorting output!

In [12]:
# Anything to do here?

<a id="Convert_sorting_extractor_to_tint"></a>
## Convert Sorting extractor to TINT
[back to index](#index)

There are several points in the pipeline at which we might want to export to TINT. Ideally it should work for any `SortingExtractor` object!

In [13]:
print('Where do we load data from?\n', dir_name)

Where do we load data from?
 /mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/sample_bin_to_tint_no_bin


From a sorting extractor we can obtain a list unit spike sample arrays. We can convert this to the .clu or .cut type array of unit ID labels for each spike.


In [14]:
cut_filename = Path('/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/Axona_Tint_1ms/20201004_Tint_1.cut')

basename = os.path.basename(os.path.splitext(cut_filename)[0])

print(basename)

20201004_Tint_1


In [15]:
filename = Path('/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/Axona_Tint_1ms/20201004_Tint.set')
print(filename)

Path(str(filename.with_suffix('')) + '_{}'.format(1) + '.cut')

/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/Axona_Tint_1ms/20201004_Tint.set


PosixPath('/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/Axona_Tint_1ms/20201004_Tint_1.cut')

### Write unit labels to .cut and .clu files

In [16]:
def convert_spike_train_to_label_array(spike_train):
    '''Takes a list of arrays, where each array is a series of
    sample points at which a spike occured for a given unit
    (each list item is a unit). Converts to .cut array, i.e.
    orders spike samples from all units and labels each sample
    with the appropriate unit ID.
    
    Parameters
    ----------
    spike_train : List of np.arrays
        Output of `get_units_spike_train()` method of sorting extractor
        
    Return
    ------
    unit_labels_sorted : np.array
        Each entry is the unit ID corresponding to the spike sample that
        occured at this ordinal position
    '''

    # Generate Index array (indexing the unit for a given spike sample)
    unit_labels = []
    for i, l in enumerate(spike_train):
        unit_labels.append(np.ones((len(l),), dtype=int) * i)
    
    # Flatten lists and sort them
    spike_train_flat = np.concatenate(spike_train).ravel()
    unit_labels_flat = np.concatenate(unit_labels).ravel()

    sort_index = np.argsort(spike_train_flat)

    unit_labels_sorted = unit_labels_flat[sort_index]

    return unit_labels_sorted

In [17]:
def write_to_cut_file(cut_filename, unit_labels):
    '''Write spike sorting output to .cut file.
    
    Parameters
    ----------
    cut_filename : str or Path
        Full filename of .cut file to write to. A given .cut file belongs
        to a given tetrode file. For example, for tetrode `my_file.1`, the
        corresponding cut_filename should be `my_file_1.cut`.
    unit_labels : np.array
        Vector of unit labels for each spike sample (ordered by time of 
        occurence)
        
    Example
    -------
    # Given a sortingextractor called sorting_nwb:
    spike_train = sorting_nwb.get_units_spike_train()
    unit_labels = convert_spike_train_to_label_array(spike_train)
    write_to_cut_file(cut_filename, unit_labels)
    
    ---
    Largely based on gebaSpike implementation by Geoff Barrett
    https://github.com/GeoffBarrett/gebaSpike
    '''

    unique_cells = np.unique(unit_labels)

    n_clusters = len(np.unique(unit_labels))
    n_spikes = len(unit_labels)

    write_list = []

    tab = '    '
    empty_space = '               '

    write_list.append('n_clusters: %d\n' % (n_clusters))
    write_list.append('n_channels: 4\n')
    write_list.append('n_params: 2\n')
    write_list.append('times_used_in_Vt:%s' % ((tab + '0') * 4 + '\n'))

    zero_string = (tab + '0') * 8 + '\n'

    for cell_i in np.arange(n_clusters):
        write_list.append(' cluster: %d center:%s' % (cell_i, zero_string))
        write_list.append('%smin:%s' % (empty_space, zero_string))
        write_list.append('%smax:%s' % (empty_space, zero_string))
    write_list.append('\nExact_cut_for: %s spikes: %d\n' % (basename, n_spikes))

    # The unit label array consists of 25 values per row in .cut file
    n_rows = int(np.floor(n_spikes / 25))
    remaining = int(n_spikes - n_rows * 25)

    cut_string = ('%3u' * 25 + '\n') * n_rows + '%3u' * remaining

    write_list.append(cut_string % (tuple(unit_labels)))

    with open(cut_filename, 'w') as f:
        f.writelines(write_list)

In [18]:
def write_to_clu_file(clu_filename, unit_labels):
    ''' .clu files are pruned .cut files, containing only a long vector of unit
    labels, which are 1-indexed, instead of 0-indexed. In addition, the very first
    entry is the total number of units.
    
    Parameters
    ----------
    clu_filename : str or Path
        Full filename of .clu file to write to. A given .clu file belongs
        to a given tetrode file. For example, for tetrode `my_file.1`, the
        corresponding clu_filename should be `my_file_1.clu`.
    unit_labels : np.array
        Vector of unit labels for each spike sample (ordered by time of 
        occurence)
        
    ---
    Largely based on gebaSpike implementation by Geoff Barrett
    https://github.com/GeoffBarrett/gebaSpike
    '''
    unit_labels = np.asarray(unit_labels).astype(int)
    unit_labels += 1

    n_clust = len(np.unique(unit_labels))
    unit_labels = np.concatenate(([n_clust], unit_labels))

    np.savetxt(clu_filename, unit_labels, fmt='%d', delimiter='\n')

In [19]:
def set_cut_filename_from_basename(filename, tetrode_id):
    '''Given a str or Path object, assume the last entry after a slash
    is a filename, strip any file suffix, add tetrode ID label, and
    .cut suffix to name.
    
    Parameters
    ----------
    filename : str or Path
    tetrode_id : int
    '''
    return Path(str(filename).split('.')[0] + '_{}'.format(tetrode_id) + '.cut')

In [20]:
def write_unit_labels_to_file(sorting_extractor, filename):
    '''Write spike sorting output to .cut file, separately for each
    tetrode.
    
    Parameters
    ----------
    sorting_extractor : spikeextractors.SortingExtractor
    filename : str or Path
        Full filename of .set file or base-filename (i.e. the part of the
        filename all Axona files have in common). A given .cut file belongs
        to a given tetrode file. For example, for tetrode `my_file.1`, the
        corresponding cut_filename should be `my_file_1.cut`. This will be
        set automatically given the base-filename or set file.
        
    TODO: Any reason one might want to only convert some tetrodes or some
    samples? Should those be parameters?
    '''
    tetrode_ids = sorting_extractor.get_units_property(property_name='group')
    tetrode_ids = np.array(tetrode_ids)
    
    unit_ids = np.array(sorting_extractor.get_unit_ids())
    
    for i in np.unique(tetrode_ids):
        
        print('Converting Tetrode {}'.format(i))

        spike_train = sorting_extractor.get_units_spike_train(unit_ids=unit_ids[tetrode_ids==i])
        unit_labels = convert_spike_train_to_label_array(spike_train)

        # We use Axona conventions for filenames (tetrodes are 1 indexed)
        cut_filename = set_cut_filename_from_basename(filename, i + 1)
        clu_filename = Path(str(cut_filename).replace('.cut', '.clu'))

        write_to_cut_file(cut_filename, unit_labels)
        write_to_clu_file(clu_filename, unit_labels)

In [21]:
# We have sorting data exported in `.nwb` format

nwb_dir = Path(dir_name, 'nwb')
sorting_nwb = se.NwbSortingExtractor(nwb_dir / 'axona_se_MS4.nwb', sampling_frequency=48000)

print(type(sorting_nwb))

<class 'spikeextractors.extractors.nwbextractors.nwbextractors.NwbSortingExtractor'>


In [22]:
print('Sampling frequency:', sorting_nwb.get_sampling_frequency(), 'Hz')

Sampling frequency: 48000 Hz


In [23]:
# Convert all tetrodes from sorting extractor to cut files
write_unit_labels_to_file(sorting_nwb, filename)

Converting Tetrode 0
Converting Tetrode 1
Converting Tetrode 2
Converting Tetrode 3


### Write waveforms to tetrode files (.X)

Here, we need information that is available in the `.set` file.

In [23]:
def parse_generic_header(filename):
    """
    Given a binary file with phrases and line breaks, enters the
    first word of a phrase as dictionary key and the following
    string (without linebreaks) as value. Returns the dictionary.
    
    Parameters
    ----------
    filename : str or Path
        Full filename.
    """
    header = {}
    with open(filename, 'rb') as f:
        for bin_line in f:
            if b'data_start' in bin_line:
                break
            line = bin_line.decode('cp1252').replace('\r\n', '').replace('\r', '').strip()
            parts = line.split(' ')
            key = parts[0]
            value = ' '.join(parts[1:])
            header[key] = value
            
    return header

In [24]:
def get_unit_group_ids(recording, sorting):
    '''Generate group ids (tetrode id - 1) for each unit based on 
    output from get_waveforms(). This is very slow, would be good to find a
    faster soluation!
    
    Parameters
    ----------
    recording : RecordingExtractor
    sorting : SortingExtractor
    
    Returns
    -------
    group_ids : List
        List of groups ids for each Unit in `sorting`.
    '''

    _, _, ch_idxs = st.postprocessing.get_unit_waveforms(
        recording,
        sorting,
        max_spikes_per_unit=1, 
        grouping_property='group',
        recompute_info=True,
        ms_before=0,
        ms_after=0.05,
        return_idxs=True,
        return_scaled=False
    )

    unit_groups = []
    for ch_idx in ch_idxs:
        if str(ch_idx) not in unit_groups:
            unit_groups.append(str(ch_idx))

    unit_group_mapping = {el: i for i, el in enumerate(unit_groups)}

    group_ids = [unit_group_mapping[str(ch_idx)] for ch_idx in ch_idxs]
    
    return group_ids

In [111]:
#%%time

#print(get_unit_group_ids(recording, sorting))

[0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]
CPU times: user 7.02 s, sys: 179 ms, total: 7.2 s
Wall time: 9.44 s


In [114]:
def combine_units_on_tetrode(group_spike_samples, group_waveforms):
    '''Write all waveforms of given tetrode in dictionary with the
    corresponding spike samples being the keys (1 sample for each
    waveform).
    
    Parameters
    ----------
    group_spike_samples : list
        As returned by sortingextractor.get_units_spike_train()
    group_waveforms : list
        As returned by spiketoolkit.postprocessing.get_unit_waveforms()
    
    Returns
    -------
    tetrode_spikes : dict
        Keys are spike samples, values are waveforms (ntrls x nch x nsamp)
    '''
    tetrode_spikes = {}

    for i, (samples, waveforms) in enumerate(zip(group_spike_samples, group_waveforms)):

        for sample, waveform in zip(samples, waveforms):

            tetrode_spikes[sample] = waveform
            
    return tetrode_spikes

In [73]:
st.postprocessing.get_unit_waveforms??

In [115]:
def get_waveforms(recording, sorting, unit_ids, header):
    '''Get waveforms for specific tetrode.
    
    Parameters
    ----------
    recording : RecordingExtractor
    sorting : SortingExtractor
    unit_ids : List
        List of unit ids to extract waveforms
    header : dict
        maps parameters from .set file to their values (as strings).
        
    Returns
    -------
    waveforms : List
        List of np.array (n_spikes, n_channels, n_timepoints) with waveforms for each unit
    spike_indexes: list
        List of spike indexes for which waveforms are computed. Returned if 'return_idxs' is True
    channel_indexes: list
        List of max channel indexes
    '''
    sampling_rate = recording.get_sampling_frequency()
    samples_before = int(header['pretrigSamps'])
    samples_after = int(header['spikeLockout'])

    ms_before = samples_before / (sampling_rate / 1000) + 0.001
    ms_after = samples_after / (sampling_rate / 1000) + 0.001

    waveforms, spike_indexes, channel_indexes = st.postprocessing.get_unit_waveforms(
        recording,
        sorting,
        unit_ids=unit_ids,
        max_spikes_per_unit=None, 
        grouping_property='group',
        recompute_info=True,
        ms_before=ms_before,
        ms_after=ms_after,
        return_idxs=True,
        return_scaled=False,
        dtype=np.int8
    )

    return waveforms, spike_indexes, channel_indexes

In [116]:
def write_to_tetrode_files(recording, sorting, group_ids, set_file):
    '''Get spike samples and waveforms for all tetrodes specified in
    `group_ids`. Note that `group_ids` is 0-indexed, whereas tetrodes are
    1-indexed (so if you want tetrodes 1+2, specify group_ids=[0, 1]).
    
    Parameters
    ----------
    recording : RecordingExtractor
    sorting : SortingExtractor
    group_ids : array like
        Tetrodes to include, but 0-indexed (i.e. tetrodeID - 1)
    set_file : Path or str
        .set file location. Used to determine how many samples prior to and
        post spike sample should be cut out for each waveform. .X files will have
        the same base filename as the .set file. So if you do not want to overwrite
        existing .X files in your .set file directory, copy the .set file to a new
        folder and give its new location. The new .X files will appear there.
    '''
    sampling_rate = recording.get_sampling_frequency()
    group_ids = get_unit_group_ids(recording, sorting)
    header = parse_generic_header(set_file)

    for group_id in np.unique(group_ids):

        # get spike samples and waveforms of this group / tetrode
        group_unit_ids = [i for i, gid in enumerate(group_ids) if gid==group_id]
        group_waveforms = get_waveforms(recording, sorting, group_unit_ids, header)
        group_spike_samples = sorting_nwb.get_units_spike_train(unit_ids=group_unit_ids)

        # Assign each waveform to it's spike sample in a dictionary
        spike_waveform_dict = combine_units_on_tetrode(group_spike_samples, group_waveforms)

        # Set tetrode filename
        tetrode_filename = str(set_file).split('.')[0] + '.{}'.format(group_id + 1)
        print(tetrode_filename)

        # Use `BinConverter` function to write to tetrode file
        write_tetrode(tetrode_filename, spike_waveform_dict, sampling_rate)

In [None]:
sampling_rate = recording.get_sampling_frequency()
group_ids = get_unit_group_ids(recording, sorting)
header = parse_generic_header(set_file_to_tint)

for group_id in np.unique(group_ids):

    # get spike samples and waveforms of this group / tetrode
    group_unit_ids = [i for i, gid in enumerate(group_ids) if gid==group_id]
    group_waveforms, _, _ = get_waveforms(recording, sorting, group_unit_ids, header)
    group_spike_samples = sorting_nwb.get_units_spike_train(unit_ids=group_unit_ids)

    # Assign each waveform to it's spike sample in a dictionary
    spike_waveform_dict = combine_units_on_tetrode(group_spike_samples, group_waveforms)

    # Set tetrode filename
    tetrode_filename = str(set_file).split('.')[0] + '.{}'.format(group_id + 1)
    print(tetrode_filename)

    # Use `BinConverter` function to write to tetrode file
    write_tetrode(tetrode_filename, spike_waveform_dict, sampling_rate)

In [117]:
set_file = dir_name / 'axona_sample.set'
print(set_file)

/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/sample_bin_to_tint_no_bin/axona_sample.set


In [118]:
from spikeextractors.extractors.axonaunitrecordingextractor import AxonaUnitRecordingExtractor

recording = AxonaUnitRecordingExtractor(filename=set_file)
signal = recording.get_traces(channel_ids=None, start_frame=None, end_frame=None, return_scaled=False)

In [119]:
set_file_to_tint = dir_name / 'conversion_to_tint' / 'axona_sample.set'
print(set_file_to_tint)

/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/sample_bin_to_tint_no_bin/conversion_to_tint/axona_sample.set


In [120]:
import pandas as pd

df = pd.DataFrame({
    'channel_ids': recording.get_channel_ids(),
    'channel_groups': recording.get_channel_groups(),
    'tetrode_ids': recording.get_channel_groups() + 1
})
df

Unnamed: 0,channel_ids,channel_groups,tetrode_ids
0,0,0,1
1,1,0,1
2,2,0,1
3,3,0,1
4,4,1,2
5,5,1,2
6,6,1,2
7,7,1,2
8,8,2,3
9,9,2,3


In [121]:
channel_groups = recording.get_channel_groups()
channel_groups

array([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3])

In [122]:
write_to_tetrode_files(recording, sorting_nwb, channel_groups, set_file_to_tint)

/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/sample_bin_to_tint_no_bin/conversion_to_tint/axona_sample.1


  return array(a, dtype, copy=False, order=order)


ValueError: cannot reshape array of size 4 into shape (16,50)

In [42]:
spike_waveform_dict[806].shape

(72, 4, 50)

In [72]:
sorting_nwb.get_shared_unit_property_names()

['firing_rate',
 'group',
 'halfwidth',
 'isi_violation',
 'max_channel',
 'peak_to_valley',
 'peak_trough_ratio',
 'recovery_slope',
 'repolarization_slope',
 'snr',
 'template']

In [None]:
sorting_nwb.get_unit_property()

In [71]:
sorting_nwb.get_unit_ids()

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]

In [79]:
ch_idxs

[array([ 8,  9, 10, 11]),
 array([ 8,  9, 10, 11]),
 array([12, 13, 14, 15]),
 array([12, 13, 14, 15])]

In [None]:
# Get unit by tetrode mapping from sorting extractor

In [81]:
# Get unit by group mapping from get_waveforms


In [95]:
# Set up parameters
sorting = sorting_nwb
group_ids = channel_groups

# write_to_tetrode_files()
sampling_rate = recording.get_sampling_frequency()

group_id = 2

# get spike samples and waveforms of this group / tetrode
group_unit_ids = [i for i, gid in enumerate(group_ids) if gid==group_id]
print(group_unit_ids)

group_waveforms, spk_idxs, ch_idxs = get_waveforms(recording, sorting, group_unit_ids, set_file_to_tint)
group_spike_samples = sorting_nwb.get_units_spike_train(unit_ids=group_unit_ids)

# Assign each waveform to it's spike sample in a dictionary
spike_waveform_dict = combine_units_on_tetrode(group_spike_samples, group_waveforms)

# Set tetrode filename
tetrode_filename = str(set_file_to_tint).split('.')[0] + '.{}'.format(group_id + 1)
print(tetrode_filename)



[8, 9, 10, 11]
/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/sample_bin_to_tint_no_bin/conversion_to_tint/axona_sample.3


In [68]:
# Use `BinConverter` function to write to tetrode file
write_tetrode(tetrode_filename, spike_waveform_dict, sampling_rate)


In [73]:
recording.get_channel_ids()[0]

0

In [84]:
self = recording
channel_ids = [0]
start_frame = 0
end_frame = 1
return_scaled=False

timebase_sr = int(self.neo_reader.file_parameters['unit']['timebase'].split(' ')[0])
samples_pre = int(self.neo_reader.file_parameters['set']['file_header']['pretrigSamps'])
samples_post = int(self.neo_reader.file_parameters['set']['file_header']['spikeLockout'])
sampling_rate = self.get_sampling_frequency()

tcmap = self._get_tetrode_channel_table(channel_ids)

traces = self._noise_std * np.random.randn(len(channel_ids), end_frame - start_frame)
if return_scaled:
    traces = traces.astype(np.float32)
else:
    traces = traces.astype(np.int8)

# Loop through tetrodes and include requested channels in traces
itrc = 0
for tetrode_id in np.unique(tcmap[:, 0]):

    channels_oi = tcmap[tcmap[:, 0] == tetrode_id, 2]

    waveforms = self.neo_reader._get_spike_raw_waveforms(
        block_index=0, seg_index=0,
        unit_index=tetrode_id - 1,  # Tetrodes IDs are 1-indexed
        t_start=start_frame / sampling_rate,
        t_stop=end_frame / sampling_rate
    )
    waveforms = waveforms[:, channels_oi, :]
    nch = len(channels_oi)

    spike_train = self.neo_reader._get_spike_timestamps(
        block_index=0, seg_index=0,
        unit_index=tetrode_id - 1,
        t_start=start_frame / sampling_rate,
        t_stop=end_frame / sampling_rate
    )

    # Fill waveforms into traces timestamp by timestamp
    for t, wf in zip(spike_train, waveforms):

        t = int(t // (timebase_sr / sampling_rate))  # timestamps are sampled at higher frequency
        t = t - start_frame
        if (t - samples_pre < 0) and (t + samples_post > traces.shape[1]):
            traces[itrc:itrc + nch, :] = wf[:, samples_pre - t:traces.shape[1] - (t - samples_pre)]
        elif t - samples_pre < 0:
            traces[itrc:itrc + nch, :t + samples_post] = wf[:, samples_pre - t:]
        elif t + samples_post > traces.shape[1]:
            traces[itrc:itrc + nch, t - samples_pre:] = wf[:, :traces.shape[1] - (t - samples_pre)]
        else:
            traces[itrc:itrc + nch, t - samples_pre:t + samples_post] = wf

    itrc += nch

In [83]:
samples_pre

10

In [75]:
recording.get_traces(channel_ids=[recording.get_channel_ids()[0]], start_frame=0, end_frame=1,
                     return_scaled=False).dtype

ValueError: could not broadcast input array from shape (1,40) into shape (1,1)

In [71]:
# retrieve waveforms (all units)
header = parse_generic_header(set_file_to_tint)
sampling_rate = recording.get_sampling_frequency()
samples_before = int(header['pretrigSamps'])
samples_after = int(header['spikeLockout'])

ms_before = samples_before / (sampling_rate / 1000)
ms_after = samples_after / (sampling_rate / 1000)

waveforms, spk_idxs, chn_idxs = st.postprocessing.get_unit_waveforms(
    recording,
    sorting_nwb,
    max_spikes_per_unit=None, 
    grouping_property='group',
    recompute_info=True,
    verbose=True,
    ms_before=ms_before,
    ms_after=ms_after,
    return_idxs=True,
    return_scaled=False,
    dtype=np.int8
)

ValueError: could not broadcast input array from shape (1,40) into shape (1,1)

In [45]:
waveforms

[memmap([[[ -9,  -4,   1, ...,   0,   0,   0],
          [-15, -12,  -9, ...,   4,  -4,  -4],
          [-26, -19,  -8, ...,  -2,  -1,   1],
          [-44, -23,   0, ...,   0,  -1,   0]],
 
         [[  0,  69,  91, ...,   0,  -3,   1],
          [  0,   8,  18, ...,   0,  -2,   7],
          [-16,   0,  17, ...,  -4,   2,   0],
          [  7,  22,  36, ...,   1,   7,   5]],
 
         [[ 44,  29,   9, ...,   3,   0,   0],
          [ 37,  42,  36, ...,   1,   1,  -6],
          [ 45,  46,  36, ...,   0,   0,   0],
          [ 31,  16,  -4, ...,   1,  -3,   2]],
 
         ...,
 
         [[ -3,  -3,  -2, ...,  14,  18,  21],
          [ -1,  -3,  -5, ..., -21,  -8,   7],
          [  0,   2,   2, ...,  -5,   7,  23],
          [ -4,   0,   8, ..., -30, -11,   3]],
 
         [[ 63,  89,  93, ...,   5,   2,  -2],
          [ 14,  41,  71, ...,   1,   0,  -1],
          [ 30,  63,  76, ...,   0,   1,   0],
          [ 46,  70,  72, ...,   1,  -6,   2]],
 
         [[  1,   6,  -5, ...

In [63]:
# retrieve waveforms (all units)
header = parse_generic_header(set_file)
sampling_rate = r_cache.get_sampling_frequency()
samples_before = int(header['pretrigSamps'])
samples_after = int(header['spikeLockout'])

ms_before = samples_before / (sampling_rate / 1000) + 0.001
ms_after = samples_after / (sampling_rate / 1000) + 0.001

waveforms, spk_idxs, chn_idxs = st.postprocessing.get_unit_waveforms(
    recording,
    sorting_nwb,
    max_spikes_per_unit=None, 
    grouping_property='group',
    recompute_info=True,
    verbose=True,
    ms_before=ms_before,
    ms_after=ms_after,
    return_idxs=True,
    return_scaled=False,
    dtype=np.int8
)

ValueError: could not broadcast input array from shape (1,40) into shape (1,1)

In [72]:
len(waveforms)

14

In [74]:
# Do we indeed have max and min values of -127 or +127 now for each waveform?

_ = [print(max(wv_int8.flatten()), min(wv_int8.flatten())) for wv_int8 in waveforms]

127 -128
127 -128
74 -75
74 -75
74 -75
74 -75
99 -100
99 -100
99 -100
99 -100
99 -100
99 -100
99 -100
99 -100


In [73]:
print(len(chn_idxs))
chn_idxs

14


[array([0, 1, 2, 3]),
 array([0, 1, 2, 3]),
 array([4, 5, 6, 7]),
 array([4, 5, 6, 7]),
 array([4, 5, 6, 7]),
 array([4, 5, 6, 7]),
 array([ 8,  9, 10, 11]),
 array([ 8,  9, 10, 11]),
 array([ 8,  9, 10, 11]),
 array([ 8,  9, 10, 11]),
 array([12, 13, 14, 15]),
 array([12, 13, 14, 15]),
 array([12, 13, 14, 15]),
 array([12, 13, 14, 15])]

In [87]:
group_id = 1

In [88]:
[i for i, gid in enumerate(group_ids) if gid==group_id]

[2, 3, 4, 5]

In [83]:
unit_ids[group_ids==group_id]

array([], shape=(0, 14), dtype=int64)

In [75]:
# On the .X file there will be no more unit information, we will write all units from a given
# tetrode for each tetrode channel

spike_samples = sorting_nwb.get_units_spike_train()
for spk_trn, wv in zip(spike_samples, waveforms):
    print(spk_trn.shape, wv.shape)

(72,) (72, 4, 50)
(101,) (101, 4, 50)
(59,) (59, 4, 50)
(15,) (15, 4, 50)
(39,) (39, 4, 50)
(51,) (51, 4, 50)
(38,) (38, 4, 50)
(50,) (50, 4, 50)
(47,) (47, 4, 50)
(73,) (73, 4, 50)
(47,) (47, 4, 50)
(42,) (42, 4, 50)
(67,) (67, 4, 50)
(52,) (52, 4, 50)


In [76]:
spike_samples

[array([  806,  1638,  1854,  2044,  2642,  2929,  4106,  5032,  5982,
         6056,  6711,  7489,  7691,  8229, 10222, 11061, 11961, 13920,
        14137, 16582, 17070, 17403, 17661, 18316, 18751, 18789, 18964,
        20220, 20854, 22096, 24185, 24770, 24788, 25149, 26285, 26382,
        26529, 27031, 27706, 29057, 29884, 30126, 30362, 30739, 31774,
        33709, 34159, 36395, 36676, 37662, 38283, 38612, 39187, 40663,
        41048, 41490, 41602, 42303, 43438, 44502, 44692, 44973, 45147,
        45616, 45747, 45780, 46531, 47303, 49143, 49495, 52031, 55848]),
 array([  472,  1579,  1823,  2255,  2345,  3132,  3869,  3946,  4678,
         5789,  5810,  6242,  6264,  6410,  6699,  6725,  6938,  7077,
         7256,  8059,  9778,  9799, 10239, 10722, 11103, 11946, 12703,
        12769, 12893, 13086, 13648, 14073, 14483, 14972, 15007, 15302,
        15343, 15679, 15711, 15963, 16459, 16756, 17751, 18986, 19221,
        19597, 19910, 20142, 22564, 22645, 22720, 23303, 24087, 24588,
    

In [78]:
group_ids

[0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]

In [93]:
st.postprocessing.get_unit_waveforms?

In [64]:
# Loop through tetrodes, and write all units of tetrode to all channels
# remember that groups are 0 indexed, but tetrodes are 1 indexed (otherwise they are the same)

sampling_rate = recording.get_sampling_frequency()

for group_id in np.unique(group_ids):

    # get spike samples and waveforms of this group / tetrode
    group_unit_ids = [i for i, gid in enumerate(group_ids) if gid==group_id]
    group_waveforms, spk_idxs, chn_idxs = st.postprocessing.get_unit_waveforms(
        recording,
        sorting_nwb,
        unit_ids=group_unit_ids,
        max_spikes_per_unit=None, 
        grouping_property='group',
        recompute_info=True,
        ms_before=ms_before,
        ms_after=ms_after,
        return_idxs=True,
        return_scaled=False,
        dtype=np.int8
    )
    group_spike_samples = sorting_nwb.get_units_spike_train(unit_ids=group_unit_ids)

    # Assign each waveform to it's spike sample in a dictionary
    spike_waveform_dict = combine_units_on_tetrode(group_spike_samples, group_waveforms)

    # Set tetrode filename
    tetrode_filename = dir_name / 'conversion_to_tint' / 'axona_sample.{}'.format(group_id + 1)
    print(tetrode_filename)

    # Use `BinConverter` function to write to tetrode file
    write_tetrode(tetrode_filename, spike_waveform_dict, sampling_rate)

ValueError: could not broadcast input array from shape (1,40) into shape (1,1)

In [102]:
len(tetrode_spikes.keys())

173

In [103]:
tetrode_spikes

{806: memmap([[ -9,  -4,   1,   3,   0, -15, -39, -60, -67, -57, -37, -16,  -2,
            2,   4,  10,  18,  25,  30,  30,  33,  26,  16,   4,   4, -15,
          -16, -12, -11, -14,   0,  -8,   1,  -1,   1,  -1,   0,   1,   0,
           -7,   1,   1,  -1,   0,   1,  -6,   1,   0,   0,   0],
         [-15, -12,  -9,  -4,  -2,  -3, -14, -35, -57, -69, -67, -57, -42,
          -28, -16,  -7,   0,   8,  12,  14,  14,  11,   8,   5,   5,   3,
           -1,  -8, -11, -14,   0,  -1,   3,   1,  -1,   5,  -1,  -2,   1,
           -2,  -8,   2,   0,  -1,   0,   8,   2,   4,  -4,  -4],
         [-26, -19,  -8,   2,  10,  10,  -1, -22, -41, -49, -45, -33, -21,
          -10,  -4,   0,   0,   0,  -3,  -5,  -7,  -8,  -8,  -5,  -5,  -9,
          -12, -16, -17, -16,   1,  -1,  -1,  -1,   0,   2,   0,   0,   0,
            0,  -6,   4,   0,   0,   1,   2,   7,  -2,  -1,   1],
         [-44, -23,   0,  23,  41,  42,  22, -12, -46, -59, -49, -24,  -4,
           -1, -14, -28, -30, -23,  -9,   2,   

In [104]:
tetrode_filename = dir_name / 'conversion_to_tint' / 'axona_sample.1'
tetrode_filename

PosixPath('/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/sample_bin_to_tint_no_bin/conversion_to_tint/axona_sample.1')

In [105]:
# Use `BinConverter` function to write to tetrode file

write_tetrode(tetrode_filename, tetrode_spikes, sampling_rate)

In [107]:
# Read data back in to see if it worked as expected

from neo import AxonaIO

neoio = AxonaIO(tetrode_filename)

waveforms = neoio.get_spike_raw_waveforms()
print(waveforms.shape)
waveforms[0, :, :]

(173, 4, 50)


memmap([[ 60,  60,  51,  33,  19,  11,   4,  -4, -18, -26, -23,  -8,  -8,
          24,  31,  31,  25,  12,  -2, -14, -17, -15, -14, -14,  -8,   2,
          12,  21,  23,  18,  10,   2,   0,   2,   9,  16,  22,  29,  37,
          38,  30,  17,   1,  -2,  -5,   5,  -3,   4,  -3,  -4],
        [ 60,  70,  70,  55,  31,  11,   0,  -5, -18, -36, -48, -46, -36,
         -24, -16, -10,  -8,  -5,   0,   4,   7,   3,  -7, -23, -36, -36,
         -24,  -9,   0,   2,  -2,  -9, -10,  -3,   4,   9,  10,  15,  23,
          29,  29,  25,   3,  -2,  -1,   0,  -1,  -1,   2,   1],
        [ 37,  43,  43,  37,  30,  24,  18,   2, -23, -52, -65, -58, -36,
         -10,   7,  14,  14,  14,  15,  12,   7,   0,  -9, -16, -14,  -2,
          11,  22,  22,  10,  -5, -21, -28, -24, -14,   0,  11,  23,  32,
          35,  28,  16,  -1,  -1,   2,  -3,   2,  -2,   0,  -1],
        [ 44,  56,  62,  56,  46,  43,  42,  35,  16,  -2,  -9,   1,  23,
          43,  58,  69,  67,  48,  12, -25, -53, -65, -62, -46, -

## Convert recording extractor to tetrode files (.X)

In [23]:
from BinConverter.core.ConvertTetrode import write_tetrode
from BinConverter.core.readBin import (
    get_bin_data, get_raw_pos, get_channel_from_tetrode, get_active_tetrode, get_active_eeg
)
from BinConverter.core.Tint_Matlab import int16toint8

from spikeextractors.extractors.axonaunitrecordingextractor import AxonaUnitRecordingExtractor

import pandas as pd

In [43]:
def parse_generic_header(filename):
    """
    Given a binary file with phrases and line breaks, enters the
    first word of a phrase as dictionary key and the following
    string (without linebreaks) as value. Returns the dictionary.
    """
    header = {}
    with open(filename, 'rb') as f:
        for bin_line in f:
            if b'data_start' in bin_line:
                break
            line = bin_line.decode('cp1252').replace('\r\n', '').replace('\r', '').strip()
            parts = line.split(' ')
            key = parts[0]
            value = ' '.join(parts[1:])
            header[key] = value
            
    return header

In [44]:
def get_channel_from_tetrode(tetrode):
    """
    This function will take the tetrode number and return the Axona
    channel numbers, i.e. Tetrode 1 = Ch0-Ch3, Tetrode 2 = Ch4-Ch7, etc.
    """
    return np.arange(0, 4) + 4 * (int(tetrode) - 1)

In [45]:
set_file = dir_name / 'axona_sample.set'
set_filename = set_file
print(set_file)

/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/sample_bin_to_tint_no_bin/axona_sample.set


In [46]:
re = AxonaUnitRecordingExtractor(filename=set_file)

In [47]:
re.neo_reader

AxonaRawIO: /mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/sample_bin_to_tint_no_bin/axona_sample
nb_block: 1
nb_segment:  [1]
signal_streams: [stream 0 (chans: 16)]
signal_channels: [1a, 1b, 1c, 1d ... 4a , 4b , 4c , 4d]
spike_channels: [tetrode 1, tetrode 2, tetrode 3, tetrode 4]
event_channels: []

In [48]:
dir_name

PosixPath('/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/sample_bin_to_tint_no_bin')

In [49]:
header = parse_generic_header(set_file)

In [50]:
pre_spike_samples = int(header['pretrigSamps'])
post_spike_samples = int(header['spikeLockout'])
rejstart = int(header['rejstart'])
rejthreshtail = int(header['rejthreshtail'])
rejthreshupper = int(header['rejthreshupper'])
rejthreshlower = int(header['rejthreshlower'])

print(pre_spike_samples)
print(post_spike_samples)
print(rejstart)
print(rejthreshtail)
print(rejthreshupper)
print(rejthreshlower)

10
40
30
43
100
-100


In [51]:
r_cache

<spikeextractors.extractors.bindatrecordingextractor.bindatrecordingextractor.BinDatRecordingExtractor at 0x7fb2ff61db20>

In [52]:
tetrode = 1
set_file.stem + '.{}'.format(tetrode)

'axona_sample.1'

In [53]:
tetrode_channels = get_channel_from_tetrode(tetrode)

In [54]:
get_bin_data??

Object `get_bin_data` not found.


In [55]:
tetrode

1

In [56]:
import contextlib
import mmap

def get_bin_data(bin_filename, channels=None, tetrode=None):
    """This function will be used to acquire the actual lfp data given the .bin filename,
    and the tetrode or channels (from 1-64) that you want to get"""

    if tetrode is not None:
        channels = get_channel_from_tetrode(tetrode)
    else:
        channels = np.array(channels)  # just in case it isn't an np.array

    bytes_per_iteration = 432

    with open(bin_filename, 'rb') as f:
        # pass
        with contextlib.closing(mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)) as m:
            num_iterations = int(len(m)/bytes_per_iteration)

            data = np.ndarray((num_iterations,), (np.int16, (1,192)), m, 32, (bytes_per_iteration,)).reshape((-1, 1)).flatten()
            data = samples_to_array(data, channels=channels.tolist())

    return data

def samples_to_array(A, channels=[]):
    """This will take data matrix A, and convert it into a numpy array, there are three samples of
    64 channels in this matrix, however their channels do need to be re-mapped"""

    if channels == []:
        channels = np.arange(64) + 1
    else:
        channels = np.asarray(channels)

    A = np.asarray(A)

    sample_num = int(len(A) / 64)  # get the sample numbers

    sample_array = np.zeros((len(channels), sample_num))  # creating a 64x3 array of zeros (64 channels, 3 samples)

    for i, channel in enumerate(channels):
        sample_array[i, :] = A[get_sample_indices(channel, sample_num)]

    return sample_array

def get_sample_indices(channel_number, samples):
    remap_channel = get_remap_chan(channel_number)

    indices_scalar = np.multiply(np.arange(samples), 64)
    sample_indices = indices_scalar + np.multiply(np.ones(samples), remap_channel)

    # return np.array([remap_channel, 64 + remap_channel, 64*2 + remap_channel])
    return (indices_scalar + np.multiply(np.ones(samples), remap_channel)).astype(int)

def get_remap_chan(chan_num):
    """There is re-mapping, thus to get the correct channel data, you need to incorporate re-mapping
    input will be a channel from 1 to 64, and will return the remapped channel"""

    remap_channels = np.array([32, 33, 34, 35, 36, 37, 38, 39, 0, 1, 2, 3, 4, 5,
                               6, 7, 40, 41, 42, 43, 44, 45, 46, 47, 8, 9, 10, 11,
                               12, 13, 14, 15, 48, 49, 50, 51, 52, 53, 54, 55, 16, 17,
                               18, 19, 20, 21, 22, 23, 56, 57, 58, 59, 60, 61, 62, 63,
                               24, 25, 26, 27, 28, 29, 30, 31])

    return remap_channels[chan_num - 1]

In [57]:
bin_filename = Path('/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/sample_bin_to_tint/axona_sample.bin')
data = get_bin_data(bin_filename, tetrode=tetrode)

In [58]:
data.shape

(4, 57600)

In [59]:
data

array([[  3574.,   -230.,      0., ...,      0.,      0.,      0.],
       [  3682.,  -3870.,      0., ..., -13700., -14410., -11876.],
       [  1714.,   -188.,      0., ...,  -9006.,  -9980.,  -9502.],
       [ 10480.,   5308.,      0., ...,  -8460.,  -7266.,  -4906.]])

In [60]:
from neo import AxonaIO

In [61]:
neoio = AxonaIO(bin_filename)

In [62]:
bin_data_neo = neoio.get_analogsignal_chunk(channel_indexes=[12, 13, 14, 15]).T

In [63]:
bin_data_neo

array([[ -7572,  -1930,      0, ..., -26920, -22956, -17398],
       [ -5500,   -206,      0, ..., -18238, -17952, -15410],
       [   798,  -9238,      0, ..., -17090, -13344,  -8568],
       [  2378,  -6200,      0, ..., -17912, -15948, -11356]], dtype=int16)

In [64]:
tet1_data_neo = neoio.get_spike_raw_waveforms()

In [65]:
tet1_data_neo

memmap([[[  -7,    2,   18, ...,   33,   26,   15],
         [  -6,   -2,    1, ...,   25,   25,   22],
         [  -2,    3,   14, ...,   30,   24,   14],
         [ -18,   -6,   13, ...,   22,   10,   -3]],

        [[   4,   -3,  -11, ...,  -11,  -10,  -12],
         [  28,   22,    8, ...,   -7,  -10,  -12],
         [  22,   15,    1, ...,  -14,  -15,  -14],
         [  40,   56,   57, ...,   -9,  -15,  -24]],

        [[  15,    6,   -1, ...,  -40,  -25,   -8],
         [  17,   14,    6, ...,  -29,  -29,  -25],
         [  20,   15,    7, ...,  -34,  -24,  -11],
         [   1,  -13,  -27, ...,  -53,  -31,   -5]],

        ...,

        [[ -40,  -71, -108, ...,  -36,  -57,  -71],
         [ -35,  -39,  -43, ...,   14,    0,  -20],
         [ -15,  -17,  -22, ...,    5,  -15,  -46],
         [ -20,  -20,  -26, ...,   12,   -6,  -32]],

        [[   7,    5,    7, ...,    9,    6,    1],
         [  -2,   -7,  -12, ...,  -15,  -13,  -16],
         [  42,   44,   33, ...,   -1,   -

In [66]:
fraw = 3682
f = 4314.84375  # should be 3682

In [48]:
bin(fraw)

'0b111001100010'

In [49]:
bin(f)

TypeError: 'float' object cannot be interpreted as an integer

In [50]:
traces * 1000

NameError: name 'traces' is not defined

In [51]:
type(data[1,1])

numpy.float64

In [52]:
df = pd.DataFrame({
    'channel_ids': r_cache.get_channel_ids(),
    'channel_groups': r_cache.get_channel_groups(),
    'tetrode_ids': r_cache.get_channel_groups() + 1
})
df

Unnamed: 0,channel_ids,channel_groups,tetrode_ids
0,1,0,1
1,2,0,1
2,4,1,2
3,6,1,2
4,7,1,2
5,8,2,3
6,9,2,3
7,10,2,3
8,11,2,3
9,12,3,4


In [53]:
tetrode_channels = df.loc[df['tetrode_ids'] == tetrode, 'channel_ids'].values + 1
tetrode_channels

array([2, 3])

In [54]:
tetrode_channels-1

array([1, 2])

In [55]:
# A tetrode file expects 4 channels. Fill missing channels with zeros.

traces = np.zeros((4, r_cache.get_num_frames()))
traces[tetrode_channels-1, :] = r_cache.get_traces(channel_ids=tetrode_channels-1)
traces.shape

(4, 57600)

In [56]:
r_cache.get_traces?

In [57]:
tetrode

1

In [58]:
tetrode_channels

array([2, 3])

In [59]:
tetrode = int(tetrode)

tetrode_filename = save_dir / Path(set_file.stem + '.{}'.format(tetrode))

tetrode_channels = df.loc[df['tetrode_ids'] == tetrode, 'channel_ids'].values + 1

traces = np.zeros((4, r_cache.get_num_frames()))
traces[(tetrode_channels-1) % 4, :] = r_cache.get_traces(channel_ids=tetrode_channels-1)

In [60]:
Fs = r_cache.get_sampling_frequency()
active_tetrodes = np.unique(r_cache.get_channel_groups()) + 1

pre_spike_samples = int(header['pretrigSamps'])
post_spike_samples = int(header['spikeLockout'])
rejstart = int(header['rejstart'])
rejthreshtail = int(header['rejthreshtail'])
rejthreshupper = int(header['rejthreshupper'])
rejthreshlower = int(header['rejthreshlower'])

df = pd.DataFrame({
    'channel_ids': r_cache.get_channel_ids(),
    'channel_groups': r_cache.get_channel_groups(),
    'tetrode_ids': r_cache.get_channel_groups() + 1
})

for tetrode in active_tetrodes:

    tetrode = int(tetrode)
    
    tetrode_filename = save_dir / Path(set_file.stem + '.{}'.format(tetrode))
    
    tetrode_channels = df.loc[df['tetrode_ids'] == tetrode, 'channel_ids'].values + 1
    
    traces = np.zeros((4, r_cache.get_num_frames()))
    traces[(tetrode_channels - 1) % 4, :] = r_cache.get_traces(channel_ids=tetrode_channels-1)

    n_samples = traces.shape[1]

    # create a time array that represents the 48kHz sampled data times
    t = np.arange(0, n_samples) / Fs  # creates a time array of the signal starting from 0 (in seconds)

    

<spikeextractors.extractors.bindatrecordingextractor.bindatrecordingextractor.BinDatRecordingExtractor at 0x7fb78c216730>

In [61]:
def get_spikes(data, threshold):
    all_spikes = np.array([])

    for i, channel_data in enumerate(data):
        spike_indices = np.where(channel_data >= threshold[i])[0]

        if len(spike_indices) == 0:
            continue

        spike_indices = find_consec(spike_indices)

        spike_indices = np.asarray([value[0] for value in spike_indices])

        if len(all_spikes) == 0:
            # this is the first iteration of the tetrode, no need to sort
            unadded_spikes = spike_indices
        else:
            idx = matching_ind(all_spikes, spike_indices)
            if len(idx) == 0:
                unadded_spikes = spike_indices
            else:
                unadded_spikes = np.setdiff1d(spike_indices, all_spikes[idx])

        if len(all_spikes) != 0:
            all_spikes = np.sort(np.concatenate((all_spikes, unadded_spikes)))
            unadded_spikes = None
        else:
            all_spikes = np.array(unadded_spikes)

    return all_spikes

def find_consec(data):
    '''finds the consecutive numbers and outputs as a list'''
    consecutive_values = []  # a list for the output
    current_consecutive = [data[0]]

    if len(data) == 1:
        return [[data[0]]]

    for index in range(1, len(data)):

        if data[index] == data[index - 1] + 1:
            current_consecutive.append(data[index])

            if index == len(data) - 1:
                consecutive_values.append(current_consecutive)

        else:
            consecutive_values.append(current_consecutive)
            current_consecutive = [data[index]]

            if index == len(data) - 1:
                consecutive_values.append(current_consecutive)
    return consecutive_values

def matching_ind(haystack, needle):
    idx = np.searchsorted(haystack, needle)
    mask = idx < haystack.size
    mask[mask] = haystack[idx[mask]] == needle[mask]
    idx = idx[mask]
    return idx

def validate_spikes(tetrode, spikes, data, t, pre_spike_samples=10, post_spike_samples=40, rejstart=30,
                    rejthreshtail=43, rejthreshupper=100, rejthreshlower=-100):
    latest_spike = None

    spike_count = 0
    percentage_values = [int(value) for value in np.rint(np.linspace(0, len(spikes), num=21)).tolist()]

    n_max = data.shape[1]

    tetrode_spikes = {}

    for spike in sorted(spikes):
        # iterate through each spike and validate to ensure no spikes occur at the same time or within the
        # refractory period

        spike_count += 1

        if spike_count in percentage_values:
            pass

        if spike - pre_spike_samples + 1 < 0:
            continue

        elif spike + post_spike_samples >= n_max:
            continue

        if latest_spike is not None:
            if spike != latest_spike:
                if spike in spike_refractory:
                    # ensures no overlapping spikes
                    continue
        else:
            pass

        latest_spike = spike
        spike_refractory = list(np.arange(spike + 1, spike + post_spike_samples + 1))

        # spike_time = t[int(spike)]
        spike_time = t[int(spike)]

        # waveform_indices = np.where((t>=spike_time-250/1e6) & (t<=spike_time+850/1e6))[0]  # too slow
        waveform_indices = np.arange(spike - pre_spike_samples + 1, spike + post_spike_samples + 1).astype(int)

        # spike_t = t[waveform_indices] - spike_time  # making the times from -200 us to 800 us

        # spike_waveform = np.zeros((len(tetrode_channels), 50))

        spike_waveform = data[:, waveform_indices]

        spike_time = spike_time * 96000  # multiply it by the timebase to get the frame count

        spike_waveform = np.rint(spike_waveform)

        # artifact rejection

        if sum(spike_waveform[:, rejstart:].flatten() > rejthreshtail) > 0:
            # this is 33% above baseline (0)
            continue

        # check if the first sample is well above or well below baseline
        elif sum(spike_waveform[:, 0].flatten() > rejthreshupper) > 0:
            # the first sample is >100
            continue

        elif sum(spike_waveform[:, 0].flatten() < rejthreshlower) > 0:
            # or < -100
            continue

        tetrode_spikes[spike_time] = spike_waveform

        # latest_spike = spike
        # spike_refractory = list(np.arange(spike + 1, spike + post_spike_samples + 1))

    return tetrode_spikes

In [62]:
tetrode_channels

array([13, 14, 15, 16])

In [63]:
threshold = 3

tetrode_spikes = {}  # creates an empty dictionary to hold the spike times
# for each tetrode, find the spikes

k = 0

# data = int16toint8(data)  # converting the data into int8

tetrode_thresholds = []
for channel_index, channel in enumerate(tetrode_channels):
    k += 1
    '''
    Auto thresholding technique incorporated by:
    Quian Quiroga in 2014 - Unsupervised Spike Detection and Sorting with Wavelets and
    Superparamagnetic Clustering
    Thr = 4*sigma, sigma = median(abs(x)/0.6745)
    '''
    standard_deviations = float(threshold)

    sigma_n = np.median(np.divide(np.abs(data[channel_index, :]), 0.6745))
    # threshold = sigma_n / channel_max
    # threshold = standard_deviations * sigma_n
    tetrode_thresholds.append(standard_deviations * sigma_n)

valid_spikes = get_spikes(data, tetrode_thresholds)

# threshold is done in 16 bit values, but the rejection is done in 8bit, so we convert here
# data = int16toint8(data)  # converting the data into int8

data_int16 = int16toint8(data)  # converting the

In [64]:
tetrode_thresholds

[0.0, 20993.32839140104, 15389.177168272796, 16474.425500370642]

In [65]:
tetrode_ids = sorting_nwb.get_units_property(property_name='group')
tetrode_ids = np.array(tetrode_ids)

unit_ids = np.array(sorting_nwb.get_unit_ids())
spike_train = sorting_nwb.get_units_spike_train(unit_ids=unit_ids[tetrode_ids==3])

sorted_spike_train = np.sort(np.concatenate(spike_train))

sorted_spike_train

array([  126,  1386,  1466,  1495,  1697,  1714,  1735,  2635,  2803,
        2824,  2838,  2923,  2942,  4677,  4691,  4708,  5145,  5166,
        5239,  5261,  5784,  5903,  5924,  6413,  6442,  6699,  6712,
        6732,  6950,  7065,  7080,  7102,  7219,  7260,  8222,  8297,
        8894,  8915,  9690,  9709,  9727, 10221, 10247, 10731, 10875,
       11169, 11204, 11386, 11803, 11818, 11958, 11980, 12689, 12702,
       12883, 12898, 12911, 14247, 14276, 14487, 14967, 15129, 15144,
       15420, 15437, 15664, 15682, 16586, 16628, 16651, 16670, 16760,
       16781, 16794, 16824, 16838, 16862, 17386, 17403, 17426, 17759,
       18324, 18343, 18756, 19596, 19614, 19629, 19820, 19833, 19900,
       20139, 20689, 20703, 21325, 21338, 21469, 21482, 21526, 21540,
       22095, 22371, 22392, 22712, 22744, 23041, 23073, 23564, 23804,
       23819, 23835, 23987, 24009, 24127, 24441, 25212, 25374, 25392,
       25417, 25462, 25482, 26015, 26368, 26382, 27043, 27059, 27746,
       27762, 29277,

In [66]:
valid_spikes

array([    0,     2,   464,   584,   627,   691,   751,   787,   788,
        1021,  1088,  1150,  1230,  1238,  1571,  1583,  1632,  1644,
        1826,  2228,  2238,  2322,  2326,  2377,  2634,  2635,  2905,
        2911,  2912,  3118,  3277,  3297,  3653,  3752,  3785,  3787,
        3872,  3926,  4103,  4352,  4353,  4364,  4577,  4673,  4674,
        4934,  4935,  5024,  5025,  5165,  5232,  5233,  5568,  5569,
        5781,  5784,  5785,  5974,  5976,  6019,  6208,  6260,  6410,
        6411,  6706,  6708,  6768,  6788,  6982,  6983,  7061,  7064,
        7255,  7256,  7482,  7494,  7601,  7686,  7687,  7837,  8173,
        8222,  8223,  8435,  8437,  8464,  8465,  8655,  8755,  8756,
        9357,  9366,  9557,  9696, 10217, 10376, 10723, 10844, 10869,
       10870, 11066, 11455, 11456, 11954, 11955, 12441, 12444, 12451,
       12454, 12677, 12775, 13078, 13087, 13642, 13643, 13644, 14112,
       14113, 14241, 14242, 14483, 14484, 14502, 14621, 14681, 14682,
       14850, 14853,

In [67]:
traces

array([[ 0.00442661, -0.00972959, -0.01101197, ..., -0.02006962,
        -0.00963475, -0.00524077],
       [-0.00532733, -0.013263  , -0.00653222, ..., -0.00630658,
        -0.00487059, -0.00842494],
       [ 0.00252745, -0.0046536 , -0.00752856, ..., -0.03252275,
        -0.01737364, -0.00533893],
       [-0.00476253, -0.01768241, -0.01526715, ..., -0.02812541,
        -0.03498014, -0.04006043]])

In [68]:
data_int16

array([[ 13,   0,   0, ...,   0,   0,   0],
       [ 14, -15,   0, ..., -53, -56, -46],
       [  6,   0,   0, ..., -35, -38, -37],
       [ 40,  20,   0, ..., -33, -28, -19]])

In [69]:
data

array([[  3574.,   -230.,      0., ...,      0.,      0.,      0.],
       [  3682.,  -3870.,      0., ..., -13700., -14410., -11876.],
       [  1714.,   -188.,      0., ...,  -9006.,  -9980.,  -9502.],
       [ 10480.,   5308.,      0., ...,  -8460.,  -7266.,  -4906.]])

In [70]:
valid_spikes

array([    0,     2,   464,   584,   627,   691,   751,   787,   788,
        1021,  1088,  1150,  1230,  1238,  1571,  1583,  1632,  1644,
        1826,  2228,  2238,  2322,  2326,  2377,  2634,  2635,  2905,
        2911,  2912,  3118,  3277,  3297,  3653,  3752,  3785,  3787,
        3872,  3926,  4103,  4352,  4353,  4364,  4577,  4673,  4674,
        4934,  4935,  5024,  5025,  5165,  5232,  5233,  5568,  5569,
        5781,  5784,  5785,  5974,  5976,  6019,  6208,  6260,  6410,
        6411,  6706,  6708,  6768,  6788,  6982,  6983,  7061,  7064,
        7255,  7256,  7482,  7494,  7601,  7686,  7687,  7837,  8173,
        8222,  8223,  8435,  8437,  8464,  8465,  8655,  8755,  8756,
        9357,  9366,  9557,  9696, 10217, 10376, 10723, 10844, 10869,
       10870, 11066, 11455, 11456, 11954, 11955, 12441, 12444, 12451,
       12454, 12677, 12775, 13078, 13087, 13642, 13643, 13644, 14112,
       14113, 14241, 14242, 14483, 14484, 14502, 14621, 14681, 14682,
       14850, 14853,

In [None]:
# --------------------------------------Write Tetrode Data----------------------------------------

Fs = get_Fs(set_filename)  # read the sampling frequency from the .set file, most like 48k

active_tetrodes = get_active_tetrode(set_filename)

# converts the data one tetrode at a time so we can eliminate memory errors

pre_spike_samples = int(get_setfile_parameter('pretrigSamps', set_filename))
post_spike_samples = int(get_setfile_parameter('spikeLockout', set_filename))
rejstart = int(get_setfile_parameter('rejstart', set_filename))
rejthreshtail = int(get_setfile_parameter('rejthreshtail', set_filename))
rejthreshupper = int(get_setfile_parameter('rejthreshupper', set_filename))
rejthreshlower = int(get_setfile_parameter('rejthreshlower', set_filename))

for tetrode in active_tetrodes:

    tetrode = int(tetrode)
    # check if this tetrode exists already

    tetrode_filename = os.path.join(directory, '%s.%d' % (tint_basename, tetrode))
    if os.path.exists(tetrode_filename):
        continue

    tetrode_channels = get_channel_from_tetrode(tetrode)  # get the channels (from range of 1->64)

    data = get_bin_data(bin_filename, tetrode=tetrode)  # 16bit, get data associated with the tetrode

    # converting data to uV

    n_samples = data.shape[1]
    # create a time array that represents the 48kHz sampled data times
    t = np.arange(0, n_samples) / Fs  # creates a time array of the signal starting from 0 (in seconds)

    if not os.path.exists(tetrode_filename):

        # ---------------------------Find the spikes in the unit data --------------------------------------

        tetrode_spikes = {}  # creates an empty dictionary to hold the spike times
        # for each tetrode, find the spikes

        k = 0

        # data = int16toint8(data)  # converting the data into int8

        tetrode_thresholds = []
        for channel_index, channel in enumerate(tetrode_channels):
            k += 1
            '''
            Auto thresholding technique incorporated by:
            Quian Quiroga in 2014 - Unsupervised Spike Detection and Sorting with Wavelets and
            Superparamagnetic Clustering
            Thr = 4*sigma, sigma = median(abs(x)/0.6745)
            '''
            standard_deviations = float(threshold)

            sigma_n = np.median(np.divide(np.abs(data[channel_index, :]), 0.6745))
            # threshold = sigma_n / channel_max
            # threshold = standard_deviations * sigma_n
            tetrode_thresholds.append(standard_deviations * sigma_n)

        valid_spikes = get_spikes(data, tetrode_thresholds)

        # threshold is done in 16 bit values, but the rejection is done in 8bit, so we convert here
        # data = int16toint8(data)  # converting the data into int8

        data = int16toint8(data)  # converting the data into int8

        tetrode_spikes = validate_spikes(tetrode, valid_spikes, data, t, pre_spike_samples,
                                         post_spike_samples, rejstart, rejthreshtail, rejthreshupper,
                                         rejthreshlower)

        # write the tetrode data to create the .N file
        write_tetrode(tetrode_filename, tetrode_spikes, Fs)
    else:
        pass

    data = None
    tetrode_spikes = None
    valid_spikes = None

In [None]:
import os
from .conversion_utils import get_set_header
import numpy as np
import struct


def write_tetrode(filepath, data, Fs):

    session_path, session_filename = os.path.split(filepath)
    tint_basename = os.path.splitext(session_filename)[0]
    set_filename = os.path.join(session_path, '%s.set' % tint_basename)

    n = len(data)

    header = get_set_header(set_filename)

    with open(filepath, 'w') as f:
        num_chans = 'num_chans 4'
        timebase_head = '\ntimebase %d hz' % (96000)
        bp_timestamp = '\nbytes_per_timestamp %d' % (4)
        # samps_per_spike = '\nsamples_per_spike %d' % (int(Fs*1e-3))
        samps_per_spike = '\nsamples_per_spike %d' % (50)
        sample_rate = '\nsample_rate %d hz' % (Fs)
        b_p_sample = '\nbytes_per_sample %d' % (1)
        # b_p_sample = '\nbytes_per_sample %d' % (4)
        spike_form = '\nspike_format t,ch1,t,ch2,t,ch3,t,ch4'
        num_spikes = '\nnum_spikes %d' % (n)
        start = '\ndata_start'

        write_order = [header, num_chans, timebase_head,
                       bp_timestamp,
                       samps_per_spike, sample_rate, b_p_sample, spike_form, num_spikes, start]

        f.writelines(write_order)

    # rearranging the data to have a flat array of t1, waveform1, t2, waveform2, t3, waveform3, etc....
    spike_times = np.asarray(sorted(data.keys()))

    # the spike times are repeated for each channel so lets tile this
    spike_times = np.tile(spike_times, (4, 1))
    spike_times = spike_times.flatten(order='F')

    spike_values = np.asarray([value for (key, value) in sorted(data.items())])

    # this will create a (n_samples, n_channels, n_samples_per_spike) => (n, 4, 50) sized matrix, we will create a
    # matrix of all the samples and channels going from ch1 -> ch4 for each spike time
    # time1 ch1_data
    # time1 ch2_data
    # time1 ch3_data
    # time1 ch4_data
    # time2 ch1_data
    # time2 ch2_data
    # .
    # .
    # .

    spike_values = spike_values.reshape((n * 4, 50))  # create the 4nx50 channel data matrix

    # make the first column the time values
    spike_array = np.hstack((spike_times.reshape(len(spike_times), 1), spike_values))

    data = None
    spike_times = None
    spike_values = None

    spike_n = spike_array.shape[0]

    t_packed = struct.pack('>%di' % spike_n, *spike_array[:, 0].astype(int))
    spike_array = spike_array[:, 1:]  # removing time data from this matrix to save memory

    spike_data_pack = struct.pack('<%db' % (spike_n*50), *spike_array.astype(int).flatten())

    spike_array = None

    # now we need to combine the lists by alternating

    comb_list = [None] * (2*spike_n)
    comb_list[::2] = [t_packed[i:i + 4] for i in range(0, len(t_packed), 4)]  # breaks up t_packed into a list,
    # each timestamp is one 4 byte integer
    comb_list[1::2] = [spike_data_pack[i:i + 50] for i in range(0, len(spike_data_pack), 50)]  # breaks up spike_data_
    # pack and puts it into a list, each spike is 50 one byte integers

    t_packed = None
    spike_data_pack = None

    write_order = []
    with open(filepath, 'rb+') as f:

        write_order.extend(comb_list)
        write_order.append(bytes('\r\ndata_end\r\n', 'utf-8'))

        f.seek(0, 2)
        f.writelines(write_order)

In [220]:
def write_to_tetrode_files(sorting_extractor, save_dir):
    '''Given a sorting extractor object create .X (tetrode) files.
    
    Parameters
    ----------
    sorting_extractor : spikeextractors.SortingExtractor
    save_dir : str or Path
        Directory where to save the output
    '''
    # TODO ...
    pass

In [232]:
def write_to_tint(sorting_extractor, filename):
    '''Given a sorting extractor object, write appropriate data
    to TINT format (from Axona). Will therefore create .X (tetrode),
    .cut and .clu (spike sorting information) files.
    
    Parameters
    ----------
    sorting_extractor : spikeextractors.SortingExtractor
    filename : str or Path
        Full path and base filename shared by all output files 
        (e.g. my_dir/my_file will yield
        my_dir/my_file.1, my_dir/my_file.2, ..., 
        my_dir/my_file_1.cut, my_dir/my_file_2.cut, ...,
        my_dir/my_file_1.clu, my_dir/my_file_2.clu, ...)
        If a file extension is given, it is simply ignored.
        
    Notes
    -----
    For details about the .X file format see:
    http://space-memory-navigation.org/DacqUSBFileFormats.pdf
    '''
    # Make sure directory exists
    filename.parent.absolute().mkdir(parents=True, exist_ok=True)
    
    # writes to .X files for each tetrode
    # TODO...
    write_to_tetrode_file(sorting_extractor, filename)
    
    # writes to .cut and .clu files for each tetrode
    write_unit_labels_to_file(sorting_extractor, filename)
    
    # Position data?
    # TODO ...

In [233]:
filename = Path(
    '/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/Axona_Tint_1ms/spikeextractors_to_tint/20201004_Tint'
)

In [234]:
write_to_tint(sorting_nwb, filename)

Converting Tetrode 0
Converting Tetrode 1
Converting Tetrode 2
Converting Tetrode 3


## Misc

In [None]:
def scale_values(x, maxabs, bound=127):
    '''Scale signal `x` between -`bound` and +`bound`,
    preserves 0 point.
    
    Parameters
    ----------
    x : np.array
    absmax : numeric
        max(|min(x)|, |max(x)|)
    bound : numeric
    
    Return
    ------
    np.array
    '''
    return x / maxabs * bound