## Import and plot waveforms from Phy

In [None]:
#!pip install skm_pyutils

In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skm_pyutils import py_plot

In [None]:
path = Path("E:\Temp\Matheus\HUMAN 257_newslice2_new.GUI")
recording_path = Path("E:\Temp\Matheus\raw.bin")

In [None]:
amplitudes = np.load(path / "amplitudes.npy")
spike_times = np.load(path / "spike_times.npy")
spike_clusters = np.load(path / "spike_clusters.npy")
clusters = pd.read_csv(path / "cluster_group.tsv", delimiter="\t")
groups = clusters.loc[clusters.group == "good"].cluster_id.values


In [24]:
# Clusters marked as good
groups

array([  7,  13,  23,  26,  28,  30,  33,  35,  36,  38,  39,  44,  50,
        60,  66,  67,  73,  82,  84,  85,  86,  87,  91,  96,  98,  99,
       104, 105], dtype=int64)

In [None]:
# get spike times for group 7
sample_spike = spike_times[spike_clusters == groups[0]]
# Get a sample of 100 spikes
spk_sample = np.random.choice(sample_spike, size=100, replace=False)

In [None]:
spk_sample

In [None]:
def load_phy(raw_data_folder, sorting_folder):
    """Use spikeinterface to load a phy clustering."""
    import spikeinterface.extractors as se

    to_exclude = ["mua", "noise"]
    return se.SpykingCircusRecordingExtractor(raw_data_folder), se.PhySortingExtractor(
        sorting_folder, exclude_cluster_groups=to_exclude
    )


def load_phy_forms(recording_folder, sorting, cache_dir, **kwargs):
    """See extract_waveforms in spikeinterface for kwargs"""
    import spikeinterface as si

    ms_before = kwargs.pop("ms_before", 3.0)
    ms_after = kwargs.pop("ms_after", 4.0)
    max_spikes_per_unit = kwargs.pop("max_spikes_per_unit", 500)
    si.extract_waveforms(
        recording_folder,
        sorting,
        cache_dir,
        ms_before=ms_before,
        ms_after=ms_after,
        max_spikes_per_unit=max_spikes_per_unit,
        load_if_exists=False,
        overwrite=True,
        **kwargs
    )

# IM not sure if this will fully work in current form, probably needs modification
def plot_all_forms(sorting, waveforms, out_loc, channels_per_group=64):
    """Plot all waveforms from a spikeinterface sorting object."""
    unit_ids = sorting.get_unit_ids()
    wf_by_group = [waveforms.get_waveforms(u) for u in unit_ids]
    for i, wf in enumerate(wf_by_group):
        try:
            tetrode = sorting.get_unit_property(unit_ids[i], "group")
        except Exception:
            try:
                tetrode = sorting.get_unit_property(unit_ids[i], "ch_group")
            except Exception:
                print("Unable to find cluster group or group in units")
                print(sorting.get_shared_unit_property_names())
                return

        fig, axes = plt.subplots(channels_per_group)
        for j in range(channels_per_group):
            try:
                wave = wf[:, j, :]
            except Exception:
                wave = wf[j, :]
            axes[j].plot(wave.T, color="k", lw=0.3)
        o_loc = out_loc / "tet{}_unit{}_forms.png".format(tetrode, unit_ids[i])
        print("Saving waveform {} on tetrode {} to {}".format(i, tetrode, o_loc))
        fig.savefig(o_loc, dpi=200)
        plt.close("all")

# This should probably work though
def plot_all_templates(sorting, waveforms, out_loc):
    unit_ids = sorting.get_unit_ids()
    wf_by_group = [waveforms.template(u) for u in unit_ids]
    for wf, unit_id in zip(unit_ids, wf_by_group):
        fig, ax = plt.subplots()
        colors = py_plot.ColorManager(wf.shape[1], "rgb")
        # iterate over channels
        for i in range(wf.shape[1]):
            ax.plot(wf[:, i].T, color=colors.get_next_color(), lw=3)
        o_loc = out_loc / "channel{}_unit{}_forms.png".format(i, unit_id)
        print("Saving unit {} template to {}".format(unit_id, o_loc))
        fig.savefig(o_loc, dpi=200)
        plt.close(fig)

In [None]:
figure_dir = path / "figures"
(figure_dir).mkdir(exist_ok=True)
waveform_dir = path / "waveforms"

In [23]:
import spikeinterface.extractors as se
to_exclude = ["mua", "noise"]
sorting = se.PhySortingExtractor(
    path, exclude_cluster_groups=to_exclude
)

PhySortingExtractor: 28 units - 1 segments - 25.0kHz

In [30]:
print(sorting.get_unit_ids())
print(sorting.get_unit_spike_train(7))
print(spike_times[spike_clusters == 7])

[  7  13  23  26  28  30  33  35  36  38  39  44  50  60  66  67  73  82
  84  85  86  87  91  96  98  99 104 105]
[ 22566332  97155567 100627241 ... 167850112 167885734 168057746]
[ 22566332  97155567 100627241 ... 167850112 167885734 168057746]


In [None]:
recording, sorting = load_phy(recording_path, path)
waveforms = load_phy_forms(
    recording,
    sorting,
    waveform_dir,
    ms_before=1,
    ms_after=2,
    max_spikes_per_unit=600,
    n_jobs=1,
    chunk_size=30000,
)

#Probably won;t work
# plot_all_forms(sorting, waveforms, figure_dir, channels_per_group=4)

# might work
plot_all_templates(sorting, waveforms, figure_dir)