In [None]:
!pip install numpy
!pip install pandas
!pip install scipy
!pip install scipy
# !pip install sklearn
!pip install matplotlib
!pip install splearn
!pip install -U scikit-learn



In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import numpy as np
import scipy.io as sio
from typing import Tuple
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt, sosfiltfilt, freqz
from splearn.fourier import fast_fourier_transform
from sklearn.cross_decomposition import CCA

from scipy.signal import filtfilt, iirnotch
import scipy
from scipy.stats import pearsonr, mode


In [None]:
class Benchmark():
    """
    [Description of the Benchmark dataset]

    """


    _CHANNELS = [
        'FP1','FPZ','FP2','AF3','AF4','F7','F5','F3','F1','FZ','F2',
        'F4','F6','F8','FT7','FC5','FC3','FC1','FCz','FC2','FC4','FC6',
        'FT8','T7','C5','C3','C1','Cz','C2','C4','C6','T8','M1','TP7',
        'CP5','CP3','CP1','CPZ','CP2','CP4','CP6','TP8','M2','P7','P5',
        'P3','P1','PZ','P2','P4','P6','P8','PO7','PO5','PO3','POz','PO4',
        'PO6','PO8','CB1','O1','Oz','O2','CB2'
    ]

    _FREQS = [
        8, 9, 10, 11, 12, 13, 14, 15,
        8.2, 9.2, 10.2, 11.2, 12.2, 13.2, 14.2, 15.2,
        8.4, 9.4, 10.4, 11.4, 12.4, 13.4, 14.4, 15.4,
        8.6, 9.6, 10.6, 11.6, 12.6, 13.6, 14.6, 15.6,
        8.8, 9.8, 10.8, 11.8, 12.8, 13.8, 14.8, 15.8
    ]

    _PHASES = [
        0, 0.5, 1, 1.5, 0, 0.5, 1, 1.5,
        0.5, 1, 1.5, 0, 0.5, 1, 1.5, 0,
        1, 1.5, 0, 0.5, 1, 1.5, 0, 0.5,
        1.5, 0, 0.5, 1, 1.5, 0, 0.5, 1,
        0, 0.5, 1, 1.5, 0, 0.5, 1, 1.5
    ]
    def __init__(self, root: str, subject_id: int, verbose: bool = False, file_prefix='S') -> None:
        self.root = root


        self.data, self.targets, self.channel_names = _load_data(self.root, subject_id, verbose, file_prefix)
        self.sampling_rate = 250
        self.stimulus_frequencies = np.array([
        8, 9, 10, 11, 12, 13, 14, 15,
        8.2, 9.2, 10.2, 11.2, 12.2, 13.2, 14.2, 15.2,
        8.4, 9.4, 10.4, 11.4, 12.4, 13.4, 14.4, 15.4,
        8.6, 9.6, 10.6, 11.6, 12.6, 13.6, 14.6, 15.6,
        8.8, 9.8, 10.8, 11.8, 12.8, 13.8, 14.8, 15.8
    ])  # list of stimulus frequencies
        self.targets_frequencies = self.stimulus_frequencies[self.targets]

    def __getitem__(self, n: int) -> Tuple[np.ndarray, int]:
        return (self.data[n], self.targets[n])

    def __len__(self) -> int:
        return len(self.data)

def _load_data(root, subject_id, verbose, file_prefix='S'):
    path = os.path.join(root, file_prefix+str(subject_id).zfill(2)+'.mat')  # Ensure subject ID is zero-padded
    data_mat = sio.loadmat(path)

    raw_data = data_mat['data'].copy()
    raw_data = np.transpose(raw_data, (2,3,0,1))

    data = []
    targets = []
    for target_id in range(raw_data.shape[0]):
        data.extend(raw_data[target_id])
        this_target = [target_id] * raw_data.shape[1]
        targets.extend(this_target)

    data = np.array(data)[:,:,0:1501]  # Adjust indexing as needed
    targets = np.array(targets)

    channel_names =  [
        'FP1','FPZ','FP2','AF3','AF4','F7','F5','F3','F1','FZ','F2',
        'F4','F6','F8','FT7','FC5','FC3','FC1','FCz','FC2','FC4','FC6',
        'FT8','T7','C5','C3','C1','Cz','C2','C4','C6','T8','M1','TP7',
        'CP5','CP3','CP1','CPZ','CP2','CP4','CP6','TP8','M2','P7','P5',
        'P3','P1','PZ','P2','P4','P6','P8','PO7','PO5','PO3','POz','PO4',
        'PO6','PO8','CB1','O1','Oz','O2','CB2'
    ] # List of channel names

    if verbose:
        print('Load path:', path)
        print('Data shape:', data.shape)
        print('Targets shape:', targets.shape)

    return data, targets, channel_names


