# SSVEP: Offline processing using Machine Leaning Method

## Step 0: Import neceessary toolboxes

In [1]:
# import require library for preprocess
import mne
import numpy as np
from mne.channels import make_standard_montage
import matplotlib.pyplot as plt
from mne.datasets import eegbci
import scipy
import pickle
import seaborn as sns

# import require library for classification
from sklearn.svm import SVC # SVM library
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis # LDA library
from sklearn.neighbors import KNeighborsClassifier # KNN library
from sklearn.model_selection import train_test_split, GridSearchCV, KFold

from sklearn.metrics import classification_report,confusion_matrix # Result representation
SEED = 42

## Step 1: Read data file

In [3]:
# read biosemi file (bdf)
raw1 = mne.io.read_raw_bdf("C:\\Users\\pipo_\\OneDrive\\Desktop\\neuromedia\\group1_block1.bdf", preload=True, verbose=False) 

raw2 = mne.io.read_raw_bdf("C:\\Users\\pipo_\\OneDrive\\Desktop\\neuromedia\\group1_block2.bdf", preload=True, verbose=False) 

raw3 = mne.io.read_raw_bdf("C:\\Users\\pipo_\\OneDrive\\Desktop\\neuromedia\\group1_block3.bdf", preload=True, verbose=False) 

raw4 = mne.io.read_raw_bdf("C:\\Users\\pipo_\\OneDrive\\Desktop\\neuromedia\\group1_block4.bdf", preload=True, verbose=False) 
raw = mne.concatenate_raws([raw1, raw2, raw3, raw4])

## Step 2: Data preprocessing -- set channel locations/ downsampling/ frequency filtering (bandpass)/ epoching

In [4]:
from scipy.signal import filtfilt
from scipy import signal

# Set channel location
montage = make_standard_montage("biosemi64")
raw.set_montage(montage, on_missing='ignore')

# Downsample data (from 1024 to 512Hz) to save storage space 
raw = raw.resample(512, verbose = False)

# Get events and timestamps
events = mne.find_events(raw, shortest_event = 0, verbose = False) 

# Create event dictionary 
event_dict =  {'12Hz': 8,
'24Hz': 4,
'6Hz': 10,
'30Hz': 2
}

# Use events and event dictionary to cut data into Epochs
ssvep_chans = ['O1','Oz','PO3','POz','O2']  # Reject O2 becuase noisy channel

Epochs = mne.Epochs(raw, events, 
    tmin= -0.5,  
    tmax= 4.0,    
    event_id=event_dict,
    picks = ssvep_chans,
    preload = True,
    event_repeated='drop',
    baseline= (-0.5,0),
    verbose=False
)

Epochs = Epochs.copy().crop(tmin = 0.0, tmax = 4.0)

train_label = Epochs['12Hz','6Hz', '24Hz', '30Hz'].events[:,-1]


**Apply Scipy Filter**

In [5]:
from scipy import signal

def butter_bandpass(lowcut,highcut,fs,order):
    nyq = 0.5*fs
    low = lowcut/nyq
    high = highcut/nyq
    b,a = signal.butter(order,[low,high],'bandpass')
    return b,a

def butter_bandpass_filter(data,lowcut = 6, highcut = 30, order = 4, axis = 1):
    b,a = butter_bandpass(lowcut,highcut,512,order)
    y = signal.filtfilt(b,a,data,axis=axis)
    return y

Epochs_data = butter_bandpass_filter(Epochs.get_data(), lowcut = 2, highcut= 40, axis = 2)

  Epochs_data = butter_bandpass_filter(Epochs.get_data(), lowcut = 2, highcut= 40, axis = 2)


## Step3: Feature extraction

## 3.1 Fast Fourier Transform

In [6]:
# Compute FFT for each epoch and return the power spectral density
def compute_fft(epoch_data, sampling_rate):

    num_epochs, num_channels, num_timepoints = epoch_data.shape

    freqs = np.fft.fftfreq(num_timepoints-1, 1 / sampling_rate)
    
    fft_data = np.zeros((num_epochs, num_channels, len(freqs)))

    # Compute FFT for each channel and each epoch
    for epoch_idx in range(num_epochs):
        for ch_idx in range(num_channels):     
            fft_result = scipy.fft.fft(epoch_data[epoch_idx, ch_idx, 0:2048])

            power_spectrum = np.abs(fft_result) ** 2  # Power = |FFT|^2
            fft_data[epoch_idx, ch_idx, :] = power_spectrum

    return fft_data, freqs

# Example usage
fft_out, freqs_out = compute_fft(Epochs_data, 512)
print(np.shape(fft_out))

fft_train = np.stack([arr.flatten() for arr in fft_out])
print(fft_train.shape)


(192, 5, 2048)
(192, 10240)


## Step4: Classification

In [None]:
x_train, x_test, y_train, y_test = train_test_split(fft_train, train_label, test_size=0.3, random_state=SEED)

## 4.1 LDA

In [9]:
def GetConfusionMatrix(models, X_train, X_test, y_train, y_test, target_names):
    y_pred = models.predict(X_train)
    print("Classification TRAIN DATA \n=======================")
    print(classification_report(y_true= y_train, y_pred=y_pred, target_names= target_names))
    print("Confusion matrix \n=======================")
    print(confusion_matrix(y_true= y_train, y_pred=y_pred))

    y_pred = models.predict(X_test)
    print("Classification TEST DATA \n=======================")
    print(classification_report(y_true=y_test, y_pred=y_pred, target_names= target_names))
    print("Confusion matrix \n=======================")
    print(confusion_matrix(y_true=y_test, y_pred=y_pred))

    

In [10]:
lda = LinearDiscriminantAnalysis()

param_grid = {
    'solver': ['svd']
}

