In [9]:
import numpy as np
import pandas as pd
import os
import joblib
from elephant.gpfa import GPFA
import quantities as pq
import neo
import matplotlib.pyplot as plt
from sklearn.model_selection import cross_val_score
from kneed import KneeLocator

ks_dir = 'Z:/KH282_BL6/ephys/061921/2021-06-19_17-26-36/Record Node 101/experiment1/recording1/continuous/Neuropix-PXI-100.0/'
    
class load_kilsort:

    def __init__(self, path):
        self.path = path

    def load_spike_times(self):
        return np.load(str.join('', (self.path, '/spike_times.npy')))

    def load_spike_clusters(self):
        return np.load(str.join('', (self.path, '/spike_clusters.npy')))

    def load_cluster_labels(self):
        return pd.read_csv(str.join('', (self.path, '/cluster_group.tsv')), sep="\t")

    def load_cluster_info(self):
        return pd.read_csv(str.join('', (self.path, '/cluster_info.tsv')), sep="\t")

    def load_channel_map(self):
        return np.load(str.join('', (self.path, '/channel_map.npy')))

    def load_templates(self):
        return np.load(str.join('', (self.path, '/templates_ind.npy')))
    
    
class kilosort2_spike_trains:

    def __init__(self):
        self.path = ks_dir
        spike_loader = load_kilsort(self.path)
        self.spike_times = spike_loader.load_spike_times()
        self.clusters = spike_loader.load_spike_clusters()
        self.cluster_labels = spike_loader.load_cluster_labels()
        self.channel_map = spike_loader.load_channel_map()
        self.cluster_info = spike_loader.load_cluster_info()
        self.templates = spike_loader.load_templates()
        self.good_clusters = []
        try:
            with open(os.path.join(self.path, '_removed_neuron.sav'), 'rb') as fr:
                removed_neurons = joblib.load(fr)
                self.removed_neurons = removed_neurons[0]
        except FileNotFoundError:
            self.removed_neurons = []

    def find_neurons(self):
        mua = 'No'
        if mua == 'Yes':
            for i, neuron in enumerate(np.unique(self.cluster_labels['cluster_id'])):
                if self.cluster_labels['group'][i] == 'good' or self.cluster_labels['group'][i] == 'mua' \
                        and neuron not in self.removed_neurons:
                    self.good_clusters.append(neuron)
        elif mua == 'No':
            for i, neuron in enumerate(np.unique(self.cluster_labels['cluster_id'])):
                if self.cluster_labels['group'][i] == 'good':
                    self.good_clusters.append(neuron)

    def show_templates(self):
        fig, axs = plt.subplots(len(self.good_clusters) // 5 + 1, 5, figsize=(10, 8))
        axs = axs.ravel()
        for i, neuron in enumerate(self.good_clusters):
            axs[i].plot(self.templates[neuron][:, np.where(
                self.channel_map[:][0] == self.cluster_info['ch'][np.where(self.cluster_info['id'] == neuron)[
                                              0][0]])[0][0]], c='black', label=neuron)
            axs[i].text(1, 0, '{}'.format(self.cluster_info['fr'][np.where(self.cluster_info['id'] == neuron)[
                                              0][0]]), horizontalalignment='center',
                        verticalalignment='center', transform=axs[i].transAxes)
            axs[i].legend()
        plt.show()

    def main(self):
        self.find_neurons()
#         self.show_templates()
        return self.spike_times, self.clusters, self.good_clusters
 
    
    



In [10]:
spike_trains_loader = kilosort2_spike_trains()
spike_times, clusters, good_clusters = spike_trains_loader.main()

In [15]:
spike_loader = load_kilsort(ks_dir)
cluster_info = spike_loader.load_cluster_info()
print(cluster_info)

      id  Amplitude  ContamPct KSLabel         amp   ch   depth        fr  \
0      0     1304.3      100.0     mua   28.357384    1    20.0  0.006937   
1      1     1427.6      100.0     mua   46.117104    0    20.0  0.774558   
2      2     1425.0      100.0     mua   43.346992    9   100.0  0.138180   
3      3     1250.1      100.0     mua   40.695259    7    80.0  1.277241   
4      4     1459.2      100.0     mua   43.949116    9   100.0  1.951540   
..   ...        ...        ...     ...         ...  ...     ...       ...   
398  452     2456.3      100.0    good  103.900543   29   300.0  0.023724   
399  453     2900.8      100.0     mua  155.482651   93   940.0  1.272108   
400  454     1082.9      100.0    good   54.037693  303  3060.0  3.816741   
401  456     2233.6      100.0    good  100.561836  105  1060.0  0.047262   
402  457     3210.7      100.0     mua  173.500870   81   820.0  0.529552   

     group  n_spikes  sh  
0      NaN       150   0  
1    noise     16749 

In [17]:
for i, neuron in enumerate(good_clusters):
    print(i, neuron)

0 9
1 43
2 54
3 131
4 158
5 159
6 160
7 178
8 200
9 203
10 204
11 205
12 216
13 241
14 290
15 294
16 299
17 308
18 335
19 351
20 353
21 369
22 371
23 373
24 375
25 381
26 382
27 383
28 385
29 386
30 390
31 393
32 397
33 398
34 406
35 408
36 409
37 416
38 420
39 421
40 424
41 428
42 430
43 433
44 435
45 453
46 454
47 456
48 457


In [19]:
np.where(cluster_info['id'] == neuron)[0][0]

402

In [29]:
len(np.unique(cluster_info['id']))

403

In [27]:
templates = spike_loader.load_templates()
print(templates.shape)

(452, 82, 383)


In [30]:
channel_map = spike_loader.load_channel_map()

In [31]:
print(channel_map)

[[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
   18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
   36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
   54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71
   72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89
   90  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107
  108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
  126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
  144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
  162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
  180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
  198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
  216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
  234 235 236 237 238 239 240 241 242 