In [None]:
def butter_bandpass_filter(signal, lowcut, highcut, sampling_rate, order=4, verbose=False):
    r"""
    Digital filter bandpass zero-phase implementation (filtfilt)
    Apply a digital filter forward and backward to a signal
    Args:
        signal : ndarray, shape (trial,channel,time)
            Input signal by trials in time domain
        lowcut : int
            Lower bound filter
        highcut : int
            Upper bound filter
        sampling_rate : int
            Sampling frequency
        order : int, default: 4
            Order of the filter
        verbose : boolean, default: False
            Print and plot details
    Returns:
        y : ndarray
            Filter signal
    """
    sos = _butter_bandpass(lowcut, highcut, sampling_rate, order=order, output='sos')
    y = sosfiltfilt(sos, signal, axis=-1)

    if verbose:
        tmp_x = signal[0, 0]
        tmp_y = y[0, 0]

        # time domain
        plt.plot(tmp_x, label='signal')
        plt.show()

        plt.plot(tmp_y, label='Filtered')
        plt.show()

        # freq domain
        lower_xlim = lowcut-10 if (lowcut-10) > 0 else 0
        fast_fourier_transform(
            tmp_x, sampling_rate, plot=True, plot_xlim=[lower_xlim, highcut+20], plot_label='Signal')
        fast_fourier_transform(
            tmp_y, sampling_rate, plot=True, plot_xlim=[lower_xlim, highcut+20], plot_label='Filtered')

        plt.xlim([lower_xlim, highcut+20])
        plt.ylim([0, 2])
        plt.legend()
        plt.xlabel('Frequency (Hz)')
        plt.show()

        print('Input: Signal shape', signal.shape)
        print('Output: Signal shape', y.shape)

    return y

def butter_bandpass_filter_signal_1d(signal, lowcut, highcut, sampling_rate, order=4, verbose=False):
    r"""
    Digital filter bandpass zero-phase implementation (filtfilt)
    Apply a digital filter forward and backward to a signal
    Args:
        signal : ndarray, shape (time,)
            Single input signal in time domain
        lowcut : int
            Lower bound filter
        highcut : int
            Upper bound filter
        sampling_rate : int
            Sampling frequency
        order : int, default: 4
            Order of the filter
        verbose : boolean, default: False
            Print and plot details
    Returns:
        y : ndarray
            Filter signal
    """
    b, a = _butter_bandpass(lowcut, highcut, sampling_rate, order)
    y = filtfilt(b, a, signal)

    if verbose:
        w, h = freqz(b, a)
        plt.plot((sampling_rate * 0.5 / np.pi) * w,
                 abs(h), label="order = %d" % order)
        plt.plot([0, 0.5 * sampling_rate], [np.sqrt(0.5), np.sqrt(0.5)],
                 '--', label='sqrt(0.5)')
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Gain')
        plt.grid(True)
        plt.legend(loc='best')
        low = max(0, lowcut-(sampling_rate/100))
        high = highcut+(sampling_rate/100)
        plt.xlim([low, high])
        plt.ylim([0, 1.2])
        plt.title('Frequency response of filter - lowcut:' +
                  str(lowcut)+', highcut:'+str(highcut))
        plt.show()

        # TIME
        plt.plot(signal, label='Signal')
        plt.title('Signal')
        plt.show()

        plt.plot(y, label='Filtered')
        plt.title('Bandpass filtered')
        plt.show()

        # FREQ
        lower_xlim = lowcut-10 if (lowcut-10) > 0 else 0
        fast_fourier_transform(
            signal, sampling_rate, plot=True, plot_xlim=[lower_xlim, highcut+20], plot_label='Signal')
        fast_fourier_transform(
            y, sampling_rate, plot=True, plot_xlim=[lower_xlim, highcut+20], plot_label='Filtered')

        plt.xlim([lower_xlim, highcut+20])
        plt.ylim([0, 2])
        plt.legend()
        plt.xlabel('Frequency (Hz)')
        plt.show()

        print('Input: Signal shape', signal.shape)
        print('Output: Signal shape', y.shape)

    return y

def _butter_bandpass(lowcut, highcut, sampling_rate, order=4, output='ba'):
    r"""
    Create a Butterworth bandpass filter
    Design an Nth-order digital or analog Butterworth filter and return the filter coefficients.
    Args:
        lowcut : int
            Lower bound filter
        highcut : int
            Upper bound filter
        sampling_rate : int
            Sampling frequency
        order : int, default: 4
            Order of the filter
        output : string, default: ba
            Type of output {‘ba’, ‘zpk’, ‘sos’}
    Returns:
        butter : ndarray
            Butterworth filter
    Dependencies:
        butter : scipy.signal.butter
    """
    nyq = sampling_rate * 0.5
    low = lowcut / nyq
    high = highcut / nyq
    return butter(order, [low, high], btype='bandpass', output=output)