cv_splitter = KFold(n_splits=5, shuffle=True, random_state=SEED)
tuned_clf_lda = GridSearchCV(estimator=lda, param_grid=param_grid,
                    scoring='accuracy', refit='accuracy', cv=cv_splitter)

tuned_clf_lda.fit(x_train, y_train)
print(f"Best parameters: {tuned_clf_lda.best_params_}")
print(f"Best cross-validation score: {tuned_clf_lda.best_score_:.3f}")
label_names = ['12Hz', '6Hz', '24Hz', '30Hz']

with open("trained_model/LDA_model.pkl", "wb") as file:
    pickle.dump(tuned_clf_lda, file)

GetConfusionMatrix(tuned_clf_lda, x_train, x_test, y_train, y_test, label_names)

Best parameters: {'solver': 'svd'}
Best cross-validation score: 0.567
Classification TRAIN DATA 
              precision    recall  f1-score   support

        12Hz       0.66      0.66      0.66        35
         6Hz       0.77      0.79      0.78        29
        24Hz       0.83      0.71      0.77        35
        30Hz       0.72      0.80      0.76        35

    accuracy                           0.74       134
   macro avg       0.74      0.74      0.74       134
weighted avg       0.74      0.74      0.74       134

Confusion matrix 
[[23  3  4  5]
 [ 4 23  0  2]
 [ 4  2 25  4]
 [ 4  2  1 28]]
Classification TEST DATA 
              precision    recall  f1-score   support

        12Hz       0.71      0.77      0.74        13
         6Hz       0.74      0.74      0.74        19
        24Hz       0.78      0.54      0.64        13
        30Hz       0.69      0.85      0.76        13

    accuracy                           0.72        58
   macro avg       0.73      0.72    

## 4.2 SVM

In [11]:
param_grid = {
    'C':  [10, 100],
    'kernel': ['rbf', 'poly']
}

svm_model =  SVC(random_state=SEED, probability= False)
cv_splitter = KFold(n_splits=5, shuffle=True, random_state=SEED)
tuned_clf_svm = GridSearchCV(estimator=svm_model, param_grid=param_grid,
                    scoring='accuracy', refit='accuracy', cv=cv_splitter)

tuned_clf_svm.fit(x_train, y_train)

print(f"Best parameters: {tuned_clf_svm.best_params_}")
print(f"Best cross-validation score: {tuned_clf_svm.best_score_:.3f}")
label_names = ['12Hz', '6Hz', '24Hz', '30Hz']

with open("trained_model/SVM_model.pkl", "wb") as file:
    pickle.dump(tuned_clf_svm, file)

GetConfusionMatrix(tuned_clf_svm, x_train, x_test, y_train, y_test, label_names)

Best parameters: {'C': 10, 'kernel': 'rbf'}
Best cross-validation score: 0.672
Classification TRAIN DATA 
              precision    recall  f1-score   support

        12Hz       1.00      1.00      1.00        35
         6Hz       1.00      1.00      1.00        29
        24Hz       1.00      1.00      1.00        35
        30Hz       1.00      1.00      1.00        35

    accuracy                           1.00       134
   macro avg       1.00      1.00      1.00       134
weighted avg       1.00      1.00      1.00       134

Confusion matrix 
[[35  0  0  0]
 [ 0 29  0  0]
 [ 0  0 35  0]
 [ 0  0  0 35]]
Classification TEST DATA 
              precision    recall  f1-score   support

        12Hz       0.77      0.77      0.77        13
         6Hz       0.88      0.74      0.80        19
        24Hz       0.42      0.62      0.50        13
        30Hz       0.80      0.62      0.70        13

    accuracy                           0.69        58
   macro avg       0.72     

## 4.3 KNN

In [12]:
knn = KNeighborsClassifier(n_neighbors= 10, weights = "uniform")

param_grid = {
    'n_neighbors': [1, 3, 5, 7, 9],
    'weights': ['uniform', 'distance'],
    'metric': ['euclidean', 'manhattan', 'chebyshev']
}


cv_splitter = KFold(n_splits=5, shuffle=True, random_state=SEED)
tuned_clf_knn = GridSearchCV(estimator=knn, param_grid=param_grid,
                    scoring='accuracy', refit='accuracy', cv=cv_splitter)

tuned_clf_knn.fit(x_train, y_train)

print(f"Best parameters: {tuned_clf_knn.best_params_}")
print(f"Best cross-validation score: {tuned_clf_knn.best_score_:.3f}")

label_names = ['12Hz', '6Hz', '24Hz', '30Hz']

with open("trained_model/KNN_model.pkl", "wb") as file:
    pickle.dump(tuned_clf_knn, file)

GetConfusionMatrix(tuned_clf_knn, x_train, x_test, y_train, y_test, label_names)

Best parameters: {'metric': 'manhattan', 'n_neighbors': 9, 'weights': 'uniform'}
Best cross-validation score: 0.635
Classification TRAIN DATA 
              precision    recall  f1-score   support

        12Hz       0.74      0.74      0.74        35
         6Hz       1.00      0.31      0.47        29
        24Hz       0.71      0.71      0.71        35
        30Hz       0.62      0.97      0.76        35

    accuracy                           0.70       134
   macro avg       0.77      0.68      0.67       134
weighted avg       0.76      0.70      0.68       134

Confusion matrix 
[[26  0  3  6]
 [ 3  9  7 10]
 [ 5  0 25  5]
 [ 1  0  0 34]]
Classification TEST DATA 
              precision    recall  f1-score   support

        12Hz       0.92      0.85      0.88        13
         6Hz       0.75      0.16      0.26        19
        24Hz       0.50      0.77      0.61        13
        30Hz       0.45      0.77      0.57        13

    accuracy                           0.59  