# SSVEP: Offline processing using Machine Leaning Method

## Step 0: Import neceessary toolboxes

In [57]:
# import require library for preprocess
import mne
import numpy as np
from mne.channels import make_standard_montage
import matplotlib.pyplot as plt
from mne.datasets import eegbci
import scipy
import pickle
import seaborn as sns

# import require library for classification
from sklearn.svm import SVC # SVM library
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis # LDA library
from sklearn.neighbors import KNeighborsClassifier # KNN library

from sklearn.metrics import classification_report,confusion_matrix # Result representation

In [58]:
select_feature = "fft"

## Step 1: Read data file

In [59]:
# read biosemi file (bdf)
raw = mne.io.read_raw_bdf("Testdata3.bdf", preload=True, verbose=False) 
eegbci.standardize(raw)

## Step 2: Data preprocessing -- set channel locations/ downsampling/ frequency filtering (bandpass)/ epoching

In [60]:
from scipy.signal import filtfilt
from scipy import signal

# Set channel location
montage = make_standard_montage("biosemi64")
raw.set_montage(montage, on_missing='ignore')

# Downsample data (from 1024 to 512Hz) to save storage space 
raw = raw.resample(512, verbose = False)


#Band pass and notch filter
raw = raw.copy().notch_filter(freqs=50)
raw = raw.copy().filter(l_freq=1, h_freq=40, verbose = False)

# raw = raw.copy().filter(l_freq=2.0, h_freq=40.0, method = 'iir', iir_params= {"order": 4, "ftype":'butter'})

# Get events and timestamps
events = mne.find_events(raw, shortest_event = 0, verbose = False) 

# Create event dictionary 
event_dict =  {'12Hz': 8,
'24Hz': 4,
'6Hz': 10,
'30Hz': 2
}

# Use events and event dictionary to cut data into Epochs
ssvep_chans = ['O1','Oz','PO3','PO4','POz','Pz']  # Reject O2 becuase noisy channel

Epochs = mne.Epochs(raw, events, 
    tmin= -1.0,  
    tmax= 4.0,    
    event_id=event_dict,
    picks = ssvep_chans,
    preload = True,
    event_repeated='drop',
    baseline= (-1,0),
    verbose=False
)

Epochs = Epochs.copy().crop(tmin = 0.0, tmax = 4.0)

train_label = Epochs['12Hz','6Hz', '24Hz', '30Hz'].events[:,-1]

# print(raw.pick(['O1','Oz','PO3','PO4','POz','Pz']).get_data()[:,9735:]* 10e6)

Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 49 - 51 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 49.38
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 49.12 Hz)
- Upper passband edge: 50.62 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 50.88 Hz)
- Filter length: 3381 samples (6.604 s)