In [None]:
def cca_spatial_filtering(signal, reference_frequencies, n_components=10):
    r"""
    Use CCA for spatial filtering is to find a spatial filter that maximizes the correlation between the spatially filtered signal and the average evoked response, thereby improving the signal-to-noise ratio of the filtered signal on a single-trial basis.
    Read more: https://github.com/jinglescode/papers/issues/90, https://github.com/jinglescode/papers/issues/89
    Args:
        signal : ndarray, shape (trial,channel,time)
            Input signal in time domain
        reference_frequencies : ndarray, shape (len(flick_freq),2*num_harmonics,time)
            Required sinusoidal reference templates corresponding to the flicker frequency for SSVEP classification
    Returns:
        filtered_signal : ndarray, shape (reference_frequencies.shape[0],signal.shape[0],signal.shape[1],signal.shape[2])
            Signal after spatial filter
    Dependencies:
        np : numpy package
        perform_cca : function
    """

    filtered_signal = np.zeros((reference_frequencies.shape[0], signal.shape[0], signal.shape[1], signal.shape[2]))
    cca = CCA(n_components=n_components, max_iter = 20)

    for target_i in range(reference_frequencies.shape[0]):
        for trial_i in range(signal.shape[0]):
            # Reshape signal and reference for CCA
            X = signal[trial_i].T
            Y = reference_frequencies[target_i].T

            # Perform CCA
            cca.fit(X, Y)
            X_c, _ = cca.transform(X, Y)

            # Reshape the spatially filtered signal back
            filtered_signal[target_i, trial_i, :, :] = X_c.T

    return filtered_signal



In [None]:
def pick_channels(data: np.ndarray,
                  channel_names: [str],
                  selected_channels: [str],
                  verbose: bool = False) -> np.ndarray:

    picked_ch = pick_channels_mne(channel_names, selected_channels)

    if len(data.shape) == 3:
        data = data[:, picked_ch, :]
    if len(data.shape) == 4:
        data = data[:, :, picked_ch, :]

    if verbose:
        print('picking channels: channel_names',
              len(channel_names), channel_names)
        print('picked_ch', picked_ch)
        print()

    del picked_ch

    return data


def pick_channels_mne(ch_names, include, exclude=[], ordered=False):
    """Pick channels by names.
    Returns the indices of ``ch_names`` in ``include`` but not in ``exclude``.
    Taken from https://github.com/mne-tools/mne-python/blob/master/mne/io/pick.py
    Parameters
    ----------
    ch_names : list of str
        List of channels.
    include : list of str
        List of channels to include (if empty include all available).
        .. note:: This is to be treated as a set. The order of this list
           is not used or maintained in ``sel``.
    exclude : list of str
        List of channels to exclude (if empty do not exclude any channel).
        Defaults to [].
    ordered : bool
        If true (default False), treat ``include`` as an ordered list
        rather than a set, and any channels from ``include`` are missing
        in ``ch_names`` an error will be raised.
        .. versionadded:: 0.18
    Returns
    -------
    sel : array of int
        Indices of good channels.
    See Also
    --------
    pick_channels_regexp, pick_types
    """
    if len(np.unique(ch_names)) != len(ch_names):
        raise RuntimeError('ch_names is not a unique list, picking is unsafe')
    # _check_excludes_includes(include)
    # _check_excludes_includes(exclude)
    if not ordered:
        if not isinstance(include, set):
            include = set(include)
        if not isinstance(exclude, set):
            exclude = set(exclude)
        sel = []
        for k, name in enumerate(ch_names):
            if (len(include) == 0 or name in include) and name not in exclude:
                sel.append(k)
    else:
        if not isinstance(include, list):
            include = list(include)
        if len(include) == 0:
            include = list(ch_names)
        if not isinstance(exclude, list):
            exclude = list(exclude)
        sel, missing = list(), list()
        for name in include:
            if name in ch_names:
                if name not in exclude:
                    sel.append(ch_names.index(name))
            else:
                missing.append(name)
        if len(missing):
            raise ValueError('Missing channels from ch_names required by '
                             'include:\n%s' % (missing,))
    return np.array(sel, int)

In [None]:
def notch_filter(data, sampling_rate=250, notch_freq=50.0, quality_factor=30.0):
    b_notch, a_notch = iirnotch(notch_freq, quality_factor, sampling_rate)
    data_notched = filtfilt(b_notch, a_notch, data)
    return data_notched

In [None]:
from sklearn.decomposition import FastICA
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# print(picked_data.shape)

import numpy as np
from sklearn.decomposition import FastICA

def perform_ica_decomposition(eeg_data, num_components=10):
    trials, channels, time_points = eeg_data.shape
    reshaped_data = eeg_data.reshape((trials * time_points, channels))
    ica = FastICA(n_components=num_components, random_state=0)
    ica.fit(reshaped_data)
    ica_signals = ica.transform(reshaped_data)
    return ica, ica_signals

