In [1]:
import numpy as np
import torch
from scipy import  signal
from scipy.io import loadmat
from sklearn.cross_decomposition import CCA

In [2]:
import numpy as np
import math
from scipy import signal
from sklearn.cross_decomposition import CCA


class FBCCA:
    def __init__(self):
        super(FBCCA, self).__init__()
        self.Nh = 22
        self.Fs = 250
        self.Nf = 20
        self.ws = 4-0.14
        self.Nc = 10
        self.Nm = 8
        self.T = int(self.Fs * self.ws)

    def get_reference_signal(self, num_harmonics, targets):
        reference_signals = []
        t = np.arange(0, (self.T / self.Fs), step=1.0 / self.Fs)
        for f in targets:
            reference_f = []
            for h in range(1, num_harmonics + 1):
                reference_f.append(np.sin(2 * np.pi * h * f * t)[0:self.T])
                reference_f.append(np.cos(2 * np.pi * h * f * t)[0:self.T])
            reference_signals.append(reference_f)
        reference_signals = np.asarray(reference_signals)
        return reference_signals

    def find_correlation(self, n_components, x, y):
        cca = CCA(n_components)
        corr = np.zeros(n_components)
        num_freq = y.shape[0]
        result = np.zeros(num_freq)
        for freq_idx in range(0, num_freq):
            matched_x = x
            cca.fit(matched_x.T, y[freq_idx].T)
            x_a, y_b = cca.transform(matched_x.T, y[freq_idx].T)
            for i in range(0, n_components):
                corr[i] = np.corrcoef(x_a[:, i], y_b[:, i])[0, 1]
                result[freq_idx] = np.max(corr)
        return result

    def filter_bank(self, eeg):
        result = np.zeros((eeg.shape[0], self.Nm, eeg.shape[-2], self.T))

        nyq = self.Fs / 2

        pass_band = [6, 14, 22, 30, 38, 46, 54, 62, 70, 78]
        stop_band = [4, 10, 16, 24, 32, 40, 48, 56, 64, 72]
        high_cut_pass, high_cut_stop = 90, 100

        gpass, gstop, rp = 2, 40, 0.3

        for i in range(self.Nm):
            wp = np.array([pass_band[i] / nyq, high_cut_pass / nyq])
            ws = np.array([stop_band[i] / nyq, high_cut_stop / nyq])
            [n, wn] = signal.cheb1ord(wp, ws, gpass, gstop)
            [b, a] = signal.cheby1(n, rp, wn, 'bandpass')
            data = signal.filtfilt(b, a, eeg, padlen=3 * (max(len(b), len(a)) - 1)).copy()
            result[:, i, :, :] = data

        return result

    def classify(self, targets, test_data, num_harmonics=3):
        reference_signals = self.get_reference_signal(num_harmonics, targets)
        test_data = self.filter_bank(test_data)

        predicted_class = []
        num_segments = test_data.shape[0]

        fb_weight = [math.pow(i, -1.25) + 0.25 for i in range(1, self.Nm + 1)]
        for segment in range(0, num_segments):
            result = np.zeros(self.Nf)
            for fb_i in range(0, self.Nm):
                x = test_data[segment, fb_i]
                y = reference_signals
                w = fb_weight[fb_i]

                result += (w * (self.find_correlation(3, x, y) ** 2))
            predicted_class.append(np.argmax(result) + 1)
        predicted_class = np.array(predicted_class)
        return predicted_class


In [3]:
samples = loadmat(f'S1B1.mat')
data = samples['data1']
label = data[10, :].copy()
data = data[0:10, :]

eeg_data = np.zeros((22, 10, 3860))
indices = np.where(label == 1)
indices = np.array(indices)
for i in range(indices.size):
    eeg_data[i] = data[:, indices[0, i]+140: indices[0, i] + 4000]
eeg_data = signal.decimate(eeg_data, 4, axis=2).copy()
eeg_data = torch.from_numpy(eeg_data)

In [4]:
fbcca = FBCCA()
freqlist = np.linspace(8,13.7,20)
result = fbcca.classify(freqlist, eeg_data)
print(result)

[ 1 13 15 14 12 20 11 17 20 18  9  9  4 19 18  3 20 19 10 16 18 19]


In [5]:
samples = loadmat(f'S1B2.mat')
data = samples['data1']
label = data[10, :].copy()
data = data[0:10, :]

eeg_data = np.zeros((22, 10, 3860))
indices = np.where(label == 1)
indices = np.array(indices)
for i in range(indices.size):
    eeg_data[i] = data[:, indices[0, i]+140: indices[0, i] + 4000]
eeg_data = signal.decimate(eeg_data, 4, axis=2).copy()
eeg_data = torch.from_numpy(eeg_data)
result = fbcca.classify(freqlist, eeg_data)
print(result)

[12 14 19  5  1 14  7 11 17 10 15  3 18  4  2 16 19 20 13 13  6 18]


In [6]:
samples = loadmat(f'S2B1.mat')
data = samples['data1']
label = data[10, :].copy()
data = data[0:10, :]

eeg_data = np.zeros((22, 10, 3860))
indices = np.where(label == 1)
indices = np.array(indices)
for i in range(indices.size):
    eeg_data[i] = data[:, indices[0, i]+140: indices[0, i] + 4000]
