Load data

In [None]:
import scipy.signal, scipy.io
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
%matplotlib inline

In [None]:
def plotDataSection(data, title, startTime=0, endTime=None, sampleRate=24e3):
    if endTime is None:
        time = np.linspace(0, len(data)/sampleRate, num=len(data))
        data_section = data
    else:
        idxStart = int(startTime*sampleRate)
        idxEnd = int(endTime*sampleRate)
        data_section = data[idxStart: idxEnd]
        time = np.linspace(startTime, endTime, num=(idxEnd - idxStart))
    fig, ax = plt.subplots()
    ax.plot(time, data_section)
    ax.set_title(title)
    ax.set_xlabel('time [s]')
    ax.set_ylabel('voltage [uv]')
    ax.grid(True)
    return fig, ax

def plotSectionMeanAndSTD(data, title, sampleRate=24e3):
    dataMean = data.mean(axis=0)
    time = np.linspace(0, len(dataMean)/sampleRate, num=len(dataMean))
    dataSTD = data.std(axis=0)
    fig = plt.figure(figsize=(10, 10))
    ax1 = fig.add_subplot(211)
    ax1.plot(time, dataMean, color='b')
    ax1.plot(time, dataMean + dataSTD, 'r.')
    ax1.plot(time, dataMean - dataSTD, 'r.')
    ax1.set_title(title)
    ax1.set_xlabel('time [s]')
    ax1.set_ylabel('voltage [uv]')
    ax1.grid(True)
    ax2 = fig.add_subplot(212)
    ax2.plot(time, dataSTD)
    ax2.set_xlabel('time [s]')
    ax2.set_ylabel('STD')
    ax2.grid(True)

def getPeakSections(data, idxPeaks, windowSizeBefore, windowSizeAfter):
    peakSections = []
    for i in idxPeaks:
        peakSections.append(data[i - windowSizeBefore: i + windowSizeAfter])
    return np.array(peakSections)

def removeOutOfBoundIdx(idx, lowBound, upBound):
    '''Remove the idx that are too close to the begining or end 
    defined by lowBound and upBound.
    Needed to avoid index out of bound error when cutting data into small sections.
    '''
    return idx[np.logical_and(idx > lowBound, idx < upBound)]

In [None]:
sampleRate = 24e3
heightThreshold = 55
distanceThreshold = 10
windowSizeBefore = 30
windowSizeAfter = 70
numSpikeTypes = 2
dataFile = 'sample_4.mat'
numPC = 4

In [None]:
data = scipy.io.loadmat(dataFile)
origData = data['data'].flatten()
spike_times = data['spike_times'][0][0][0]
spike_class = data['spike_class'][0][0][0]
spike_times_type0 = spike_times[np.where(spike_class == 0)]
spike_times_type0 = removeOutOfBoundIdx(
    spike_times_type0, windowSizeBefore, len(origData) - windowSizeAfter - 80)

spike_times_type1 = spike_times[np.where(spike_class == 1)]
spike_times_type1 = removeOutOfBoundIdx(
    spike_times_type1, windowSizeBefore, len(origData) - windowSizeAfter - 80)

spike_times_type2 = spike_times[np.where(spike_class == 2)]
spike_times_type2 = removeOutOfBoundIdx(
    spike_times_type2, windowSizeBefore, len(origData) - windowSizeAfter - 80)

print(f'dataset loaded, origial data shape: {origData.shape}')

recLength = len(origData) / sampleRate
print(f'recording length {recLength:8.2f} seconds.')

In [None]:
spikeSectionType0 = getPeakSections(origData, spike_times_type0, 
    windowSizeBefore, windowSizeAfter + 80)
spikeSectionType1 = getPeakSections(origData, spike_times_type1, 
    windowSizeBefore, windowSizeAfter + 80)
spikeSectionType2 = getPeakSections(origData, spike_times_type2, 
    windowSizeBefore, windowSizeAfter + 80)
plotSectionMeanAndSTD(spikeSectionType0, 'avg waveform type 0')
plotSectionMeanAndSTD(spikeSectionType1, 'avg waveform type 1')
plotSectionMeanAndSTD(spikeSectionType2, 'avg waveform type 2')

Visulize raw data

In [None]:
plotDataSection(origData, 'raw data', 0.5, 1, sampleRate)

Filtering

In [None]:
sos = scipy.signal.butter(8, 40*2*np.pi, 'highpass', fs=sampleRate, output='sos')
filteredData = scipy.signal.sosfilt(sos, origData)

fiteredSectionType0 = getPeakSections(filteredData, spike_times_type0, 
    windowSizeBefore, windowSizeAfter + 80)
fiteredSectionType1 = getPeakSections(filteredData, spike_times_type1, 
    windowSizeBefore, windowSizeAfter + 80)
fiteredSectionType2 = getPeakSections(filteredData, spike_times_type2, 
    windowSizeBefore, windowSizeAfter + 80)

