_Neural Data Analysis_

Lecturer: Prof. Dr. Philipp Berens, Dr. Alexander Ecker

Tutors: Sarah Strauss, Santiago Cadena

Summer term 2019

Due date: 2019-04-23, 9am

Student names: *FILL IN YOUR NAMES HERE*

# Exercise sheet 1

Download the data file ```nda_ex1.csv``` from ILIAS and save it in a subfolder ```../data/```.

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from scipy import signal
from sklearn.decomposition import PCA
import scipy as sp
import itertools as it
sns.set_style('whitegrid')
%matplotlib inline


## Load data

In [None]:
Fs = 30000     # sampling rate of the signal in Hz
dt = 1/Fs
gain = .5      # gain of the signal
x = pd.read_csv('../data/nda_ex_1.csv', header=0, names=('Ch1', 'Ch2', 'Ch3', 'Ch4'))  

## Task 1: Filter Signal

In order to detect action potentials, the first step is to filter out low frequency fluctuations (LFP) and high frequency noise. Determine appropriate filter settings and implement the filtering in the function ```filterSignal()```. A typical choice for this task would be a butterworth filter. Plot a segment of the raw signal and the filtered signal for all four channels with matching y-axis. The segment you choose should contain spikes. When you apply the function also test different filter settings.

*Grading: 2 pts*


In [None]:
def filterSignal(x, Fs, low, high):
# Filter raw signal
#   y = filterSignal(x, Fs, low, high) filters the signal x. Each column in x is one
#   recording channel. Fs is the sampling frequency. low and high specify the passband in Hz.
#   The filter delay is compensated in the output y.

        

    return y


In [None]:
xf = filterSignal(x, Fs, 500, 4000)

In [None]:
plt.figure(figsize=(14, 8))

T = 100000
t = np.arange(0,T) * dt 

for i, col in enumerate(xf):
    plt.subplot(4,2,2*i+1)
    plt.plot(t,x[col][0:T],linewidth=.5)
    plt.ylim((-1000, 1000))
    plt.xlim((0,3))
    plt.ylabel('Voltage')
    
    
    plt.subplot(4,2,2*i+2)
    plt.plot(t,xf[col][0:T],linewidth=.5)
    plt.ylim((-400, 250))
    plt.xlim((0,3))
    plt.ylabel('Voltage')
    

## Task 2: Detect action potentials

Action potentials are usually detected by finding large-amplitude deflections in the continuous signal. A good choice of threshold for detecting spikes is important. If it is too low, you will detect too many low amplitude events (noise); if it is too high, you run the risk of missing good spikes. Implement an automatic procedure to obtain a reasonable threshold and detect the times when spikes occurred in the function ```detectSpikes()``` . Plot a segment of the filtered signal for all four channels with matching y-axis and indicate the time points where you detected spikes. Plot the threshold. Are the detected time points well aligned with peaks in the signal?

*Grading: 3 pts*

In [None]:
def detectSpikes(x,Fs):
# Detect spikes
# s, t = detectSpikes(x,Fs) detects spikes in x, where Fs the sampling
#   rate (in Hz). The outputs s and t are column vectors of spike times in
#   samples and ms, respectively. By convention the time of the zeroth
#   sample is 0 ms.



    return (s, t)


In [None]:
T = xf.shape[0]
s, t = detectSpikes(xf.as_matrix(),Fs)

In [None]:
plt.figure(figsize=(7, 8))

tt = np.arange(0,T) * dt 

for i, col in enumerate(xf):
    plt.subplot(4,1,i+1)
    plt.plot(tt,xf[col],linewidth=.5)
    plt.plot(tt[s],xf[col][s],'r.')
    plt.ylim((-400, 400))
    plt.xlim((0.025,0.075))
    plt.ylabel('Voltage')



