# Convert recording and sorting extractor data to TINT format

The Hussaini lab uses the proprietary TINT software from Axona to analyze extracellular electrophysiology data. While we are already able to read various data formats from Axona (`raw` data or `unit` data) into spikeinterface, perform preprocessing, spike sorting and export the data to NWB, we also want to allow to export data to the TINT format. 

The TINT format is essentially the same as the `unit` data, including `.X` and `.pos` files, but also `.cut` or `.clu`. The latter two contain information about the spike sorted units.

The conversion can be facilitated by using the existing tools from the Hussaini lab, which [convert `.bin` data to `.X` and `.pos`](https://github.com/HussainiLab/BinConverter/blob/master/BinConverter/core/ConversionFunctions.py). Some of this code is only relevant for using the GUI, which did not work for me. I cleared out GUI code and ran a conversion from `.bin` to `.X` and `.pos` in this notebook: [explore_hussaini_tools.ipynb](https://github.com/sbuergers/hussaini-lab-to-nwb-notebooks/blob/master/explore_hussaini_tools.ipynb).

They also already wrote a [`write_cut()`](https://github.com/GeoffBarrett/gebaSpike/blob/967097ec28592182ef9783d2d391930e1c63ca58/gebaSpike/core/writeCut.py) function.

We can test our solutions by reading data with these [Hussaini lab tools](https://github.com/HussainiLab/BinConverter/blob/master/BinConverter/core/Tint_Matlab.py). 

<a id='index'></a>
## Index

* [Testing functions](#testing_functions)
* [Hussaini-lab functions](#hussaini-lab_functions)
* [Convert Recording Extractor to TINT](#Convert_recording_extractor_to_tint)
* [Convert Sorting Extractor to TINT](#Convert_sorting_extractor_to_tint)

In [1]:
import sys
import os
from pathlib import Path
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (16, 8)
plt.rcParams.update({'font.size':14})
%matplotlib inline

import spikeextractors as se

print(sys.version, sys.platform, sys.executable)

3.8.5 (default, Sep  4 2020, 07:30:14) 
[GCC 7.3.0] linux /home/sbuergers/spikeinterface/spikeinterface_new_api/venv/bin/python


In [2]:
# Directories

dir_name = Path('/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/sample_bin_to_tint_no_bin')
print('Input directory = ', dir_name)

save_dir = dir_name / 'conversion_to_tint'
save_dir.mkdir(parents=True, exist_ok=True)
print('Output directory = ', save_dir)

Input directory =  /mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/sample_bin_to_tint_no_bin
Output directory =  /mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/sample_bin_to_tint_no_bin/conversion_to_tint


In [3]:
# Read cached spikeextractors data

r_cache = se.load_extractor_from_pickle(os.path.join(dir_name, 'cached_unit_data_no_bin_preproc.pkl'))

In [4]:
# Read NWB recording data

nwb_dir = Path(dir_name, 'nwb')
recording_nwb = se.NwbRecordingExtractor(nwb_dir / 'axona_tutorial_re2.nwb')

  warn(msg)


In [5]:
# Read NWB sorting data

sorting_nwb = se.NwbSortingExtractor(nwb_dir / 'axona_se_MS4.nwb', sampling_frequency=48000)

In [6]:
# Show data types of different objects

print(type(r_cache))
print(type(recording_nwb))
print(type(sorting_nwb))

<class 'spikeextractors.extractors.bindatrecordingextractor.bindatrecordingextractor.BinDatRecordingExtractor'>
<class 'spikeextractors.extractors.nwbextractors.nwbextractors.NwbRecordingExtractor'>
<class 'spikeextractors.extractors.nwbextractors.nwbextractors.NwbSortingExtractor'>


<a id="testing_functions"></a>
## Testing functions
[back to index](#index)

As we start exporting to putative TINT format, we will want to check if we can read it back in.

In [7]:
from spikeextractors.extractors.axonaunitrecordingextractor import AxonaUnitRecordingExtractor
import os


def test_axonaunitrecordingextractor(filename):
    '''Reads UNIT data with AxonaUnitRecordingExtractor and
    performs some simple operations as a sanity check. 
    
    Parameters
    ----------
    filename : str or Path
        Full filename of `.set` file (could be any extension actually)
    '''
    re = AxonaUnitRecordingExtractor(filename=filename)
    
    # TEST AXONARECORDINGEXTRACTOR
    # Retrieve some simple recording information and print it
    recording = re
    print('Channel ids = {}'.format(recording.get_channel_ids()))
    print('Num. channels = {}'.format(len(recording.get_channel_ids())))
    print('Sampling frequency = {} Hz'.format(recording.get_sampling_frequency()))
    print('Num. timepoints = {}'.format(recording.get_num_frames()))
    print('Stdev. on third channel = {}'.format(np.std(recording.get_traces(channel_ids=2))))
    print('Location of third electrode = {}'.format(
        recording.get_channel_property(channel_id=2, property_name='location')))
    print('Channel groups = {}'.format(recording.get_channel_groups()))
    
    # TEST NEO_READER (axonaio)
    print(recording.neo_reader.header['signal_channels'])
    
    
def test_tetrode_files(filename):
    '''Reads UNIT data with AxonaUnitRecordingExtractor and
    performs some simple operations as a sanity check. 
    Will only test .X  and .set files (no .clu or .cut, no .pos).
    
    Parameters
    ----------
    filename : str or Path
        Full filename of `.set` file (could be any extension actually)
    '''
    test_axonaunitrecordingextractor(filename)

<a id="hussaini-lab_functions"></a>
## Hussaini-lab functions
[back to index](#index)

`gebaSpike` actually wants already existing `.cut` or `.clu` files, and allows modifying them. So these might not be all that useful for exporting to `.cut` or `.clu`.

In [54]:
# From 
# https://github.com/GeoffBarrett/gebaSpike/blob/967097ec28592182ef9783d2d391930e1c63ca58/gebaSpike/main.py

def save_function(self):
    """
    this method will save the .cut file
    :return:
    """
    if self.cut_filename.text() == default_filename:
        return

    save_filename = os.path.realpath(self.cut_filename.text())

    if os.path.exists(save_filename):
        self.choice = None
        self.LogError.signal.emit('OverwriteCut!%s' % save_filename)
        while self.choice is None:
            time.sleep(0.1)

        if self.choice != QtWidgets.QMessageBox.Yes:
            return

    if len(self.tetrode_data) == 0:
        return

    # organize the cut data
    n_spikes_expected = self.tetrode_data.shape[1]
    n_spikes = len(np.asarray([item for sublist in self.cell_indices.values() for item in sublist]))

    # check that with the manipulation of the spikes, that we still have the correct number of spikes
    if n_spikes != n_spikes_expected:
        self.choice = None
        self.LogError.signal.emit('cutSizeError')
        while self.choice is None:
            time.sleep(0.1)
        return

    # we will check if we are missing some of the spikes somehow. If we kept track of them, then the indices from
    # the spikes, when sorted, should produce an array from 0 -> N-1 spikes.
    if not np.array_equal(np.sort(np.asarray([item for sublist in self.cell_indices.values() for item in sublist])),
                      np.arange(len(self.cut_data_original))):
        self.choice = None
        self.LogError.signal.emit('cutIndexError')
        while self.choice is None:
            time.sleep(0.1)
        return

    cut_values = np.zeros(n_spikes)
    for cell, cell_indices in self.cell_indices.items():
        cut_values[cell_indices] = cell

    if '.clu.' in save_filename:
        # save the .clu filename
        write_clu(save_filename, cut_values)
        self.choice = None
        self.LogError.signal.emit('saveCompleteClu')
        while self.choice is None:
            time.sleep(0.1)
        self.actions_made = False

    else:
        # save the cut filename
        write_cut(save_filename, cut_values)
        self.choice = None
        self.LogError.signal.emit('saveComplete')
        while self.choice is None:
            time.sleep(0.1)
        self.actions_made = False

In [55]:
# From 
# https://github.com/GeoffBarrett/gebaSpike/blob/967097ec28592182ef9783d2d391930e1c63ca58/gebaSpike/core/writeCut.py

def write_cut(cut_filename, cut, basename=None):
    if basename is None:
        basename = os.path.basename(os.path.splitext(cut_filename)[0])

    unique_cells = np.unique(cut)

    if 0 not in unique_cells:
        # if it happens that there is no zero cell, add it anyways
        unique_cells = np.insert(unique_cells, 0, 0)  # object, index, value to insert

    n_clusters = len(np.unique(cut))
    n_spikes = len(cut)

    write_list = []  # the list of values to write

    tab = '    '  # the spaces didn't line up with my tab so I just created a string with enough spaces
    empty_space = '               '  # some of the empty spaces don't line up to x tabs

    # we add 1 to n_clusters because zero is the garbage cell that no one uses
    write_list.append('n_clusters: %d\n' % (n_clusters))
    write_list.append('n_channels: 4\n')
    write_list.append('n_params: 2\n')
    write_list.append('times_used_in_Vt:%s' % ((tab + '0') * 4 + '\n'))

    zero_string = (tab + '0') * 8 + '\n'

    for cell_i in np.arange(n_clusters):
        write_list.append(' cluster: %d center:%s' % (cell_i, zero_string))
        write_list.append('%smin:%s' % (empty_space, zero_string))
        write_list.append('%smax:%s' % (empty_space, zero_string))
    write_list.append('\nExact_cut_for: %s spikes: %d\n' % (basename, n_spikes))

    # now the cut file lists 25 values per row
    n_rows = int(np.floor(n_spikes / 25))  # number of full rows

    remaining = int(n_spikes - n_rows * 25)
    cut_string = ('%3u' * 25 + '\n') * n_rows + '%3u' * remaining

    write_list.append(cut_string % (tuple(cut)))

    with open(cut_filename, 'w') as f:
        f.writelines(write_list)

In [56]:
# From 
# https://github.com/GeoffBarrett/gebaSpike/blob/967097ec28592182ef9783d2d391930e1c63ca58/gebaSpike/core/writeCut.py

def write_clu(clu_filename, data):
    # the .clu files and the .cut files are different since the .clu files are the .cut files (with no manual sorting)
    # without the headers, and the values go from 1 -> N instead of 0 -> N, (1-based numbering instead of 0-based). Thus
    # we add 1 to the .cut data to get the .clu data

    data = np.asarray(data).astype(int)  # ensuring that the data is the integer data-type

    data += 1  # making the data 1-based instead of 0-based

    # calculating the number of clusters
    n_clust = len(np.unique(data))

    # ensuring that the cluster number is the 1st value
    data = np.concatenate(([n_clust], data))

    # saving the data as a column (delimter='\n') and integer format.
    np.savetxt(clu_filename, data, fmt='%d', delimiter='\n')

In [None]:
import sys
import pyqtgraph as pg
from PyQt5 import QtCore, QtGui, QtWidgets
from pyqtgraph.widgets.MatplotlibWidget import MatplotlibWidget
from gebaSpike.core.gui_utils import validate_session, Communicate, validate_cut, find_tetrodes
from gebaSpike.core.default_parameters import project_name, default_filename, defaultXAxis, defaultYAxis, defaultZAxis, openGL, \
    default_move_channel, max_spike_plots, alt_action_button
# from core.Tint_Matlab import find_tetrodes
from gebaSpike.core.plot_functions import plot_session, cut_cell, get_index_from_roi
from gebaSpike.core.waveform_cut_functions import moveToChannel, maxSpikesChange
from gebaSpike.core.undo import undo_function
from gebaSpike.core.feature_plot import feature_name_map
from gebaSpike.core.PopUpCutting import PopUpCutWindow
from gebaSpike.core.writeCut import write_cut, write_clu
import pyqtgraph.opengl as gl
import os
import json
import time
import numpy as np


version = "1.1.0"


class gebaSpikeObj():

    def __init__(self):
        """
        initializes many of the variables
        """

        # initializing attributes
        self.plotted_tetrode = None
        self.change_set_with_tetrode = True
        self.multiple_files = False
        self.cut_filename = None
        self.choose_cut_filename_btn = None
        self.feature_win = None
        self.quit_btn = None
        self.filename = None
        self.choose_filename_btn = None
        self.x_axis_cb = None
        self.y_axis_cb = None
        self.z_axis_cb = None
        self.tetrode_cb = None
        self.choice = None  # the current choice for the error popups
        self.plot_btn = None
        self.feature_plot = None

        # initialize list of actions to undo
        self.latest_actions = {}

        # bool for if an action has been made, I suppose we could take the length of the actions attribute
        self.actions_made = False

        # bool representing if the user is dragging the mouse (for drawing the line segments on the graphs)
        self.drag_active = False

        # the graph index that was last dragged upon with the mouse
        self.last_drag_index = None

        self.feature_data = None
        self.tetrode_data = None
        self.tetrode_data_loaded = False
        self.cut_data = None
        self.cut_data_loaded = False
        self.cut_data_original = None
        self.spike_times = None
        self.scatterItem = None
        self.glViewWidget = None
        self.feature_plot_added = False
        self.samples_per_spike = None

        # keep a list of the positions that the cells are plotted in
        self.unit_positions = {}

        # not all the spikes are plotted at once, so we will keep a dict of which subsample is plotted
        self.cell_subsample_i = {}

        self.unit_drag_lines = {}
        self.active_ROI = []
        self.unit_data = {}

        self.xline = None
        self.yline = None
        self.zline = None

        self.max_spike_plots = None

        self.n_channels = None

        self.invalid_channel = None

        self.cell_indices = {}

        self.spike_colors = None

        self.original_cell_count = {}

        self.unit_plots = {}
        self.vb = {}
        self.plot_lines = {}
        self.avg_plot_lines = {}
        self.unit_rows = 0
        self.unit_cols = 0

    def save_function(self):
        """
        this method will save the .cut file
        :return:
        """
        save_filename = os.path.realpath(self.cut_filename.text())

        # organize the cut data
        n_spikes_expected = self.tetrode_data.shape[1]
        n_spikes = len(np.asarray([item for sublist in self.cell_indices.values() for item in sublist]))

        # check that with the manipulation of the spikes, that we still have the correct number of spikes
        if n_spikes != n_spikes_expected:
            self.choice = None
            self.LogError.signal.emit('cutSizeError')
            while self.choice is None:
                time.sleep(0.1)
            return

        # we will check if we are missing some of the spikes somehow. If we kept track of them, then the indices from
        # the spikes, when sorted, should produce an array from 0 -> N-1 spikes.
        if not np.array_equal(np.sort(np.asarray([item for sublist in self.cell_indices.values() for item in sublist])),
                          np.arange(len(self.cut_data_original))):
            self.choice = None
            self.LogError.signal.emit('cutIndexError')
            while self.choice is None:
                time.sleep(0.1)
            return

        cut_values = np.zeros(n_spikes)
        for cell, cell_indices in self.cell_indices.items():
            cut_values[cell_indices] = cell

        if '.clu.' in save_filename:
            # save the .clu filename
            write_clu(save_filename, cut_values)
            self.choice = None
            self.LogError.signal.emit('saveCompleteClu')
            while self.choice is None:
                time.sleep(0.1)
            self.actions_made = False

        else:
            # save the cut filename
            write_cut(save_filename, cut_values)
            self.choice = None
            self.LogError.signal.emit('saveComplete')
            while self.choice is None:
                time.sleep(0.1)
            self.actions_made = False

    def set_cut_filename(self):
        """
        When you choose a session filename, it will trigger this function which will automatically set the .cut
        filename for the user.
        :return:
        """

        if not self.multiple_files:
            filename = self.filename.text()

            try:
                tetrode = int(self.tetrode_cb.currentText())
            except ValueError:
                return

            cut_filename = '%s_%d.cut' % (
                os.path.splitext(filename)[0], tetrode)
            clu_filename = '%s.clu.%d' % (
                os.path.splitext(filename)[0], tetrode)
            if os.path.exists(cut_filename):
                self.cut_filename.setText(os.path.realpath(cut_filename))
            elif os.path.exists(clu_filename):
                self.cut_filename.setText(os.path.realpath(clu_filename))

    def tetrode_changed(self):
        """
        Upon changing of the tetrode drop-menu, the .cut file will also need to be changed (trigger that change)
        :return:
        """
        # we will update the cut_filename
        if self.change_set_with_tetrode:
            self.set_cut_filename()
            # self.reset_parameters()

    def filename_changed(self):
        """
        This method will run when the filename LineEdit has been changed.
        It will essentially find the active tetrodes and populate the drop-down menu.
        """

        filename = self.filename.text()

        if self.multiple_files:
            filename = filename.split(', ')

        else:
            filename = [filename]

        # ensure that the files exist
        for file in filename:
            if not os.path.exists(file):
                return

        tetrodes = []

        self.tetrode_cb.clear()

        tetrode_list = find_tetrodes(self, self.filename.text())

        # get the extension value (excluding the .) so we can create a list of tetrode integers

        for file in tetrode_list:
            tetrode = os.path.splitext(file)[-1][1:]
            tetrodes.append(tetrode)

        # make a list of added tetrodes
        added_tetrodes = []
        for tetrode in sorted(tetrodes):
            # check if the tetrode value has been added already
            if tetrode in added_tetrodes:
                # continue if already added
                continue

            # add the item to the list containing the tetrode value
            self.tetrode_cb.addItem(tetrode)

            # add the tetrode value to the added_tetrodes list
            added_tetrodes.append(tetrode)

        # set the cut_filename
        self.set_cut_filename()

  self.feature_win = pg.GraphicsWindow()
  self.unit_win = pg.GraphicsWindow()


In [19]:
from gebaSpike.core.Tint_Matlab import is_tetrode, read_clu
from collections import Counter


def find_tetrodes(set_fullpath):
    """finds the tetrode files available for a given .set file if there is a  .cut file existing.
    if multiple set files were provided, then we will find the tetrode values that overlap for the both of them."""

    set_files = [set_fullpath]

    num_files = len(set_files)

    tetrode_files = {}

    # finds all the tetrode files
    for file in set_files:
        tetrode_path, session = os.path.split(file)
        session, _ = os.path.splitext(session)

        # getting all the files in that directory
        file_list = os.listdir(tetrode_path)

        # acquiring only a list of tetrodes that belong to that set file
        tetrode_list = [os.path.join(tetrode_path, file) for file in file_list
                        if is_tetrode(file, session)]

        # if the .cut or .clu.X file doesn't exist remove from list
        tetrode_list = [file for file in tetrode_list if (
            os.path.exists(
                os.path.join(tetrode_path, '%s_%s.cut' % (
                    os.path.splitext(file)[0], os.path.splitext(file)[1][1:]))) or
            os.path.exists(
                os.path.join(tetrode_path, '%s.clu.%s' % (
                    os.path.splitext(file)[0], os.path.splitext(file)[1][1:])))
        )]

        tetrode_files[file] = tetrode_list

    # count the files to ensure that we have the same amount of tetrode files as we do sessions

    tetrode_count = Counter()
    for tet_files in tetrode_files.values():
        for file in tet_files:
            ext = os.path.splitext(file)[-1]

            if ext in tetrode_count:
                tetrode_count[ext] += 1
            else:
                tetrode_count[ext] = 1

    # we will only include the tetrodes that are existing across all sessions provided
    tetrode_list = []
    for ext in sorted(tetrode_count):
        if tetrode_count[ext] == num_files:
            # this could likely be optimized, but I figure it won't take too long to iterate through these files anyways
            for values in tetrode_files.values():
                ext_files = [file for file in values if os.path.splitext(file)[-1] == ext]
                tetrode_list.extend(ext_files)

    return tetrode_list

In [20]:
set_fullpath = Path('/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/Axona_Tint_1ms/20201004_Tint.set')

tetrode_list = find_tetrodes(set_fullpath)

In [21]:
tetrode_list

['/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/Axona_Tint_1ms/20201004_Tint.1',
 '/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/Axona_Tint_1ms/20201004_Tint.2',
 '/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/Axona_Tint_1ms/20201004_Tint.3',
 '/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/Axona_Tint_1ms/20201004_Tint.4']

<a id="Convert_recording_extractor_to_tint"></a>
## Convert Recording extractor to TINT
[back to index](#index)

Hmm, since we are writing to TINT, thereby creating `.X` tetrode files, we throw away all information in-between spikes. There is no point to convert the fake continuous recording used for spike sorting to TINT at all. We really only want to export the spike sorting output!

In [None]:
# Anything to do here?

<a id="Convert_sorting_extractor_to_tint"></a>
## Convert Sorting extractor to TINT
[back to index](#index)

There are several points in the pipeline at which we might want to export to TINT. Ideally it should work for any `SortingExtractor` object!

In [43]:
print('Where do we load data from?\n\n', dir_name)

Where do we load data from?

 /mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/sample_bin_to_tint_no_bin


From a sorting extractor we can obtain a list unit spike sample arrays. We can convert this to the .clu or .cut type array of unit ID labels for each spike.


In [40]:
def convert_spike_train_to_label_array(spike_train):
    '''Takes a list of arrays, where each array is a series of
    sample points at which a spike occured for a given unit
    (each list item is a unit). Converts to .cut array, i.e.
    orders spike samples from all units and labels each sample
    with the appropriate unit ID.
    
    Parameters
    ----------
    spike_train : List of np.arrays
        Output of `get_units_spike_train()` method of sorting extractor
        
    Return
    ------
    unit_labels_sorted : np.array
        Each entry is the unit ID corresponding to the spike sample that
        occured at this ordinal position
    '''

    # Generate Index array (indexing the unit for a given spike sample)
    unit_labels = []
    for i, l in enumerate(spike_train):
        unit_labels.append(np.ones((len(l),), dtype=int) * i)
    
    # Flatten lists and sort them
    spike_train_flat = np.concatenate(spike_train).ravel()
    unit_labels_flat = np.concatenate(unit_labels).ravel()

    sort_index = np.argsort(spike_train_flat)

    unit_labels_sorted = unit_labels_flat[sort_index]

    return unit_labels_sorted

In [85]:
cut_filename = Path('/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/Axona_Tint_1ms/20201004_Tint_1.cut')

basename = os.path.basename(os.path.splitext(cut_filename)[0])

print(basename)

20201004_Tint_1


In [66]:
def write_to_cut_file(cut_filename, unit_labels):
    '''Write spike sorting output to .cut file.
    
    Parameters
    ----------
    cut_filename : str or Path
        Full filename of .cut file to write to. A given .cut file belongs
        to a given tetrode file. For example, for tetrode `my_file.1`, the
        corresponding cut_filename should be `my_file_1.cut`.
    unit_labels : np.array
        Vector of unit labels for each spike sample (ordered by time of 
        occurence)
        
    Example
    -------
    # Given a sortingextractor called sorting_nwb:
    spike_train = sorting_nwb.get_units_spike_train()
    unit_labels = convert_spike_train_to_label_array(spike_train)
    write_to_cut_file(cut_filename, unit_labels)
    
    ---
    Largely based on gebaSpike implementation by Geoff Barrett
    https://github.com/GeoffBarrett/gebaSpike
    '''

    unique_cells = np.unique(unit_labels)

    n_clusters = len(np.unique(unit_labels))
    n_spikes = len(unit_labels)

    write_list = []

    tab = '    '
    empty_space = '               '

    write_list.append('n_clusters: %d\n' % (n_clusters))
    write_list.append('n_channels: 4\n')
    write_list.append('n_params: 2\n')
    write_list.append('times_used_in_Vt:%s' % ((tab + '0') * 4 + '\n'))

    zero_string = (tab + '0') * 8 + '\n'

    for cell_i in np.arange(n_clusters):
        write_list.append(' cluster: %d center:%s' % (cell_i, zero_string))
        write_list.append('%smin:%s' % (empty_space, zero_string))
        write_list.append('%smax:%s' % (empty_space, zero_string))
    write_list.append('\nExact_cut_for: %s spikes: %d\n' % (basename, n_spikes))

    # The unit label array consists of 25 values per row in .cut file
    n_rows = int(np.floor(n_spikes / 25))
    remaining = int(n_spikes - n_rows * 25)

    cut_string = ('%3u' * 25 + '\n') * n_rows + '%3u' * remaining

    write_list.append(cut_string % (tuple(unit_labels)))

    with open(cut_filename, 'w') as f:
        f.writelines(write_list)

In [84]:
unit_ids = sorting_nwb.get_unit_ids()
print('Unit ids:', unit_ids)

tetrode_id = sorting_nwb.get_units_property(property_name='group')
print('Tetrode ids:', tetrode_id)

Unit ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
Tetrode ids: [0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]


In [164]:
def set_cut_filename_from_basename(filename, tetrode_id):
    '''Given a str or Path object, assume the last entry after a slash
    is a filename, strip any file suffix, add tetrode ID label, and
    .cut suffix to name.
    
    Parameters
    ----------
    filename : str or Path
    tetrode_id : int
    '''
    return Path(str(filename).split('.')[0] + '_{}'.format(tetrode_id) + '.cut')

In [161]:

filename = Path('/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/Axona_Tint_1ms/20201004_Tint.set')
print(filename)

Path(str(filename.with_suffix('')) + '_{}'.format(1) + '.cut')

/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/Axona_Tint_1ms/20201004_Tint.set


PosixPath('/mnt/d/freelance-work/catalyst-neuro/hussaini-lab-to-nwb/Axona_Tint_1ms/20201004_Tint_1.cut')

In [168]:
# We want to write .cut files separately for each tetrode, so we need
# to get the spike_trains separately for each tetrode!

def convert_sorting_extractor_to_cut(sorting_extractor, filename):
    '''Write spike sorting output to .cut file, separately for each
    tetrode.
    
    Parameters
    ----------
    sorting_extractor : spikeextractors.SortingExtractor
    filename : str or Path
        Full filename of .set file or base-filename (i.e. the part of the
        filename all Axona files have in common). A given .cut file belongs
        to a given tetrode file. For example, for tetrode `my_file.1`, the
        corresponding cut_filename should be `my_file_1.cut`. This will be
        set automatically given the base-filename or set file.
        
    TODO: Any reason one might want to only convert some tetrodes or some
    samples? Should those be parameters?
    '''
    tetrode_ids = sorting_extractor.get_units_property(property_name='group')
    tetrode_ids = np.array(tetrode_ids)
    
    unit_ids = np.array(sorting_extractor.get_unit_ids())
    
    for i in np.unique(tetrode_ids):
        
        print('Converting Tetrode {}'.format(i))

        spike_train = sorting_extractor.get_units_spike_train(unit_ids=unit_ids[tetrode_ids==i])
        unit_labels = convert_spike_train_to_label_array(spike_train)

        # We use Axona conventions filenames (tetrodes are 1 indexed)
        cut_filename = set_cut_filename_from_basename(filename, i + 1)
        write_to_cut_file(cut_filename, unit_labels)

In [169]:
convert_sorting_extractor_to_cut(sorting_nwb, filename)

Converting Tetrode 0
Converting Tetrode 1
Converting Tetrode 2
Converting Tetrode 3


In [None]:
# Given a sortingextractor called sorting_nwb:
spike_train = sorting_nwb.get_units_spike_train()
unit_labels = convert_spike_train_to_label_array(spike_train)
write_to_cut_file(cut_filename, unit_labels)

In [72]:
sorting_nwb.get_shared_unit_property_names()

['firing_rate',
 'group',
 'halfwidth',
 'isi_violation',
 'max_channel',
 'peak_to_valley',
 'peak_trough_ratio',
 'recovery_slope',
 'repolarization_slope',
 'snr',
 'template']

[0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]

In [32]:
# We saved sorting output as mountinsort `.mda` file

sorting_MS4 = se.MdaSortingExtractor(
    file_path=os.path.join(dir_name, 'mountainsort4.mda'),
    sampling_frequency=48000
)
print('Unit ids = {}'.format(sorting_MS4.get_unit_ids()))
spike_train = sorting_MS4.get_unit_spike_train(unit_id=1)
print('Num. events for unit 1 = {}'.format(len(spike_train)))
spike_train1 = sorting_MS4.get_unit_spike_train(unit_id=1, start_frame=0, end_frame=30000)
print('Num. events for first second of unit 1 = {}'.format(len(spike_train1)))

Unit ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
Num. events for unit 1 = 101
Num. events for first second of unit 1 = 59


In [35]:
# We also have sorting data exported in `.nwb` format

nwb_dir = Path(dir_name, 'nwb')
sorting_nwb = se.NwbSortingExtractor(nwb_dir / 'axona_se_MS4.nwb', sampling_frequency=48000)

In [36]:
print(type(sorting_MS4))
print(type(sorting_nwb))

<class 'spikeextractors.extractors.mdaextractors.mdaextractors.MdaSortingExtractor'>
<class 'spikeextractors.extractors.nwbextractors.nwbextractors.NwbSortingExtractor'>


In [37]:
sorting_nwb

<spikeextractors.extractors.nwbextractors.nwbextractors.NwbSortingExtractor at 0x7f7ce3b9c910>

In [38]:
sorting_MS4

<spikeextractors.extractors.mdaextractors.mdaextractors.MdaSortingExtractor at 0x7f7ce3fcaa30>

I am not sure how we should best package the conversion to TINT in the end, for now we just need the different conversion methods!

In [46]:
sorting_nwb.get_sampling_frequency()

48000

In [None]:
def write_to_tetrode_file(sorting_extractor, save_dir):
    '''Given a sorting extractor object create .X (tetrode) files.
    
    Parameters
    ----------
    sorting_extractor : spikeextractors.SortingExtractor
    save_dir : str or Path
        Directory where to save the output
    '''

In [None]:
def write_to_clu_file(sorting_extractor, save_dir):
    '''Given a sorting extractor object create .clu files.
    
    Parameters
    ----------
    sorting_extractor : spikeextractors.SortingExtractor
    save_dir : str or Path
        Directory where to save the output
    '''
    

In [None]:
def write_to_cut_file(sorting_extractor, save_dir):
    '''Given a sorting extractor object create .cut files.
    
    Parameters
    ----------
    sorting_extractor : spikeextractors.SortingExtractor
    save_dir : str or Path
        Directory where to save the output
    '''

In [None]:
def write_to_tint(sorting_extractor, save_dir):
    '''Given a sorting extractor object, write appropriate data
    to TINT format (from Axona). Will therefore create .X (tetrode)
    and .clu (spike sorting information) files.
    
    Parameters
    ----------
    sorting_extractor : spikeextractors.SortingExtractor
    save_dir : str or Path
        Directory where to save the output
    '''
    
    write_to_tetrode_file(sorting_extractor, save_dir)
    write_to_clu_file(sorting_extractor, save_dir)