eeg_data = signal.decimate(eeg_data, 4, axis=2).copy()
eeg_data = torch.from_numpy(eeg_data)
result = fbcca.classify(freqlist, eeg_data)
print(result)

[10 13  4 11  3  9 17 16  5 12 19 14  2  8 15 18 20  7  6  1 16  9]


In [7]:
samples = loadmat(f'S2B2.mat')
data = samples['data1']
label = data[10, :].copy()
data = data[0:10, :]

eeg_data = np.zeros((22, 10, 3860))
indices = np.where(label == 1)
indices = np.array(indices)
for i in range(indices.size):
    eeg_data[i] = data[:, indices[0, i]+140: indices[0, i] + 4000]
eeg_data = signal.decimate(eeg_data, 4, axis=2).copy()
eeg_data = torch.from_numpy(eeg_data)
result = fbcca.classify(freqlist, eeg_data)
print(result)

[13 17  4 19  2  8 12 11 14 18  7  1  6 10  1 16  3 20  9  5 16  9]


In [8]:
samples = loadmat(f'S3B1.mat')
data = samples['data1']
label = data[10, :].copy()
data = data[0:10, :]

eeg_data = np.zeros((22, 10, 3860))
indices = np.where(label == 1)
indices = np.array(indices)
for i in range(indices.size):
    eeg_data[i] = data[:, indices[0, i]+140: indices[0, i] + 4000]
eeg_data = signal.decimate(eeg_data, 4, axis=2).copy()
eeg_data = torch.from_numpy(eeg_data)
result = fbcca.classify(freqlist, eeg_data)
print(result)

[17  9 19 11  7 20 12 14  2  1  2  4 18  3 16  8  2 10 13  5 12  4]


In [9]:
samples = loadmat(f'S3B2.mat')
data = samples['data1']
label = data[10, :].copy()
data = data[0:10, :]

eeg_data = np.zeros((22, 10, 3860))
indices = np.where(label == 1)
indices = np.array(indices)
for i in range(indices.size):
    eeg_data[i] = data[:, indices[0, i]+140: indices[0, i] + 4000]
eeg_data = signal.decimate(eeg_data, 4, axis=2).copy()
eeg_data = torch.from_numpy(eeg_data)
result = fbcca.classify(freqlist, eeg_data)
print(result)

[ 3 11  1  2  5  9 19  7  6 18  4  8 16 13 14  1 12 20 17 10  6  5]


In [10]:
samples = loadmat(f'S4B1.mat')
data = samples['data1']
label = data[10, :].copy()
data = data[0:10, :]

eeg_data = np.zeros((22, 10, 3860))
indices = np.where(label == 1)
indices = np.array(indices)
for i in range(indices.size):
    eeg_data[i] = data[:, indices[0, i]+140: indices[0, i] + 4000]
eeg_data = signal.decimate(eeg_data, 4, axis=2).copy()
eeg_data = torch.from_numpy(eeg_data)
result = fbcca.classify(freqlist, eeg_data)
print(result)

[16 12  2 19 14  8 18  3 15  4  5  9  3 13 20  7 11  1  6 10 14  6]


In [11]:
samples = loadmat(f'S4B2.mat')
data = samples['data1']
label = data[10, :].copy()
data = data[0:10, :]

eeg_data = np.zeros((22, 10, 3860))
indices = np.where(label == 1)
indices = np.array(indices)
for i in range(indices.size):
    eeg_data[i] = data[:, indices[0, i]+140: indices[0, i] + 4000]
eeg_data = signal.decimate(eeg_data, 4, axis=2).copy()
eeg_data = torch.from_numpy(eeg_data)
result = fbcca.classify(freqlist, eeg_data)
print(result)

[13 20 15  7  9  6  1  2  3 11 12 14 10 18 19  5 15 17  4  8  6 14]


In [12]:
samples = loadmat(f'S5B1.mat')
data = samples['data1']
label = data[10, :].copy()
data = data[0:10, :]

eeg_data = np.zeros((22, 10, 3860))
indices = np.where(label == 1)
indices = np.array(indices)
for i in range(indices.size):
    eeg_data[i] = data[:, indices[0, i]+140: indices[0, i] + 4000]
eeg_data = signal.decimate(eeg_data, 4, axis=2).copy()
eeg_data = torch.from_numpy(eeg_data)
result = fbcca.classify(freqlist, eeg_data)
print(result)



[10 12  7 18 16 17 13  1  9  4  8  5  3 20  2 14  3 11 15  4 20 14]


In [13]:
samples = loadmat(f'S5B2.mat')
data = samples['data1']
label = data[10, :].copy()
data = data[0:10, :]

eeg_data = np.zeros((22, 10, 3860))
indices = np.where(label == 1)
indices = np.array(indices)
for i in range(indices.size):
    eeg_data[i] = data[:, indices[0, i]+140: indices[0, i] + 4000]
eeg_data = signal.decimate(eeg_data, 4, axis=2).copy()
eeg_data = torch.from_numpy(eeg_data)
result = fbcca.classify(freqlist, eeg_data)
print(result)

[ 7 16  8  2  4 13  6 18  9 12 11 19  5  6 15 10 20 14 17  1  6  6]