## Task 3: Extract waveforms
For later spike sorting we need the waveforms of all detected spikes. Extract the waveforms segments (1 ms) on all four channels for each spike time (as a result each spike is represented by a 4x30 element matrix). Implement this procedure in the function ```extractWaveforms()```. Plot (a) the first 100 spikes you detected and (b) the 100 largest spikes you detected. Are there a lot of very small spikes (likely noise) among your detected spikes? If so your threshold may be too low. Can you see obvious artifacts, not looking like spikes at all?

*Grading: 2 pts*

In [None]:
def extractWaveforms(x, s):
# Extract spike waveforms.
#   w = extractWaveforms(x, s) extracts the waveforms at times s (given in
#   samples) from the filtered signal x using a fixed window around the
#   times of the spikes. The return value w is a 3d array of size
#   length(window) x #spikes x #channels.


    return w
    

In [None]:
w = extractWaveforms(xf.as_matrix(),s)

Plot first 100 spike waveforms

In [None]:
t = np.arange(-10,20) * dt * 1000 

plt.figure(figsize=(11, 8))

for i, col in enumerate(xf):
    plt.subplot(2,2,i+1)
    plt.plot(t,w[:,1:100,i],'k', linewidth=1)
    plt.ylim((-500, 250))
    plt.xlim((-0.33,0.66))
    plt.ylabel('Voltage')


Plot largest 100 spike waveforms

In [None]:
idx = np.argsort(np.min(np.min(w,axis=2),axis=0))


t = np.arange(-10,20) * dt * 1000 

plt.figure(figsize=(11, 8))
for i, col in enumerate(xf):
    plt.subplot(2,2,i+1)
    plt.plot(t,w[:,idx[0:100],i],'k', linewidth=1)
    plt.ylim((-1000, 500))
    plt.xlim((-0.33,0.66))
    plt.ylabel('Voltage')


## Task 4: Extract features using PCA
Compute the first three PCA features on each channel separately in ```extractFeatures()```. You can use a available PCA implementation or implement it yourself. After that, each spike is represented by a 12 element vector. Compute the fraction of variance captured by these three PCs.
Plot scatter plots for all pairwise combinations of 1st PCs. Do you see clusters visually? 

*Grading: 2+1 pts*


In [None]:
def extractFeatures(w):
# Extract features for spike sorting.
#   b = extractFeatures(w) extracts features for spike sorting from the
#   waveforms in w, which is a 3d array of size length(window) x #spikes x
#   #channels. The output b is a matrix of size #spikes x #features.
#   The implementation should do PCA on the waveforms of each channel
#   separately and uses the first three principal components. Thus, we get
#   a total of 12 features. Also, the varianced explained by the 3 features per channel
#   should be computed.
    

    

    return b



In [None]:
b = extractFeatures(w)

In [None]:
plt.figure(figsize=(10, 6))
plt.suptitle('Scatter plots',fontsize=20)

idx = [0, 3, 6, 9]
p = 1
labels = ['Ch1','Ch2','Ch3','Ch4']
for i in np.arange(0,4):
    for j in np.arange(i+1,4):
        ax = plt.subplot(2,3,p, aspect='equal')
        plt.plot(b[:,idx[i]],b[:,idx[j]],'.k', markersize=.7) 
        plt.xlabel(labels[i])
        plt.ylabel(labels[j])
        plt.xlim((-1500,1500))
        plt.ylim((-1500,1500))
        ax.set_xticks([])
        ax.set_yticks([])
        p = p+1

In [None]:
np.save('../data/nda_ex_1_features',b)
np.save('../data/nda_ex_1_spiketimes',s)
np.save('../data/nda_ex_1_waveforms',w)

