In [None]:
from spks import *
# from pathlib import Path
folder = Path('/scratch/ks25_sorting_20230905_131941_tr9v')

with open(folder/'params.py','r') as f:
    params = f.read()
params = params.split('\n')
for i,p in enumerate(params):
    if p.startswith('dat_path'):
        params[i] = "dat_path = 'filtered_recording.ap.bin'"
with open(folder/'params.py','w') as f:
    f.write("\n".join(params))


In [None]:
%matplotlib widget

from spks import *

class Clusters():
    def __init__(self,folder,spike_times = None, 
                spike_clusters = None,
                channels_positions = None, compute_raw_templates=True):#, remove_duplicate_spikes = False):
        '''Object to access the spike sorting results in phy'''
        if type(folder) in [str]:
            folder = Path(folder)
        self.folder = folder
        # load spiketimes
        self.spike_times = self._load_required('spike_times.npy')
        # load each spike cluster number
        self.spike_clusters = self._load_required('spike_clusters.npy')
        self.unique_clusters = np.sort(np.unique(self.spike_clusters))
        # load the channel locations
        self.channel_positions =  self._load_optional('channel_positions.npy')
        self.channel_map =  self._load_optional('channel_map.npy')
        # Load the amplitudes used to fit the template
        self.spike_template_amplitudes = self._load_optional('amplitudes.npy')
        # load spike templates (which template was fitted) for each spike
        self.spike_templates = self._load_optional('spike_templates.npy')
        # load the templates used to extract the spikes
        self.templates =  self._load_optional('templates.npy')
        # load the whitening matrix (to correct for the data having been whitened)
        self.whitening_matrix = self._load_optional('whitening_mat_inv.npy')
        if not self.whitening_matrix is None:
            self.whitening_matrix = self.whitening_matrix.T
        self.cluster_groups = self._load_optional('cluster_group.tsv')

        # compute the raw templates and the position of each cluster based on the template position
        if compute_raw_templates:
            self._compute_template_amplitudes()

        #if remove_duplicate_spikes:
        #    self.remove_duplicate_spikes()

    def __getitem__(self, index):
        ''' returns the spiketimes for a set of clusters'''
        if type(index) in [int,np.int64,np.int32]:
            index = [index]
        if type(index) in [slice]:
            index = np.arange(*index.indices(len(self)))
        sp = []
        for iclu in self.unique_clusters[index]:
            sp.append(self.spike_times[self.spike_clusters == iclu])
        if len(sp) == 1:
            return sp[0]
        else:
            return sp

    def __len__(self):
        return len(self.unique_clusters)

    def __iter__(self):
        for iclu in self.unique_clusters:
            yield self.spike_times[self.spike_clusters == iclu]


    def remove_duplicate_spikes(self,overwrite_phy = False):
        from spks.postprocess import get_overlapping_spikes_indices
        doubled = get_overlapping_spikes_indices(self.spike_times,self.spike_clusters, self.templates_raw, self.channel_positions)
        if not len(doubled):
            return
        self.spike_times = np.delete(self.spike_times,doubled)
        self.spike_clusters = np.delete(self.spike_clusters,doubled)
        
        if not self.spike_amplitudes is None:
            self.spike_amplitudes = np.delete(self.spike_amplitudes,doubled)
        if not self.spike_positions is None:
            self.spike_positions = np.delete(self.spike_positions,doubled)
        if not self.spike_templates is None:
            self.spike_templates = np.delete(self.spike_templates,doubled)
        if not self.spike_template_amplitudes is None:
            self.spike_template_amplitudes = np.delete(self.spike_template_amplitudes,doubled)
        if overwrite_phy:
            self.export_phy(self.folder)

    def export_phy(self,folder):
        if type(folder) is str:
            folder = Path(folder)
        np.save(folder/'spike_times.npy',self.spike_times)
        np.save(folder/'spike_clusters.npy',self.spike_clusters)
        if not self.spike_template_amplitudes is None:
            np.save(folder/'amplitudes.npy', self.spike_template_amplitudes)
        if not self.spike_templates is None:
            np.save(folder/'spike_templates.npy', self.spike_templates)
    
    def _compute_template_amplitudes(self):
        self.templates_raw = None
        self.templates_amplitude = None
        self.templates_position = None
        self.spike_amplitudes = None
        self.spike_positions = None
        if (not self.templates is None and 
            not self.whitening_matrix is None and 
            not self.channel_positions is None):
            # the raw templates are the dot product of the templates by the whitening matrix
            self.templates_raw = np.dot(self.templates,self.whitening_matrix)
            # compute the peak to peak of each template
            templates_peak_to_peak = (self.templates_raw.max(axis = 1) - self.templates_raw.min(axis = 1))
            # the amplitude of each template is the max of the peak difference for all channels
            self.templates_amplitude = templates_peak_to_peak.max(axis=1)
            # compute the center of mass (X,Y) of the templates
            self.template_position = [templates_peak_to_peak*pos for pos in self.channel_positions.T]
            self.template_position = np.vstack([np.sum(t,axis =1 )/np.sum(templates_peak_to_peak,axis = 1) 
                                                for t in self.template_position]).T
            # get the spike positions and amplitudes from the average templates
            self.spike_amplitudes = self.templates_amplitude[self.spike_templates]*self.spike_template_amplitudes
            self.spike_positions = self.template_position[self.spike_templates,:].squeeze()

    def _load_required(self,file,var = None):
        path = self.folder / file
        assert path.exists(), f'[SortedSpikes] - {path} doesnt exist'
        return np.load(path)

    def _load_optional(self,file):
        path = self.folder / file
        if path.exists():
            if path.suffix == '.npy':
                return np.load(path)
            elif path.suffix == '.tsv':
                return pd.read_csv(path,sep = '\t')
        return None

    # class timestamps():
    #     def __init__(self)

