<CENTER><H1>Spike classification based on extracellular waveforms</H1></CENTER>

<HR>   
<img src="https://raw.githubusercontent.com/JoseGuzman/minibrain/master/img/spikes.png">

<P>This notebook classifies different electrical waveforms from the brain called spikes. Spikes are electrical signals generated by neurons for communication. Different spike waveforms correspond to different cell types or different neuron morphologies.</P>



# Table of Contents

1. [Importing necesary libraries](#section1)
2. [Loading the dataset](#section2)
3. [Visualization waveforms](#section3)
4. [Dimensionality reduction and K-means clustering](#section4)
      * [4.1 Principal Components Analysis (PCA)](#section4.1)
      * [4.2 K-means clustering](#section4.2)
5. [Hierarchical clustering](#section5)
      * [5.1 Elbow method](#section5.1)
      * [5.2 Visualization](#section5.2)


<a id="section1"></a>
# 1. Importing necessary libraries

The [Waveforms dataset](https://www.kaggle.com/joseguzman/waveforms) is freely availabe in Kaggle, but also [in GitHub](https://github.com/JoseGuzman/minibrain). I will use the GitHub link to have some custom plotting.

In [None]:
%pylab inline
from matplotlib import style
style.use('https://raw.githubusercontent.com/JoseGuzman/minibrain/master/minibrain/paper.mplstyle') # my custom plots
import pandas as pd

<a id="section2"></a>
# 2. Loading the dataset

The dataset contains spike waveforms obtained from different neural tissues. In Neuroscience/Electrophysiology, we tend to use the word samples similar as features in Machine Learning.

In [None]:
# Load waveforms
mypath = '../input/waveforms/waveforms.csv'
waveforms = pd.read_csv(mypath, index_col = 'uid')

print(f'{waveforms.shape[0]} observations and {waveforms.shape[1]-1} samples (features)')

waveforms.organoid.value_counts()

<a id="section3"></a>
# 3. Visualize waveforms

To identify waveforms from different tissues, we will use a color code from a python dictionary, called *mycolors*.

In [None]:
# define some custom colors for visualization
mycolors = {'TSCp5_30s':     '#FFA500', # orange
          'TSCp5_32s':       '#4169E1', # royalblue
          'TSCp5_30s_CHIR':  '#FF4500', # orange red
          'TSCp5_32s_CHIR':  '#9400D3', #darkviolet
          'DLX_bluered':     '#32CD32', # limegreen 
          'DLX_Cheriff':     '#228B22', # forestgreen
          'DLX_H9' :         '#006400', # darkgreen 
          'AP_ctrl':         '#00BFFF', # deepskyblue 
          'AP_drug':         '#DC143C'  # crimson
    }

We add a column *color* in our Pandas DataFrame that includes the corresponding tissue color in RBG.

In [None]:
# Add color to the DataFrame
waveforms['color'] = waveforms['organoid'].apply(lambda orgID: mycolors[orgID])


In [None]:
def plot_waveform(index, ax = None):
    """
    Returns an axis with the normalized waveform,
    together with its color and tissue name. It 
    requires trace (2D NumPy array with the vectors
    of spikes) and waveforms (pandas DataFrame).
    
    Arguments
    ---------
    index (int)
        the index of the waveform to be selected in the
        Pandas object.
    ax (matplotlib axis object)
        If None, use current axis object
    """
    if ax is None:
        myax = plt.gca()
    else:
        myax = ax
    
    # waveforms is a DataFrame and must be previously defined
    df = waveforms.drop(['organoid', 'color'], axis = 1, inplace=False)
    trace = df.values # 2D array
    
    myax.plot(trace[index], color = waveforms.iloc[index].color)
    myax.text(x = 5, y= 0.75, s=waveforms.iloc[index].organoid, fontsize=10, color=waveforms.iloc[index].color)
    
    return myax

    

In [None]:
fig, ax = plt.subplots(1,8, figsize=(12,3))
fig.suptitle('Example waveforms')

ax[0] = plot_waveform(index = 15,  ax = ax[0])
ax[1] = plot_waveform(index = 497, ax = ax[1])
ax[2] = plot_waveform(index = 400, ax = ax[2])
ax[3] = plot_waveform(index = 300, ax = ax[3])
ax[4] = plot_waveform(index = 321, ax = ax[4])
ax[5] = plot_waveform(index = 93,  ax = ax[5])
ax[6] = plot_waveform(index = 262,  ax = ax[6])
ax[7] = plot_waveform(index = 351,  ax = ax[7])

ax[7].text(x = 120, y= -1, s='1 ms', fontsize=10)
ax[7].hlines(y = -1.1, xmin =120, xmax = 150, lw=2, color='k') # 30 samples -> 1ms

for myax in ax:
    myax.set_ylim(-1.3,1.3)
    myax.grid()
    myax.axis('off')
#plt.savefig('spikes.png', dpi = 300)

<a id="section4"></a>
# 4 Dimensionality reduction and clustering

Every waveform (observation) contain 150 samples (or features). We want to reduce the number of measurements to a minimun necessary to perform a further classification (ideally to keep 95% of the variance).

<a id="section4.1"></a>
## 4.1 Principal Components Analysis (PCA)

In [None]:
# we select numerical values 
df = waveforms.drop(['organoid', 'color'], axis = 1)
trace = df.values[:, 30:] # remove first 30 samples (1 ms) of waveform baseline

In [None]:
from sklearn.decomposition import PCA

N_components = 10
mypca = PCA(n_components=N_components)
PC = mypca.fit(trace)

fig, ax = plt.subplots(figsize=(4,2))
x = np.arange(1, N_components+1, step=1)
y = np.cumsum(mypca.explained_variance_ratio_)


plt.plot(x, y, '-', color='gray', alpha = 0.4)
plt.plot(x, y, 'o', color='k')

plt.ylim(0.0,1.2)

plt.xlabel('Principal components')
plt.xticks(np.arange(0, 11, step=1)) #change from 0-based array index to 1-based human-readable label
plt.ylabel('Cumulative variance')

plt.axhline(y = 0.95, color = 'brown', linestyle = '--')
plt.hlines(y = y[1], xmin = 0, xmax = 2, color = 'C0', linestyle = '--')
plt.text(0.5, 1, '95% $\sigma$', color = 'brown', fontsize=12)
plt.text(0.5, y[1]+0.05, s=f'{y[1]*100:1.2f}%', color = 'C0', fontsize= 12)
plt.show()

Let's plot the first two projections for visualization

In [None]:
mypca = PCA(n_components=2)
PC = mypca.fit_transform(trace)

waveforms['PC1'] = PC[:,0]
waveforms['PC2'] = PC[:,1]
var1, var2 = mypca.explained_variance_ratio_*100 # variance in percentage
print('Explained variation by the two first principal components: {:2.2f}%, {:2.2f}%'.format(var1, var2))

In [None]:
# visualize
fig, ax = plt.subplots(1,1, figsize=(4,4))

ax.scatter(x = waveforms.PC1, y = waveforms.PC2, s=4, c=waveforms.color)
ax.set_xlabel(f'PC$_1$ = {var1:2.1f} %');
ax.set_ylabel(f'PC$_2$ = {var2:2.1f} %');
ax.set_xlim(-3,3), ax.set_ylim(-3,3)
ax.set_yticks([-2,0,2]), ax.set_xticks([-2,0,2])

markers = [plt.Line2D([],[], color= i, marker = 'o', linestyle = '') for i in mycolors.values()]

ax.legend(markers, mycolors.keys(), numpoints = 10,  title = 'Tissue', bbox_to_anchor=(1.05, 1), loc='upper left')

In [None]:
# visualize


    
fig, ax = plt.subplots(2,4, figsize=(8,4), sharex=True, sharey = True)
fig.tight_layout()
fig.suptitle('Tissues')

def plot_ax(organoidtype, loc=(0,0)):
    """
    Auxiliary plot function to return subplot axis
    of the two first PCA projections of a tissue.
    
    Argument:
    organoidtype:  (str)
        The type of tissue we want to plot
    
    loc: (list)
        The x,y location of the axis object in the subplot
    """
    x = waveforms[waveforms.organoid==organoidtype].PC1
    y = waveforms[waveforms.organoid==organoidtype].PC2
    
    i,j = loc
    return ax[i,j].scatter(x, y, color=mycolors[organoidtype], s = 5, label=organoidtype)


plot_ax('DLX_Cheriff', loc=(0,0))
plot_ax('DLX_bluered', loc=(1,0))

plot_ax('TSCp5_30s', loc=(0,1))
plot_ax('TSCp5_32s', loc=(1,1))

plot_ax('TSCp5_30s_CHIR', loc=(0,2))
plot_ax('TSCp5_32s_CHIR', loc=(1,2))

plot_ax('AP_ctrl', loc=(0,3))
plot_ax('AP_drug', loc=(1,3))


#x = waveforms[waveforms.organoid=='DLX_H9'].PC1
#y = waveforms[waveforms.organoid=='DLX_H9'].PC2
#ax[0,1].scatter(x, y, color=mycolors['DLX_H9'], s = 5, label='DLX_H9')


for myax in ax.flatten():
    myax.scatter(x = waveforms.PC1, y = waveforms.PC2, c='gray', s = 4, alpha = 0.2)
    #myax.set_xlabel(f'PC$_1$ = {var1:2.1f} %');
    #myax.set_ylabel(f'PC$_2$ = {var2:2.1f} %');
    myax.legend(loc='upper center', fontsize=10, frameon=False)
    myax.axis('off')

<a id="section4.2"></a>
## 4.2 K-means clustering

We now use the first two-principal components to perform an umbiassed K-means clustering. We will calculate the minimal number of clusters we need to perform an umbiassed classification.

In [None]:
from sklearn.cluster import KMeans



# Cluster calculation
To determine the optimal number of clusters, we need to compute the cluster after which the distortion/inertia start decreasing linearly.  The distortion or inertia of a K-Means clustering result is the sum of squared differences between an observation and it's corresponding centroid.

**Inertia**: is the sum of squared errors. Thus, the inertia is the sum of squared euclidian distances for each point to its closed centroid.
<math>
$$ \sum_{i = 1}^{n}{d(x_i, c_k)^2} ,$$
</math>

where $n$ is the number of points,  $c_k$ is the k-centroid, and $d()$ is the euclidian distance.

To detect the minimal number of clusters, we can compute the second derivative of the inertia.



In [None]:
inertia = list() # sum of squared error

for k in range(1,21):
    km = KMeans(n_clusters=k, init='k-means++', random_state = 42)
    km = km.fit(PC)
    inertia.append(km.inertia_)

In [None]:
idx = np.arange(1, len(inertia) + 1)

fig, ax = plt.subplots(2, 1, figsize=(4,4), sharex=True)

ax[0].plot(idx, inertia, '-', color='gray', alpha = 0.4)
ax[0].plot(idx, inertia, 'ko', ms=6)
ax[0].set_xticks([1,2,3,4,5, 10, 15])

ax[0].set_ylabel('Inertia');

acceleration = np.diff(np.sqrt(inertia), 2)  # 2nd derivative of the distances

ax[1].plot(idx[:-2] + 1, acceleration, 'g-', ms=8)
ax[1].set_ylabel('Acceleration')
ax[1].set_xlabel('K-clusters')
#ax[1].set_yticks(np.arange(-1,7))

# calculate clusters

k = acceleration.argmax() + 2  # if idx 0 is the max of this we want 2 clusters
ax[0].vlines(x = k, ymin = 0, ymax = inertia[k-1], color='brown', linestyle='--')
ax[1].vlines(x = k, ymin = 0, ymax = acceleration[k-2], color='gray', linestyle='--', alpha=0.7)

print(f'Number of k-clusters:{k}')


In [None]:
# define colors according to the cluster label
myKmeans = KMeans(n_clusters=3, init='k-means++', random_state = 42)
myKmeans.fit(PC)
kcolors = pd.DataFrame(myKmeans.labels_, columns=['Ktype'])

kcolors['color'] = kcolors['Ktype'].map({1: 'green', 0:'orange', 2:'purple'})


We use scipy.spatial.distance.cdist compute distance between each pair of collections of inputs.

In [None]:
from scipy.spatial.distance import cdist

def plot_radii(mykmeans, X, ax = None):
    """
    Plots k-centroids from a kmeans model
    
    Arguments
    ---------
    kmeans (Kmeans object)
        (sklearn.sklearn.cluster.Kmeans model)
    
    X (2D-Numpy array)
        The observations ith the observations)
    
    ax (an axis object)
    """
    # plot the input data
    ax = ax or plt.gca()
    
    labels = mykmeans.labels_ # read labels

    centers = mykmeans.cluster_centers_ # centroids
    radii = [cdist(X[labels == i], [center]).max() for i, center in enumerate(centers)]
    
    for c, r in zip(centers, radii):
        ax.add_patch(plt.Circle(c, r, fc='gray', lw=1, alpha=0.05, zorder=1))

In [None]:
fig, ax = plt.subplots(1,1, figsize=(4,4))

ax.scatter(x = waveforms.PC1, y = waveforms.PC2, s=4, c=kcolors.color)
ax.set_xlabel(f'PC$_1$ = {var1:2.1f} %');
ax.set_ylabel(f'PC$_2$ = {var2:2.1f} %');
ax.set_xlim(-3,3), ax.set_ylim(-3,3)
ax.set_yticks([-2,0,2]), ax.set_xticks([-2,0,2]);
plot_radii(mykmeans = myKmeans, X = PC, ax = ax)

We now calculate the percentage of waveforms in every cluster.

In [None]:
# calculate percentages (just for fun with different NumPY methods :P)
prop_0 = 100 * np.count_nonzero(myKmeans.labels_==0) / len (myKmeans.labels_)
prop_1 = 100 * np.count_nonzero(myKmeans.labels_==1) / len (myKmeans.labels_)
prop_2 = 100 * np.sum(myKmeans.labels_==2)           /len(myKmeans.labels_)


In [None]:
# visualize
# To learn how to use multiple plots, use this https://jakevdp.github.io/PythonDataScienceHandbook/04.08-multiple-subplots.html

grid = plt.GridSpec(1, 5, wspace=0.4, hspace=0.4)

fig = plt.figure(figsize=(12, 4))
pca_plot = fig.add_subplot(grid[0, 0:2])

pca_plot.scatter(x = waveforms.PC1, y = waveforms.PC2, s=4, c=kcolors.color)
pca_plot.set_xlabel(f'PC$_1$ = {var1:2.1f} %');
pca_plot.set_ylabel(f'PC$_2$ = {var2:2.1f} %');
pca_plot.set_yticks([-2,0,2]), ax.set_xticks([-2,0,2]);
pca_plot.set_xlim(-3,3), ax.set_ylim(-3,3)
plot_radii(myKmeans, PC, ax = pca_plot)

wave0 = fig.add_subplot(grid[0, 2])
wave = list()
idx = np.where(kcolors.Ktype==1)
for i in idx[0]:
    mytrace = trace[i]
    wave0.plot(mytrace, lw=0.5, color='green', alpha=0.05)
    wave.append(mytrace)
wave0.plot(np.mean(wave, axis=0), color='green')
wave0.text(x = 0, y = .5, s = f'{prop_1:2.2f}%', color='green')


wave1 = fig.add_subplot(grid[0, 3])
wave = list()
idx = np.where(kcolors.Ktype==0)
for i in idx[0]:
    mytrace = trace[i]
    wave1.plot(mytrace, lw=0.5, color='orange', alpha=0.05)
    wave.append(mytrace)
wave1.plot(np.mean(wave, axis=0), color='orange')
wave1.text(x = 0, y = .5, s = f'{prop_0:2.2f}%', color='orange')


wave2 = fig.add_subplot(grid[0, 4])
wave = list()
idx = np.where(kcolors.Ktype==2)
for i in idx[0]:
    mytrace = trace[i]
    wave2.plot(mytrace, lw=0.5, color='purple', alpha=0.05)
    wave.append(mytrace)
wave2.plot(np.mean(wave, axis=0), color='indigo')
wave2.text(x = 0, y = .5, s = f'{prop_2:2.2f}%', color='indigo')


wave2.text(x = 75, y= -1, s='1 ms', fontsize=10)
wave2.hlines(y = -1.1, xmin =75, xmax = 105, lw=2, color='k') 


for wave in [wave0,wave1,wave2]:
    wave.set_ylim(-1.3,1.0)
    wave.axis('off')

<a id="section5"></a>
# 5. Hierarchical clustering

We will plot the distances of merges during hierarchical clustering with a dendogram. Learn the details of this method in [this blog entry](https://joernhees.de/blog/2015/08/26/scipy-hierarchical-clustering-and-dendrogram-tutorial/)

In [None]:
from scipy.cluster.hierarchy import dendrogram, linkage

We can use the 70% of maximal distance to detect clusters

In [None]:
Z = linkage(PC, method = 'ward')

pmax = Z[:,2].max()
fig, ax = plt.subplots(1, 1, figsize = (12,3))
d = dendrogram(Z, truncate_mode = 'lastp', p = 50, leaf_rotation = 90, leaf_font_size = 8, show_contracted = True, ax = ax)
ax.set_ylabel('Euclidean distances'), ax.set_xlabel('Spike Waveforms')
ax.set_yticks(np.arange(0,25,5))
ax.axhline(y = 0.7*pmax, color='k', linestyle= '--')
ax.text(y = 0.7*pmax +1, x = 10, s='70% of max. distance')


<a id="section5.1"></a>
## 5.1 Elbow method 

In [None]:
distance = Z[-20:, 2] # get distances the latest 20 distances
selection = distance[::-1]
idx = np.arange(1, len(selection) + 1)

fig, ax = plt.subplots(2, 1, figsize=(4,4), sharex=True)
ax[0].plot(idx, selection, '-', color='gray', alpha = 0.4)#ms=8)
ax[0].plot(idx, selection, 'ko', ms = 6)
ax[0].set_xticks([1,2,3,4,5, 10, 15])


acceleration = np.diff(selection, 2)  # 2nd derivative of the distances

ax[1].plot(idx[:-2] + 1, acceleration, 'g-', ms=8)
ax[1].set_ylabel('Acceleration')
ax[1].set_xlabel('Clusters')
#ax[1].set_yticks(np.arange(-1,7))
ax[0].set_ylabel('Distance')

k = acceleration.argmax() + 2  # if idx 0 is the max of this we want 2 clusters
ax[0].vlines(x = k, ymin = 0, ymax = selection[k-1], color='brown', linestyle='--')
ax[1].vlines(x = k, ymin = np.min(acceleration), ymax = acceleration[k-2], color='gray', linestyle='--', alpha=0.7)

print(f'clusters:{k}')

<a id="section5.2"></a>
## 5.2 Visualization

In [None]:
from scipy.cluster.hierarchy import fcluster

In [None]:
mylabels = fcluster(Z, 3, criterion='maxclust')-1 # zero based

In [None]:
# calculate percentages (just for fun with different NumPY methods :P)
prop_0 = 100 * np.sum(mylabels ==0) / len (mylabels)
prop_1 = 100 * np.sum(mylabels ==1) / len (mylabels)
prop_2 = 100 * np.sum(mylabels ==2) /len(mylabels)
prop_0, prop_1, prop_2


In [None]:
# adjust colors
dcolors = pd.DataFrame(mylabels, columns=['Dtype'])

dcolors['color'] = dcolors['Dtype'].map({0: 'green', 1:'orange', 2:'purple'})

In [None]:
# visualize

grid = plt.GridSpec(1, 5, wspace=0.4, hspace=0.4)

fig = plt.figure(figsize=(12, 4))
pca_plot = fig.add_subplot(grid[0, 0:2])

pca_plot.scatter(x = waveforms.PC1, y = waveforms.PC2, s=4, c=dcolors.color)
pca_plot.set_xlabel(f'PC$_1$ = {var1:2.1f} %');
pca_plot.set_ylabel(f'PC$_2$ = {var2:2.1f} %');
pca_plot.set_yticks([-2,0,2]), pca_plot.set_xticks([-2,0,2]);
pca_plot.set_xlim(-3,3), pca_plot.set_ylim(-3,3)
plot_radii(myKmeans, PC, ax = pca_plot)

wave0 = fig.add_subplot(grid[0, 2])
wave = list()
idx = np.where(mylabels==0)
for i in idx[0]:
    mytrace = trace[i]
    wave0.plot(mytrace, lw=0.5, color='green', alpha=0.05)
    wave.append(mytrace)
wave0.plot(np.mean(wave, axis=0), color='green')
wave0.text(x = 0, y = .5, s = f'{prop_0:2.2f}%', color='green')


wave1 = fig.add_subplot(grid[0, 3])
wave = list()
idx = np.where(mylabels==1)
for i in idx[0]:
    mytrace = trace[i]
    wave1.plot(mytrace, lw=0.5, color='orange', alpha=0.05)
    wave.append(mytrace)
wave1.plot(np.mean(wave, axis=0), color='orange')
wave1.text(x = 0, y = .5, s = f'{prop_1:2.2f}%', color='orange')


wave2 = fig.add_subplot(grid[0, 4])
wave = list()
idx = np.where(mylabels==2)
for i in idx[0]:
    mytrace = trace[i]
    wave2.plot(mytrace, lw=0.5, color='purple', alpha=0.05)
    wave.append(mytrace)
wave2.plot(np.mean(wave, axis=0), color='indigo')
wave2.text(x = 0, y = .5, s = f'{prop_2:2.2f}%', color='indigo')


wave2.text(x = 75, y= -1, s='1 ms', fontsize=10)
wave2.hlines(y = -1.1, xmin =75, xmax = 105, lw=2, color='k') 


for wave in [wave0,wave1,wave2]:
    wave.set_ylim(-1.3,1.0)
    wave.axis('off')