In [None]:
import numpy as np
import datetime
import os
from os.path import join as pjoin
from spks.raw import *


In [None]:
from spks.utils import *
from pathlib import Path
folder = Path('/scratch/ks25_sorting_20230905_131941_tr9v')

with open(folder/'params.py','r') as f:
    params = f.read()
params = params.split('\n')
for i,p in enumerate(params):
    if p.startswith('dat_path'):
        params[i] = "dat_path = 'filtered_recording.ap.bin'"
with open(folder/'params.py','w') as f:
    f.write("\n".join(params))



In [None]:
class SortedSpikes():
    def __init__(self,folder, compute_raw_templates=True, remove_duplicate_spikes = True):
        '''Object to access the spike sorting results in phy'''
        if type(folder) in [str]:
            folder = Path(folder)
        self.folder = folder
        # load spiketimes
        self.spike_times = self._load_required('spike_times.npy')
        # load each spike cluster number
        self.clusters = self._load_required('spike_clusters.npy')
        self.unique_clusters = np.sort(np.unique(self.clusters))
        # load the channel locations
        self.channel_positions =  self._load_optional('channel_positions.npy')
        self.channel_map =  self._load_optional('channel_map.npy')
        # Load the amplitudes used to fit the template
        self.spike_template_amplitudes = self._load_optional('amplitudes.npy')
        # load spike templates (which template was fitted) for each spike
        self.spike_templates = self._load_optional('spike_templates.npy')
        # load the templates used to extract the spikes
        self.templates =  self._load_optional('templates.npy')
        # load the whitening matrix (to correct for the data having been whitened)
        self.whitening_matrix = self._load_optional('whitening_mat_inv.npy')
        if not self.whitening_matrix is None:
            self.whitening_matrix = self.whitening_matrix.T
        self.cluster_groups = self._load_optional('cluster_group.tsv')

        # compute the raw templates and the position of each cluster based on the template position
        if compute_raw_templates:
            self._compute_template_amplitudes()

        if remove_duplicate_spikes:
            self._remove_duplicate_spikes()

    def __getitem__(self, index):
        ''' returns the spiketimes for a set of clusters'''
        if type(index) in [int,np.int64,np.int32]:
            index = [index]
        if type(index) in [slice]:
            index = np.arange(*index.indices(len(self)))
        sp = []
        for iclu in self.unique_clusters[index]:
            sp.append(self.spike_times[self.clusters == iclu])
        if len(sp) == 1:
            return sp[0]
        else:
            return sp

    def __len__(self):
        return len(self.unique_clusters)

    def __iter__(self):
        for iclu in self.unique_clusters:
            yield self.spike_times[self.clusters == iclu]


    def _remove_duplicate_spikes(self):
        get_overlapping_spikes_indices(sp.spike_times,sp.clusters, sp.templates_raw, sp.channel_positions)
    
    def _compute_template_amplitudes(self):
        self.templates_raw = None
        self.templates_amplitude = None
        self.templates_position = None
        self.spike_amplitudes = None
        self.spike_positions = None
        
        if (not self.templates is None and 
            not self.whitening_matrix is None and 
            not self.channel_positions is None):
            # the raw templates are the dot product of the templates by the whitening matrix
            self.templates_raw = np.dot(self.templates,self.whitening_matrix)
            # compute the peak to peak of each template
            templates_peak_to_peak = (self.templates_raw.max(axis = 1) - self.templates_raw.min(axis = 1))
            # the amplitude of each template is the max of the peak difference for all channels
            self.templates_amplitude = templates_peak_to_peak.max(axis=1)
            # compute the center of mass (X,Y) of the templates
            self.template_position = [templates_peak_to_peak*pos for pos in self.channel_positions.T]
            self.template_position = np.vstack([np.sum(t,axis =1 )/np.sum(templates_peak_to_peak,axis = 1) 
                                                for t in self.template_position]).T
            # get the spike positions and amplitudes from the average templates
            self.spike_amplitudes = self.templates_amplitude[self.spike_templates]*self.spike_template_amplitudes
            self.spike_positions = self.template_position[self.spike_templates,:].squeeze()

    def _load_required(self,file):
        path = self.folder / file
        assert path.exists(), f'[SortedSpikes] - {path} doesnt exist'
        return np.load(path)

    def _load_optional(self,file):
        path = self.folder / file
        if path.exists():
            if path.suffix == '.npy':
                return np.load(path)
            elif path.suffix == '.tsv':
                return pd.read_csv(path,sep = '\t')
        return None

    # class timestamps():
    #     def __init__(self)

sp = SortedSpikes(folder)



In [None]:

pos, peak = waveforms_position(sp.templates_raw,sp.channel_positions)
peak_to_peak = (sp.templates_raw.max(axis = 1) - sp.templates_raw.min(axis = 1)).max(axis=1)

###
import pylab as plt
%matplotlib widget
plt.figure()
plt.plot(sp.channel_positions[:,0],sp.channel_positions[:,1],'ko',color='lightgray')
plt.scatter(pos[20,0],pos[20,1],30,peak_to_peak[20],alpha = 0.5,cmap='hot')
plt.colorbar()
plt.plot(sp.channel_positions[peak[20],0],sp.channel_positions[peak[20],1],'x')
from spks.viz import plot_footprints
plot_footprints(sp.templates_raw[20],sp.channel_positions,gain=[5,0.3]);

In [None]:
sp['unit selection','probe','shank','unit']

In [None]:
spike_clusters = sp.clusters
spike_times = sp.spike_times


In [None]:
sp.templates_raw.shape

In [None]:


get_overlapping_spikes_indices(sp.spike_times,sp.clusters, sp.templates_raw, sp.channel_positions)

In [None]:
list(folder.glob("*"))

In [None]:
from spks.sorting import run_ks25
sessionfiles = ['/home/data/JC131/20230901_113844/ephys_g0/ephys_g0_imec0/ephys_g0_t0.imec0.ap.bin',
                   '/home/data/JC131/20230901_115632/ephys_g1/ephys_g1_imec0/ephys_g1_t0.imec0.ap.bin']

run_ks25(sessionfiles)