folder = Path('/scratch/ks25_sorting_20230905_131941_tr9v')

sp = Clusters(folder,remove_duplicate_spikes=True)


In [None]:

# folder
dat2 = map_binary(folder/'filtered_recording.ap.bin',meta['nchannels'])
meta = load_dict_from_h5(folder/'filtered_recording.ap.metadata.hdf')
meta['file_offsets'],dat.file_sample_offsets,dat.shape,dat2.shape

In [None]:

folder = Path('/scratch/ks25_sorting_20230905_131941_tr9v')
sessionfiles = ['/home/data/JC131/20230901_113844/ephys_g0/ephys_g0_imec0/ephys_g0_t0.imec0.ap.bin',
                '/home/data/JC131/20230901_115632/ephys_g1/ephys_g1_imec0/ephys_g1_t0.imec0.ap.bin']

dat = RawRecording(sessionfiles,return_preprocessed = False)

meta = load_dict_from_h5(folder/'filtered_recording.ap.metadata.hdf')
dat2 = map_binary(folder/'filtered_recording.ap.bin',meta['nchannels'])

sp = Clusters(folder)

mwaves = extract_waveform_set(spike_times = sp,data = dat2,chmap = dat.channel_info.channel_idx.values,max_n_spikes=1000, chunksize=10)

In [None]:
waveforms = {}
for iclu,w in zip(sp.unique_clusters,mwaves):
    waveforms[iclu] = w

import h5py


save_dict_to_h5(folder/'cluster_waveforms.hdf',waveforms)


In [None]:


waveforms = load_dict_from_h5(folder/'cluster_waveforms.hdf')

In [None]:
# %matplotlib widget 
import pylab as plt
from spks import *

plt.figure()


iclus = 200
for mw in mwaves[iclus][:20]:
    plot_footprints(waves = ,
                    channel_xy = np.stack(dat.channel_info.channel_coord.values), gain=[15,0.07],lw = 0.1);

plot_footprints(waves = mwaves[iclus].mean(axis=0),
                    channel_xy = np.stack(dat.channel_info.channel_coord.values), gain=[15,0.07],lw = 1,color = 'r');

In [None]:

pos, peak = waveforms_position(sp.templates_raw,sp.channel_positions)
peak_to_peak = (sp.templates_raw.max(axis = 1) - sp.templates_raw.min(axis = 1)).max(axis=1)

###
import pylab as plt
%matplotlib widget
plt.figure()
plt.plot(sp.channel_positions[:,0],sp.channel_positions[:,1],'o',color='lightgray')
plt.scatter(pos[20,0],pos[20,1],30,peak_to_peak[20],alpha = 0.5,cmap='hot')
plt.colorbar()
plt.plot(sp.channel_positions[peak[20],0],sp.channel_positions[peak[20],1],'x')
from spks.viz import plot_footprints
plot_footprints(sp.templates_raw[20],sp.channel_positions,gain=[5,0.3]);
plt.show()

In [None]:
sp['unit selection','probe','shank','unit']

In [None]:
spike_clusters = sp.clusters
spike_times = sp.spike_times


In [None]:
sp.templates_raw.shape

In [None]:


get_overlapping_spikes_indices(sp.spike_times,sp.clusters, sp.templates_raw, sp.channel_positions)

In [None]:
list(folder.glob("*"))

In [None]:
from spks.sorting import run_ks25
sessionfiles = ['/home/data/JC131/20230901_113844/ephys_g0/ephys_g0_imec0/ephys_g0_t0.imec0.ap.bin',
                   '/home/data/JC131/20230901_115632/ephys_g1/ephys_g1_imec0/ephys_g1_t0.imec0.ap.bin']

run_ks25(sessionfiles)