[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  72 out of  72 | elapsed:    0.1s finished


## Step3: Feature extraction

## 3.1 Fast Fourier Transform

In [61]:
# Compute FFT for each epoch and return the power spectral density
def compute_fft(epoch_data, sampling_rate):

    num_epochs, num_channels, num_timepoints = epoch_data.shape

    freqs = np.fft.fftfreq(num_timepoints-1, 1 / sampling_rate)
    
    fft_data = np.zeros((num_epochs, num_channels, len(freqs)))

    # Compute FFT for each channel and each epoch
    for epoch_idx in range(num_epochs):
        for ch_idx in range(num_channels):     
            fft_result = scipy.fft.fft(epoch_data[epoch_idx, ch_idx, 0:2048])

            power_spectrum = np.abs(fft_result) ** 2  # Power = |FFT|^2
            fft_data[epoch_idx, ch_idx, :] = power_spectrum

    return fft_data, freqs

# Example usage
fft_out, freqs_out = compute_fft(Epochs.get_data() * 10e6, 512)
print(np.shape(fft_out))

fft_train = np.stack([arr.flatten() for arr in fft_out])
print(fft_train.shape)


(20, 6, 2048)
(20, 12288)


  fft_out, freqs_out = compute_fft(Epochs.get_data() * 10e6, 512)


## 3.2 Power Spectrum Density

In [62]:
psd_epoch = Epochs['12Hz','6Hz', '24Hz', '30Hz'].pick(ssvep_chans).compute_psd(fmin=1.0, fmax=40.0)
print(psd_epoch.shape)

psd_train = np.stack([arr.flatten() for arr in psd_epoch])
print(psd_train.shape)

    Using multitaper spectrum estimation with 7 DPSS windows
(20, 6, 156)
(20, 936)


## Load test set

In [63]:
with open('SSVEP_test_epochs.pkl', 'rb') as f:
    test_epochs = pickle.load(f)

test_epochs = test_epochs.copy().crop(tmin = 0.0, tmax = 4.0)

fft_out, freqs_out = compute_fft(test_epochs['12Hz','6Hz', '24Hz', '30Hz'].pick(ssvep_chans).get_data() * 10e6, 512)
fft_test = np.stack([arr.flatten() for arr in fft_out])

psd_epoch = test_epochs['12Hz','6Hz', '24Hz', '30Hz'].pick(ssvep_chans).compute_psd(fmin=1.0, fmax=40.0)
psd_test = np.stack([arr.flatten() for arr in psd_epoch])

mapping = {2: 8, 4: 4, 8: 10, 10: 2}
test_label = np.vectorize(mapping.get)(test_epochs['12Hz','6Hz', '24Hz', '30Hz'].events[:,-1])


    Using multitaper spectrum estimation with 7 DPSS windows


  fft_out, freqs_out = compute_fft(test_epochs['12Hz','6Hz', '24Hz', '30Hz'].pick(ssvep_chans).get_data() * 10e6, 512)


## Step4: Classification

## 4.0 Select Feature as train and test set

In [64]:
if select_feature == "fft":
    x_train = fft_train
    x_test = fft_test

elif select_feature == 'psd':
    x_train = psd_train
    x_test = psd_test 

## 4.1 LDA

In [65]:
def GetConfusionMatrix(models, X_train, X_test, y_train, y_test, target_names):
    y_pred = models.predict(X_train)
    print("Classification TRAIN DATA \n=======================")
    print(classification_report(y_true= y_train, y_pred=y_pred, target_names= target_names))
    print("Confusion matrix \n=======================")
    print(confusion_matrix(y_true= y_train, y_pred=y_pred))

    y_pred = models.predict(X_test)
    print("Classification TEST DATA \n=======================")
    print(classification_report(y_true=y_test, y_pred=y_pred, target_names= target_names))
    print("Confusion matrix \n=======================")
    print(confusion_matrix(y_true=y_test, y_pred=y_pred))

    

In [66]:
y_train = train_label # Get true label
y_test = test_label

lda = LinearDiscriminantAnalysis(solver= "svd")
lda.fit(x_train, y_train)

print('accuracy', lda.score(x_train, y_train))
label_names = ['12Hz', '6Hz', '24Hz', '30Hz']

with open("LDA_model.pkl", "wb") as file:
    pickle.dump(lda, file)

GetConfusionMatrix(lda, x_train, x_test, y_train, y_test, label_names)

accuracy 0.6
Classification TRAIN DATA 
              precision    recall  f1-score   support

        12Hz       0.57      0.80      0.67         5
         6Hz       0.67      0.40      0.50         5
        24Hz       0.60      0.60      0.60         5
        30Hz       0.60      0.60      0.60         5

    accuracy                           0.60        20
   macro avg       0.61      0.60      0.59        20
weighted avg       0.61      0.60      0.59        20

Confusion matrix 
[[4 1 0 0]
 [1 2 1 1]
 [1 0 3 1]
 [1 0 1 3]]
Classification TEST DATA 
              precision    recall  f1-score   support

        12Hz       0.36      0.67      0.47         6
         6Hz       0.38      0.60      0.46         5
        24Hz       0.00      0.00      0.00         5
        30Hz       0.33      0.17      0.22         6

    accuracy                           0.36        22
   macro avg       0.27      0.36      0.29        22
weighted avg       0.28      0.36      0.29        22

C

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


## 4.2 SVM

In [69]:
y_train = train_label # Get true label
y_test = test_label

svm_model = SVC(C = 1, kernel= 'rbf')  # Using a linear kernel
svm_model.fit(x_train, y_train)

print(x_train.shape)

print('accuracy', svm_model.score(x_train, y_train))
label_names = ['12Hz', '6Hz', '24Hz', '30Hz']

with open("SVM_model.pkl", "wb") as file:
    pickle.dump(svm_model, file)

GetConfusionMatrix(svm_model, x_train, x_test, y_train, y_test, label_names)

(20, 12288)
accuracy 1.0
Classification TRAIN DATA 
              precision    recall  f1-score   support

        12Hz       1.00      1.00      1.00         5
         6Hz       1.00      1.00      1.00         5
        24Hz       1.00      1.00      1.00         5
        30Hz       1.00      1.00      1.00         5

    accuracy                           1.00        20
   macro avg       1.00      1.00      1.00        20
weighted avg       1.00      1.00      1.00        20

Confusion matrix 
[[5 0 0 0]
 [0 5 0 0]
 [0 0 5 0]
 [0 0 0 5]]
Classification TEST DATA 
              precision    recall  f1-score   support

        12Hz       0.75      0.50      0.60         6
         6Hz       0.50      0.20      0.29         5
        24Hz       1.00      1.00      1.00         5
        30Hz       0.45      0.83      0.59         6

    accuracy                           0.64        22
   macro avg       0.68      0.63      0.62        22
weighted avg       0.67      0.64      0.62 

## 4.3 KNN

In [71]:
y_train = train_label # Get true label
y_test = test_label

knn = KNeighborsClassifier(n_neighbors= 5, weights = "distance")
knn.fit(x_train, y_train)

print('accuracy', knn.score(x_train, y_train))
label_names = ['12Hz', '6Hz', '24Hz', '30Hz']

with open("KNN_model.pkl", "wb") as file:
    pickle.dump(knn, file)

GetConfusionMatrix(knn, x_train, x_test, y_train, y_test, label_names)

accuracy 1.0
Classification TRAIN DATA 
              precision    recall  f1-score   support

        12Hz       1.00      1.00      1.00         5
         6Hz       1.00      1.00      1.00         5
        24Hz       1.00      1.00      1.00         5
        30Hz       1.00      1.00      1.00         5

    accuracy                           1.00        20
   macro avg       1.00      1.00      1.00        20
weighted avg       1.00      1.00      1.00        20

Confusion matrix 
[[5 0 0 0]
 [0 5 0 0]
 [0 0 5 0]
 [0 0 0 5]]
Classification TEST DATA 
              precision    recall  f1-score   support

        12Hz       1.00      0.17      0.29         6
         6Hz       0.40      0.80      0.53         5
        24Hz       1.00      0.40      0.57         5
        30Hz       0.33      0.50      0.40         6

    accuracy                           0.45        22
   macro avg       0.68      0.47      0.45        22
weighted avg       0.68      0.45      0.44        22

C