def plot_ica_components(ica, channels, n_components=10, figsize=(15, 6)):
    plt.figure(figsize=figsize)
    for i in range(n_components):
        plt.subplot(n_components // 5 + 1, 5, i + 1)
        plt.plot(ica.components_[i])
        plt.title(f'Component {i}')
        plt.xticks([])
        plt.yticks([])
    plt.tight_layout()
    plt.show()

import numpy as np
from scipy.signal import welch
import matplotlib.pyplot as plt

def plot_ica_psd(ica_signals, sfreq, n_components=10):
    """
    Plot the power spectral density (PSD) of ICA components.

    Parameters:
        ica_signals (numpy.ndarray): ICA components with dimensions (samples, components).
        sfreq (float): Sampling frequency of the data.
n_components (int): Number of components to plot.
"""
    # Define the frequency range for the PSD plot
    fmin, fmax = 0, sfreq / 2

    plt.figure(figsize=(15, 6))
    for i in range(n_components):
        # Calculate the PSD using Welch's method
        freqs, psd = welch(ica_signals[:, i], sfreq, nperseg=2048)

        # Plot the PSD
        plt.subplot(n_components // 5 + 1, 5, i + 1)
        plt.semilogy(freqs, psd, label=f'Component {i}')
        plt.title(f'Component {i}')
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Power Spectral Density (dB/Hz)')
        plt.xlim([fmin, fmax])
        plt.legend()

    plt.tight_layout()
    plt.show()



def remove_artifacts(eeg_data, ica, ica_signals, artifact_indices):
    trials, channels, time_points = eeg_data.shape
    ica_signals[:, artifact_indices] = 0
    cleaned_data = ica.inverse_transform(ica_signals)
    cleaned_eeg_data = cleaned_data.reshape((trials, channels, time_points))
    return cleaned_eeg_data

# # Step 1: Perform ICA Decomposition
# ica, ica_signals = perform_ica_decomposition(picked_data, num_components=20)

# # Step 2: Plot ICA Components for Visual Inspection
# plot_ica_components(ica, selected_channels)
# plot_ica_psd(ica_signals, 250)

# # Manually identify artifact indices from plots
# artifact_indices = []

# # Step 3: Remove Artifacts
# cleaned_eeg_data = remove_artifacts(picked_data, ica, ica_signals, artifact_indices)
# print(cleaned_eeg_data.shape)

In [None]:
def create_reference_signals(frequencies, num_harmonics,sample_length, sampling_rate):
  t = np.linspace(0, sample_length / sampling_rate, sample_length, endpoint=False)
  reference_signals = np.zeros((len(frequencies), 2 * num_harmonics, sample_length))

  for i, freq in enumerate(frequencies):
      for h in range(1, num_harmonics + 1):
          reference_signals[i, 2 * (h - 1), :] = np.sin(2 * np.pi * freq * h * t)
          reference_signals[i, 2 * (h - 1) + 1, :] = np.cos(2 * np.pi * freq * h * t)

  return reference_signals

In [None]:
def filterbank(eeg, fs, idx_fb):
    if idx_fb == None:
        warnings.warn('stats:filterbank:MissingInput '\
                          +'Missing filter index. Default value (idx_fb = 0) will be used.')
        idx_fb = 0
    elif (idx_fb < 0 or 9 < idx_fb):
        raise ValueError('stats:filterbank:InvalidInput '\
                          +'The number of sub-bands must be 0 <= idx_fb <= 9.')

    if (len(eeg.shape)==2):
        num_chans = eeg.shape[0]
        num_trials = 1
    else:
        _, num_chans, num_trials = eeg.shape

    # Nyquist Frequency = Fs/2N
    Nq = fs/2

    # passband = [6, 14, 22, 30, 38, 46, 54, 62, 70, 78]
    # stopband = [4, 10, 16, 24, 32, 40, 48, 56, 64, 72]

    # Passbands for each sub-band, starting at 8 Hz and ending at n × 8 Hz, up to the 8th harmonic
    passband = [8, 16, 24, 32, 40, 48, 56, 64, 72]

    # Assuming you want the stopbands to be slightly wider than the passbands
    stopband = [7, 15, 23, 31, 39, 47, 55, 63, 71]

    Wp = [passband[idx_fb]/Nq, 90/Nq]

    #print("Wp: ", Wp)
    Ws = [stopband[idx_fb]/Nq, 100/Nq]

    #print("Ws: ", Ws)
    [N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3, 40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)]
    [B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency

    y = np.zeros(eeg.shape)

    if (num_trials == 1):
        for ch_i in range(num_chans):
            #apply filter, zero phass filtering by applying a linear filter twice, once forward and once backwards.
            # to match matlab result we need to change padding length
            y[ch_i, :] = scipy.signal.filtfilt(B, A, eeg[ch_i, :])

    else:
        for trial_i in range(num_trials):
            for ch_i in range(num_chans):
                y[:, ch_i, trial_i] = scipy.signal.filtfilt(B, A, eeg[:, ch_i, trial_i])
    return y

def cca_reference(list_freqs, fs, num_smpls, num_harms=9):
    num_freqs = len(list_freqs)
    tidx = np.arange(1, num_smpls + 1) / fs  # time index

    y_ref = np.zeros((num_freqs, 2 * num_harms, num_smpls))
    for freq_i in range(num_freqs):
        tmp = []
        for harm_i in range(1, num_harms + 1):
            stim_freq = list_freqs[freq_i]  # in HZ
            # Sin and Cos
            tmp.extend([np.sin(2 * np.pi * tidx * harm_i * stim_freq),
                        np.cos(2 * np.pi * tidx * harm_i * stim_freq)])
        y_ref[freq_i] = tmp  # 2*num_harms because include both sin and cos
    print(y_ref.shape)
    return y_ref

def fbcca(eeg, list_freqs, fs, num_harms=2, num_fbs=9):
    fb_coefs = np.power(np.arange(1, num_fbs + 1), (-1.25)) + 0.25

    num_targs = len(list_freqs)
    events, channels, num_smpls = eeg.shape
    y_ref = cca_reference(list_freqs, fs, num_smpls, num_harms)
    cca = CCA(n_components=1)

    r = np.zeros((num_fbs, num_targs))
    r_mode = []
    r_corr_avg = []
    results = []

    # print("Filter Bank Coefficients (fb_coefs):", fb_coefs)

    for event in range(events):
        print(f"Processing event {event + 1}/{events}")
        test_tmp = np.squeeze(eeg[event, :, :])
        for fb_i in range(num_fbs):
            # print(f"  Using filter bank {fb_i + 1}/{num_fbs}")
            r_row = []
            for class_i in range(num_targs):
                testdata = filterbank(test_tmp, fs, fb_i)
                refdata = np.squeeze(y_ref[class_i, :, :])
                test_C, ref_C = cca.fit_transform(testdata.T, refdata.T)
                r_tmp, _ = pearsonr(np.squeeze(test_C), np.squeeze(ref_C))
                if np.isnan(r_tmp) or r_tmp < 0:
                    r_tmp = 0
                r[fb_i, class_i] = r_tmp
                r_row.append(r_tmp)
                # print(f"    Correlation for frequency {list_freqs[class_i]}: {r_tmp}")

            # Print correlations for current filter bank
            # print(f"  Correlations for filter bank {fb_i + 1}: {r_row}")

        rho = np.dot(fb_coefs, r)  # weighted sum of r from all different filter banks' result
        result = np.argmax(rho)
        print(f"Result for event {event + 1}: Frequency {list_freqs[result]}, Correlation: {rho[result]}")
        results.append(list_freqs[result])

        r_mode.append(result)
        r_corr_avg.append(rho[result])

    # most_recurrent_class = mode(r_mode)[0][0]
    # average_correlation = np.mean(r_corr_avg)
    # print("====Most recurrent class: ====", most_recurrent_class)
    # print("====Average correlation: =====", average_correlation)

    return results



In [None]:
root_dir = 'drive/MyDrive/BCI_data/SSVEP/'

# Number of subjects
total_subjects = 5

# Initialize a dictionary to store the datasets for each subject
datasets = {}

# Loop over each subject
for subject_id in range(1, total_subjects + 1):
    # Load the dataset for the current subject
    dataset = Benchmark(root=root_dir, subject_id=subject_id, verbose=True)

    # Store the dataset shape in the dictionary with subject ID as the key
    datasets[str(subject_id)] = dataset.data.shape



Load path: drive/MyDrive/BCI_data/SSVEP/S01.mat
Data shape: (240, 64, 1500)
Targets shape: (240,)
Load path: drive/MyDrive/BCI_data/SSVEP/S02.mat
Data shape: (240, 64, 1500)
Targets shape: (240,)
Load path: drive/MyDrive/BCI_data/SSVEP/S03.mat
Data shape: (240, 64, 1500)
Targets shape: (240,)
Load path: drive/MyDrive/BCI_data/SSVEP/S04.mat
Data shape: (240, 64, 1500)
Targets shape: (240,)
Load path: drive/MyDrive/BCI_data/SSVEP/S05.mat
Data shape: (240, 64, 1500)
Targets shape: (240,)


In [None]:
selected_channels = np.array(['PO5','PO3','POz','PO4',
        'PO6','Pz','O1','Oz','O2'])


for subject_id, dataset_shape in datasets.items():
    # Load the dataset for the current subject
    dataset = Benchmark(root=root_dir, subject_id=int(subject_id), verbose=False)

    # Apply the bandpass filter to the dataset
    filtered_data = butter_bandpass_filter(
        dataset.data, 7, 90, dataset.sampling_rate, 20)
    picked_data = pick_channels(filtered_data, dataset.channel_names, selected_channels)
    # Update the data in the dataset object
    dataset.data = picked_data

    # Update the datasets dictionary with the modified dataset
    datasets[subject_id] = dataset

    # Print information about the dataset
    print(f"Subject ID: {subject_id}")
    print(f"Original Data Shape: {dataset_shape}")
    print(f"Filtered Data Shape: {dataset.data.shape}")
    print(f"Number of Trials: {len(dataset)}")
    print(f"Example Trial Shape: {dataset[0][0].shape}")
    print("-" * 30)


Subject ID: 1
Original Data Shape: (240, 64, 1500)
Filtered Data Shape: (240, 8, 1500)
Number of Trials: 240
Example Trial Shape: (8, 1500)
------------------------------
Subject ID: 2
Original Data Shape: (240, 64, 1500)
Filtered Data Shape: (240, 8, 1500)
Number of Trials: 240
Example Trial Shape: (8, 1500)
------------------------------
Subject ID: 3
Original Data Shape: (240, 64, 1500)
Filtered Data Shape: (240, 8, 1500)
Number of Trials: 240
Example Trial Shape: (8, 1500)
------------------------------
Subject ID: 4
Original Data Shape: (240, 64, 1500)
Filtered Data Shape: (240, 8, 1500)
Number of Trials: 240
Example Trial Shape: (8, 1500)
------------------------------
Subject ID: 5
Original Data Shape: (240, 64, 1500)
Filtered Data Shape: (240, 8, 1500)
Number of Trials: 240
Example Trial Shape: (8, 1500)
------------------------------


In [None]:
def cca_reference(list_freqs, fs, num_smpls, num_harms=2):
    num_freqs = len(list_freqs)
    tidx = np.arange(1, num_smpls + 1) / fs  # time index

    y_ref = np.zeros((num_freqs, 2 * num_harms, num_smpls))
    for freq_i in range(num_freqs):
        tmp = []
        for harm_i in range(1, num_harms + 1):
            stim_freq = list_freqs[freq_i]  # in Hz
            # Sin and Cos
            tmp.extend([np.sin(2 * np.pi * tidx * harm_i * stim_freq),
                        np.cos(2 * np.pi * tidx * harm_i * stim_freq)])
        y_ref[freq_i] = tmp  # 2*num_harms because include both sin and cos
    return y_ref

def simple_cca(eeg, list_freqs, fs, num_harms=2):
    num_targs = len(list_freqs)
    events, channels, num_smpls = eeg.shape
    y_ref = cca_reference(list_freqs, fs, num_smpls, num_harms)
    cca = CCA(n_components=1)

    results = []

    for event in range(events):
        print(f"Processing event {event + 1}/{events}")
        test_tmp = np.squeeze(eeg[event, :, :])
        r_corr_avg = []
        for class_i in range(num_targs):
            refdata = np.squeeze(y_ref[class_i, :, :])
            test_C, ref_C = cca.fit_transform(test_tmp.T, refdata.T)
            r_tmp, _ = pearsonr(np.squeeze(test_C), np.squeeze(ref_C))
            if np.isnan(r_tmp) or r_tmp < 0:
                r_tmp = 0
            r_corr_avg.append(r_tmp)
            # print(f"Correlation for frequency {list_freqs[class_i]}: {r_tmp}")

        # Find the frequency with the highest correlation for this event
        max_corr = max(r_corr_avg)
        max_corr_idx = r_corr_avg.index(max_corr)
        results.append(max_corr_idx)
        print(f"Highest correlation for event {event + 1}: Frequency {list_freqs[max_corr_idx]} with correlation {max_corr}\n")

    print("All events processed.")
    return results

In [None]:
def calculate_accuracy(ground_truth_labels, predicted_labels):
    # Calculate the number of correct predictions
    correct_predictions = sum(1 for true, pred in zip(ground_truth_labels, predicted_labels) if true == pred)

    # Calculate the accuracy
    accuracy = correct_predictions / len(ground_truth_labels)

    return accuracy

In [None]:
def calculate_itr(num_targets, accuracy, selection_time):
    if num_targets <= 1 or accuracy <= 0 or selection_time <= 0:
        raise ValueError("Invalid input values for ITR calculation.")
    if accuracy > 1:
        raise ValueError("Accuracy cannot be greater than 1.")

    # Calculate terms in the ITR formula
    term1 = np.log2(num_targets)

    if accuracy == 1:
    # If accuracy is 100%, term2 and term3 are not needed
        itr = term1 * (60 / selection_time)
    else:
      term2 = accuracy * np.log2(accuracy)
      term3 = (1 - accuracy) * np.log2((1 - accuracy) / (num_targets - 1))

    # Calculate ITR in bits per minute
      itr = (term1 + term2 + term3) * (60 / selection_time)
    return itr


In [None]:
frequencies = np.array([
        8, 9, 10, 11, 12, 13, 14, 15,
        8.2, 9.2, 10.2, 11.2, 12.2, 13.2, 14.2, 15.2,
        8.4, 9.4, 10.4, 11.4, 12.4, 13.4, 14.4, 15.4,
        8.6, 9.6, 10.6, 11.6, 12.6, 13.6, 14.6, 15.6,
        8.8, 9.8, 10.8, 11.8, 12.8, 13.8, 14.8, 15.8
    ])  # define your reference frequencies

num_harmonics = 4 # Adjust as needed
sample_length = 1500 # Length of each EEG data sample
fs = 250  # Sampling frequency in Hz
# eeg_data = picked_data

reference_frequencies = create_reference_signals(frequencies, num_harmonics, sample_length, fs)
print((reference_frequencies.shape))
# print(cleaned_eeg_data.shape)

(40, 8, 1500)


In [None]:
num_fbs = 9

fbcca_results = {}
for subject_id in range(1, total_subjects + 1):
    dataset = Benchmark(root=root_dir, subject_id=subject_id, verbose=True)
    filtered_data = butter_bandpass_filter(
        dataset.data, 7, 90, dataset.sampling_rate, 20)
    picked_data = pick_channels(filtered_data, dataset.channel_names, selected_channels)
    print(picked_data.shape)
    # Update the data in the dataset object
    datasets[str(subject_id)]= picked_data.shape
    # Run FBCCA on the current dataset
    results = fbcca(picked_data, frequencies, fs, num_harmonics, num_fbs)

    fbcca_results[str(subject_id)] = results

Load path: drive/MyDrive/BCI_data/SSVEP/S01.mat
Data shape: (240, 64, 1500)
Targets shape: (240,)
(240, 8, 1500)
(40, 8, 1500)
Processing event 1/240
Result for event 1: Frequency 8.0, Correlation: 1.4640639785974936
Processing event 2/240
Result for event 2: Frequency 8.0, Correlation: 1.2728500744936342
Processing event 3/240
Result for event 3: Frequency 8.0, Correlation: 1.3736643836141174
Processing event 4/240
Result for event 4: Frequency 8.0, Correlation: 1.3285256199848734
Processing event 5/240
Result for event 5: Frequency 8.0, Correlation: 1.4528175206539669
Processing event 6/240
Result for event 6: Frequency 8.0, Correlation: 1.3682048117161438
Processing event 7/240
Result for event 7: Frequency 9.0, Correlation: 1.566364749084168
Processing event 8/240
Result for event 8: Frequency 9.0, Correlation: 1.4953350408279469
Processing event 9/240
Result for event 9: Frequency 9.0, Correlation: 1.4853009181271413
Processing event 10/240
Result for event 10: Frequency 9.0, Corr



Result for event 94: Frequency 15.2, Correlation: 2.0038172728324093
Processing event 95/240
Result for event 95: Frequency 15.2, Correlation: 1.8844916215039833
Processing event 96/240
Result for event 96: Frequency 15.2, Correlation: 1.7046561892580128
Processing event 97/240
Result for event 97: Frequency 8.4, Correlation: 1.6972250382024823
Processing event 98/240
Result for event 98: Frequency 8.4, Correlation: 1.837784891723868
Processing event 99/240
Result for event 99: Frequency 8.4, Correlation: 1.8380750026984396
Processing event 100/240
Result for event 100: Frequency 8.4, Correlation: 1.809914609356651
Processing event 101/240
Result for event 101: Frequency 8.4, Correlation: 1.8319048268653828
Processing event 102/240
Result for event 102: Frequency 8.4, Correlation: 1.6609386246090052
Processing event 103/240
Result for event 103: Frequency 9.4, Correlation: 1.7200276299982393
Processing event 104/240
Result for event 104: Frequency 9.4, Correlation: 1.6717192700928591
P

In [None]:
cca_results = {}
for subject_id in range(1, total_subjects + 1):
    dataset = Benchmark(root=root_dir, subject_id=subject_id, verbose=True)
    filtered_data = butter_bandpass_filter(
        dataset.data, 7, 90, dataset.sampling_rate, 20)
    picked_data = pick_channels(filtered_data, dataset.channel_names, selected_channels)
    print(picked_data.shape)
    # Update the data in the dataset object
    datasets[str(subject_id)]= picked_data.shape
    # Run FBCCA on the current dataset
    results = simple_cca(picked_data, frequencies, fs, num_harmonics)

    cca_results[str(subject_id)] = results

[1;30;43mStreaming output truncated to the last 5000 lines.[0m

Processing event 25/240
Highest correlation for event 25: Frequency 12.0 with correlation 0.35826019892814476

Processing event 26/240
Highest correlation for event 26: Frequency 12.0 with correlation 0.4169190441192555

Processing event 27/240
Highest correlation for event 27: Frequency 12.0 with correlation 0.3577245917288777

Processing event 28/240
Highest correlation for event 28: Frequency 12.0 with correlation 0.4277253185605591

Processing event 29/240
Highest correlation for event 29: Frequency 12.0 with correlation 0.3173152556993407

Processing event 30/240
Highest correlation for event 30: Frequency 11.0 with correlation 0.4434598335095645

Processing event 31/240
Highest correlation for event 31: Frequency 13.0 with correlation 0.4194441017015418

Processing event 32/240
Highest correlation for event 32: Frequency 13.0 with correlation 0.3678943688991436

Processing event 33/240
Highest correlation for event

In [None]:
all_results_cca = []

for subject_id, results in cca_results.items():
    all_results_cca.extend(results)
    print(f"Subject {subject_id}: {results}")

print("All results:", all_results_cca)

Subject 1: [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 18, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 26, 14, 14, 14, 14, 14, 14, 15, 26, 15, 15, 15, 15, 16, 16, 16, 10, 16, 16, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 20, 20, 20, 10, 20, 20, 21, 21, 21, 21, 21, 21, 18, 26, 18, 22, 22, 22, 23, 23, 18, 23, 23, 23, 24, 24, 24, 10, 24, 24, 25, 25, 25, 25, 25, 25, 26, 26, 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 29, 29, 29, 29, 29, 29, 30, 26, 18, 18, 30, 30, 18, 18, 31, 31, 31, 4, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 34, 34, 34, 34, 34, 34, 35, 35, 35, 35, 35, 35, 36, 36, 36, 36, 36, 36, 37, 37, 37, 37, 37, 37, 38, 38, 38, 38, 38, 26, 39, 39, 39, 39, 26, 39]
Subject 2: [0, 0, 0, 0, 0, 0, 1, 1, 34, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 

In [None]:
all_results_fbcca = []

for subject_id, results in fbcca_results.items():
    all_results_fbcca.extend(results)
    print(f"Subject {subject_id}: {results}")

print("All results:", all_results_fbcca)

Subject 1: [8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 14.0, 14.0, 14.0, 14.0, 14.0, 14.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 8.2, 8.2, 8.2, 8.2, 8.2, 8.2, 9.2, 9.2, 9.2, 9.2, 9.2, 9.2, 10.2, 10.2, 10.2, 10.2, 10.2, 10.2, 11.2, 11.2, 11.2, 11.2, 11.2, 11.2, 12.2, 12.2, 12.2, 12.2, 12.2, 12.2, 13.2, 13.2, 13.2, 13.2, 13.2, 13.2, 14.2, 14.2, 14.2, 14.2, 14.2, 14.2, 15.2, 15.2, 15.2, 15.2, 15.2, 15.2, 8.4, 8.4, 8.4, 8.4, 8.4, 8.4, 9.4, 9.4, 9.4, 9.4, 9.4, 9.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 11.4, 11.4, 11.4, 11.4, 11.4, 11.4, 12.4, 12.4, 12.4, 12.4, 12.4, 12.4, 13.4, 13.4, 13.4, 13.4, 13.4, 13.4, 14.4, 14.4, 14.4, 14.4, 14.4, 14.4, 15.4, 15.4, 15.4, 15.4, 15.4, 15.4, 8.6, 8.6, 8.6, 8.6, 8.6, 8.6, 9.6, 9.6, 9.6, 9.6, 9.6, 9.6, 10.6, 10.6, 10.6, 10.6, 10.6, 10.6, 11.6, 11.6, 11.6, 11.6, 11.6, 11.6, 12.6, 12.6, 12.6, 12.6, 12.6

In [None]:
frequencies_results = [frequencies[idx] for idx in all_results_cca]
print("Frequencies:", frequencies_results)

num_epochs_per_subject = 240
num_subjects = len(all_results_cca) // num_epochs_per_subject
num_targets = 40  # Number of choices in the BCI task
selection_time = 5  # Average time for selection in seconds

ground_truth_labels_all_subjects = [] # List with the gorund truth frequencies for each subject

for subject_index in range(num_subjects):
    start_index = subject_index * num_epochs_per_subject
    end_index = start_index + num_epochs_per_subject -1
    print(start_index, end_index)
    ground_truth_labels = ground_truth_labels_all_subjects[start_index:end_index]
    predicted_labels = frequencies_results[start_index:end_index]

    # Calculate accuracy
    accuracy = calculate_accuracy(ground_truth_labels, predicted_labels)
    print(f"Subject {subject_index + 1} Accuracy: {accuracy * 100:.2f}%")

    # Calculate ITR
    itr = calculate_itr(num_targets, accuracy, selection_time)
    print(f"Subject {subject_index + 1} ITR: {itr:.2f} bits/min")

Frequencies: [8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 13.0, 13.0, 13.0, 13.0, 13.0, 13.0, 14.0, 10.4, 14.0, 14.0, 14.0, 14.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 8.2, 8.2, 8.2, 8.2, 8.2, 8.2, 9.2, 9.2, 9.2, 9.2, 9.2, 9.2, 10.2, 10.2, 10.2, 10.2, 10.2, 10.2, 11.2, 11.2, 11.2, 11.2, 11.2, 11.2, 12.2, 12.2, 12.2, 12.2, 12.2, 12.2, 13.2, 13.2, 13.2, 13.2, 13.2, 10.6, 14.2, 14.2, 14.2, 14.2, 14.2, 14.2, 15.2, 10.6, 15.2, 15.2, 15.2, 15.2, 8.4, 8.4, 8.4, 10.2, 8.4, 8.4, 9.4, 9.4, 9.4, 9.4, 9.4, 9.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 11.4, 11.4, 11.4, 11.4, 11.4, 11.4, 12.4, 12.4, 12.4, 10.2, 12.4, 12.4, 13.4, 13.4, 13.4, 13.4, 13.4, 13.4, 10.4, 10.6, 10.4, 14.4, 14.4, 14.4, 15.4, 15.4, 10.4, 15.4, 15.4, 15.4, 8.6, 8.6, 8.6, 10.2, 8.6, 8.6, 9.6, 9.6, 9.6, 9.6, 9.6, 9.6, 10.6, 10.6, 10.6, 10.6, 10.6, 10.6, 11.6, 11.6, 11.6, 11.6, 11.6, 11.6, 12.6, 12.6, 12.6, 12.6, 

ZeroDivisionError: division by zero