In [9]:
import os.path
import scipy.io
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pywt
import pywt.data
from scipy import signal

import seaborn as sns
%matplotlib inline

In [None]:
# Load in provided raw waveform data

mat = scipy.io.loadmat('tetrode12.mat')
RawData = mat['spikes_ep2']

In [None]:
from scipy import signal

FS = 30000

# Pass signal through a bandpass filter
b, a = signal.butter(4, [600/(FS/2), 6000/(FS/2)], btype='bandpass')

FilteredData = signal.lfilter(b, a, RawData, axis=0)

In [3]:
def data_preprocessing(FilteredData):
  '''
  Computes the indices where FilteredData exceeds 60.
  This is a naive approach to thresholding; more robust approaches would
    enforce a 1ms interspike interval. 
  '''
  SnippetLength = 40
  SnippetPre = 10
  SnippetPost = SnippetLength - SnippetPre
  Threshold = 60

  DataLen = FilteredData.shape[0] - (SnippetPost)

  # can be remanipulated 
  # it is not robust spike detection but it's ok
  OverThreshold = np.where(FilteredData > Threshold)
  central = list(np.where(FilteredData > Threshold)[0])
  all_indices = []
  for i in central:
    for j in range(i-SnippetPre, i+SnippetPost+1):
      all_indices.append(j) 
  return all_indices 
def swt(signal):
  '''
  Implement SWT (Figure 3) for detail levels 1-4 using Haar as mother wavelet
  '''
  haar = pywt.Wavelet('haar')
  detail_coeffs = []
  coeffs = pywt.swt(signal, haar, level=4)
  detail_coeffs.append(np.array([c[1] for c in coeffs]))
  return np.vstack(detail_coeffs)

In [None]:
waveform = np.transpose(FilteredData)[0][0:50000]

wavelet_result = swt(waveform)

dk = wavelet_result[-1*1, :]

spike_times = np.unique(data_preprocessing(waveform))

In [None]:
AP = []
for i in range(len(waveform)):
  AP.append(0)

for i in range(len(waveform)):
  if i in spike_times:
    AP[i] = waveform[i]

In [None]:
plt.plot(AP[12500:14000])
plt.xlabel("Sample Number")
plt.ylabel("uV")
plt.title("Action Potentials")
plt.show()

In [None]:
plt.plot(waveform[12500:14000])
plt.xlabel("Sample Number")
plt.ylabel("uV")
plt.title("Waveform for original signal")
plt.show()

In [None]:
for i in range(1,5):

  waveform = np.transpose(FilteredData)[0][0:50000]

  wavelet_result = swt(waveform)

  dk = wavelet_result[-1*i, :]

  spike_times = np.unique(data_preprocessing(waveform))

  plt.plot(dk[12500:14000])
  plt.xlabel("Sample Number")
  plt.ylabel("uV")
  plt.title("Waveform for Level-4 Detailed coefficients")
  plt.show()