In [None]:
@jit(nopython=True)
def mog_old(x, k, m, S, p, y, ind, arr):
# Fit Mixture of Gaussian model
#   ind, m, S, p = mog(x,k) fits a Mixture of Gaussian model to the data in
#   x using k components. The output ind contains the MAP assignments of the
#   datapoints in x to the found clusters. The outputs m, S, p contain
#   the model parameters.
#
#   x:     N by D
#
#   ind:   N by 1
#   m:     k by D
#   S:     D by D by k
#   p:     k by 1

    # fill in your code here
            
    iteration = 100;
    iter_count = 0;
    
    threshold = 1e-30
    log_likelihood_old = 0;
    log_likelihood_new = 1;
    
    while iter_count < iteration and np.abs(log_likelihood_old - log_likelihood_new) > threshold:
              
        for j in range(N):           
            sum_tmp = 0;
            for i in range(k): 
                minus = x[j] - m[i];
                
                matmul_1 = (minus[0] * S[1][1][i] + minus[1] * (-1*S[0][1][i])) / np.linalg.det(S[:, :, i]);
                matmul_2 = (minus[0] * (-1*S[1][0][i]) + minus[1] * S[0][0][i]) / np.linalg.det(S[:, :, i]);
                
                matmul_tmp = matmul_1 * minus[0] + matmul_2 * minus[1];     
                                
                exp_tmp = np.exp(-0.5 * matmul_tmp);
                posterior = (p[i] / (2 * math.pi * np.sqrt(np.linalg.det(S[:, :, i])))) * exp_tmp;
                sum_tmp += posterior;
                y[j][i] = posterior;
                
            for i in range(k):
                y[j][i] = y[j][i] /sum_tmp;   

        N_k = np.zeros(k);
        
        for i in range(k):
            for j in range(N):
                N_k[i] += y[j][i];
                
        p[0] = N_k[0] / N;
        p[1] = N_k[1] / N;
        p[2] = N_k[2] / N;
                
        for i in range(k):      
            sum_tmp_x = 0;     
            sum_tmp_y = 0;
            
            for j in range(N):               
                sum_tmp_x += y[j][i] * x[j][0]; 
                sum_tmp_y += y[j][i] * x[j][1];    
            
            m[i][0] = sum_tmp_x/N_k[i];
            m[i][1] = sum_tmp_y/N_k[i];
                                 
            sum_tmp_11 = 0;     
            sum_tmp_12 = 0;   
            sum_tmp_21 = 0;     
            sum_tmp_22 = 0;
            
            for j in range(N):      
                
                tmp_arr = x[j] - m[i];  
                
                sum_tmp_11 += y[j][i] * (tmp_arr[0]*tmp_arr[0]);
                sum_tmp_12 += y[j][i] * (tmp_arr[0]*tmp_arr[1]);
                sum_tmp_21 += y[j][i] * (tmp_arr[0]*tmp_arr[1]);
                sum_tmp_22 += y[j][i] * (tmp_arr[1]*tmp_arr[1]);
                   
            S[0, 0, i] = sum_tmp_11/N_k[i];
            S[0, 1, i] = sum_tmp_12/N_k[i];
            
            S[1, 0, i] = sum_tmp_21/N_k[i];
            S[1, 1, i] = sum_tmp_22/N_k[i];

        iter_count += 1; 
        
        log_likelihood_old = log_likelihood_new;        
        log_likelihood_new = 0;
        
        for j in range(N): 
            
            log_sum = 0;
            
            for i in range(k):
                minus = x[j] - m[i];
                
                matmul_1 = (minus[0] * S[1][1][i] + minus[1] * (-1*S[0][1][i])) / np.linalg.det(S[:, :, i]);
                matmul_2 = (minus[0] * (-1*S[1][0][i]) + minus[1] * S[0][0][i]) / np.linalg.det(S[:, :, i]);
                
                matmul_tmp = matmul_1 * minus[0] + matmul_2 * minus[1];     
                
                exp_tmp = np.exp(-0.5 * matmul_tmp);
                posterior = (p[i] / (2 * math.pi * np.sqrt(np.linalg.det(S[:, :, i])))) * exp_tmp;
                
                log_sum += math.log(posterior);
                
            log_likelihood_new += log_sum;
            
    for j in range(N): 
        for i in range(k):
            arr[i] = y[j][i];
        ind[j] = np.argmax(arr);
            
    return (ind, m, S, p)