plotDataSection(fiteredSectionType0[3, :], 'filtered waveform type 0')
plotDataSection(fiteredSectionType0[4, :], 'filtered waveform type 0')
plotDataSection(fiteredSectionType0[5, :], 'filtered waveform type 0')

Find Peaks

In [None]:
idxPeaks, peakProperties = scipy.signal.find_peaks(
    origData, height=heightThreshold, distance=distanceThreshold)
idxPeaks = idxPeaks[idxPeaks > windowSizeBefore]
print(f'number of peaks found: {len(idxPeaks)}')

Cut out each peak as small section of data, then the remaining data is considered as background noise.

In [None]:
peakSections = getPeakSections(origData, idxPeaks, windowSizeBefore, windowSizeAfter)

backgroundNoise = origData.copy()
for i in idxPeaks:
    backgroundNoise[i - windowSizeBefore: i + windowSizeAfter] = np.Inf      
backgroundNoise = backgroundNoise[backgroundNoise != np.Inf]


Visualize the background noise

In [None]:
plotDataSection(backgroundNoise, 'background noise')

Apply PCA to the peak sections

In [None]:
pca = PCA(n_components=numPC)
pca.fit(peakSections)
peakPC = pca.transform(peakSections)

Apply K-Mean clustering on the PCA features

In [None]:
kMeanCluster = KMeans(n_clusters=numSpikeTypes, n_init=10)
kMeanCluster.fit(peakPC)
# kMeanCluster.fit(peakSections)
for i in range(numSpikeTypes):
    plotDataSection(
        pca.inverse_transform(kMeanCluster.cluster_centers_[i]), f'spike type {i}')
    # plotDataSection(kMeanCluster.cluster_centers_[i], f'spike type {i}')

Compare the clustering result to ground truth

In [None]:
# compare number of peaks found through clustering to ground truth
print('compare number of peaks found through clustering to ground truth')
print(f'{sum(kMeanCluster.labels_ == 0)} type 0 spikes were found through clutering')
print(f'{sum(kMeanCluster.labels_ == 1)} type 1 spikes were found through clutering')

print(f'{len(spike_times_type0)} type 0 spikes in ground truth')
print(f'{len(spike_times_type1)} type 1 spikes in ground truth')
print(f'{len(spike_times_type2)} type 2 spikes in ground truth')

Confusion matrix

In [None]:
def getConfusionMatrix(pred, groundTruth, tolerance):
    numTruePositive = 0    
    numFalsePositive = 0
    notMatchedGroundTruth = groundTruth.copy()
    listFalsePositive = []
    for singlePred in pred:
        matchingScore = np.abs(notMatchedGroundTruth - singlePred)
        allMatching = notMatchedGroundTruth[matchingScore <= tolerance]
        if allMatching.size > 0:
            bestMatchIdx = np.argmin(matchingScore)
            bestMatch = notMatchedGroundTruth[bestMatchIdx]
            notMatchedGroundTruth = np.delete(notMatchedGroundTruth, bestMatchIdx)
            numTruePositive += 1
        else:
            numFalsePositive += 1
            listFalsePositive.append(singlePred)

    listFalseNegative = notMatchedGroundTruth
    numFalseNegative = len(listFalseNegative)
    print(f'    number of True Positive: {numTruePositive} / {len(groundTruth)}')
    print(f'    number of False Positive: {numFalsePositive}, False Negative: {numFalseNegative}')
    return (numTruePositive, numFalsePositive, 
        numFalseNegative, listFalsePositive, listFalseNegative)

In [None]:
type0PeakIdx = idxPeaks[kMeanCluster.labels_ == 0]
type1PeakIdx = idxPeaks[kMeanCluster.labels_ == 1]
tolerance = windowSizeBefore + windowSizeAfter

print('confusion matrix for type0 (clusting) and type1 (ground truth):')
numTruePositive, numFalsePositive, numFalseNegative, \
    listFalsePositive, listFalseNegative = getConfusionMatrix(
        type0PeakIdx, spike_times_type1, tolerance)

print('confusion matrix for type1 (clusting) and type2 (ground truth):')
numTruePositive, numFalsePositive, numFalseNegative, \
    listFalsePositive, listFalseNegative = getConfusionMatrix(
        type1PeakIdx, spike_times_type2, tolerance)

In [None]:
def plotPeak(origData, peakIdx, title, windowSizeBefore, windowSizeAfter, sampleRate=24e3):
    data = origData[peakIdx - windowSizeBefore: peakIdx + windowSizeAfter]
    plotDataSection(data, title, sampleRate=sampleRate)

numVisualize = 2
for i in range(numVisualize):
    plotPeak(origData, listFalsePositive[i], f'False Positive case {i}', 
        windowSizeBefore, windowSizeAfter, sampleRate)
    plotPeak(origData, listFalseNegative[i], f'False Negative case {i}', 
        windowSizeBefore, windowSizeAfter, sampleRate)