# Task-Related Component Analysis

Task-related component analysis (TRCA) is a classifier, originally for steady-state visual evoked potentials (SSVEPs) detection.

Taken from the [paper](http://ieeexplore.ieee.org/document/7904641/) abstract:
> Task-related component analysis (TRCA), which can enhance reproducibility of SSVEPs across multiple trials, was employed to improve the signal-to-noise ratio (SNR) of SSVEP signals by removing background electroencephalographic (EEG) activities. An ensemble method was further developed to integrate TRCA filters corresponding to multiple stimulation frequencies.

In [None]:
#@title 
!git clone https://github.com/jinglescode/python-signal-processing.git
%cd python-signal-processing
!pip install -r requirements.txt --quiet

Cloning into 'python-signal-processing'...
remote: Enumerating objects: 198, done.[K
remote: Counting objects: 100% (198/198), done.[K
remote: Compressing objects: 100% (139/139), done.[K
remote: Total 198 (delta 111), reused 118 (delta 50), pack-reused 0[K
Receiving objects: 100% (198/198), 22.08 MiB | 36.88 MiB/s, done.
Resolving deltas: 100% (111/111), done.
/content/python-signal-processing


In [None]:
import sys
sys.path.append("..")

from splearn.cross_decomposition.trca import TRCA # https://github.com/jinglescode/python-signal-processing/blob/main/splearn/cross_decomposition/trca.py
from splearn.data.sample_ssvep import SampleSSVEPData # https://github.com/jinglescode/python-signal-processing/blob/main/splearn/data/sample_ssvep.py
from splearn.cross_validate.leave_one_out import leave_one_block_evaluation # https://github.com/jinglescode/python-signal-processing/blob/main/splearn/cross_validate.leave_one_out.py
from splearn.cross_decomposition.cca import CCA # https://github.com/jinglescode/python-signal-processing/blob/main/splearn/cross_decomposition/cca.py

import numpy as np
from sklearn.metrics import accuracy_score

## Load data

In this tutorial, we load a 40-target steady-state visual evoked potentials (SSVEP) dataset recorded from a single subject. It contains 6 blocks, each block consists of 40 trials, where each trial is a target. The electroencephalogram (EEG) signals has 9 channels and 1250 sampling points.

Read more about this dataset: https://www.pnas.org/content/early/2015/10/14/1508080112.abstract.

In [None]:
data = SampleSSVEPData()
eeg = data.get_data()
labels = data.get_targets()
print("eeg.shape:", eeg.shape)
print("labels.shape:", labels.shape)

eeg.shape: (6, 40, 9, 1250)
labels.shape: (6, 40)


## Leave-One-Block-Out cross-validation

We use the Leave-One-Block-Out cross-validation approach to determine TRCA's classification performance.

In [None]:
trca_classifier = TRCA(sampling_rate=data.sampling_rate)
test_accuracies = leave_one_block_evaluation(classifier=trca_classifier, X=eeg, Y=labels)

Block: 1 | Train acc: 100.00% | Test acc: 97.50%
Block: 2 | Train acc: 100.00% | Test acc: 100.00%


KeyboardInterrupt: 

### Comparing to CCA
Let's also test the classification performance with [CCA](https://colab.research.google.com/github/jinglescode/python-signal-processing/blob/main/tutorials/Canonical%20Correlation%20Analysis.ipynb) and compare the accuracy performance.

In [None]:
cca = CCA(
    sampling_rate=data.sampling_rate, 
    target_frequencies=data.get_stimulus_frequencies(), 
    signal_size=eeg.shape[3], 
    num_harmonics=1
)

test_accuracies = leave_one_block_evaluation(classifier=cca, X=eeg, Y=labels)

Block: 1 | Test acc: 100.00%
Block: 2 | Test acc: 100.00%
Block: 3 | Test acc: 100.00%
Block: 4 | Test acc: 100.00%
Block: 5 | Test acc: 100.00%
Block: 6 | Test acc: 100.00%
Mean test accuracy: 100.0%


Comparing the `mean test accuracy`, we can't see the difference in the classification performance between TRCA and CCA. We will use another dataset below.

## Using `.fit` and `.predict`

In this example, we select the first 2 blocks for training and the remaining 4 blocks for testing. 

In [None]:
trca_classifier = TRCA(sampling_rate=data.sampling_rate)

x_train = eeg[0:2]
y_train = labels[0:2]

blocks, targets, channels, samples = x_train.shape
x_train = x_train.reshape((blocks-1*targets, channels, samples))
y_train = y_train.reshape((blocks-1*targets))

print("Train shape:", x_train.shape, y_train.shape)
trca_classifier.fit(x_train, y_train)

for block_i in range(2, 6):

    test_x = eeg[block_i]
    test_y = labels[block_i]

    # Shuffle the test set
    arrangement = np.arange(40)
    np.random.shuffle(arrangement)
    test_x = test_x[arrangement, :,:]
    test_y = test_y[arrangement]

    # Preduct
    pred = trca_classifier.predict(test_x)
    acc = accuracy_score(test_y, pred)

    print(f'Block: {block_i+1} | accuracy: {acc*100:.2f}%')

Train shape: (80, 9, 1250) (80,)
Block: 3 | accuracy: 100.00%
Block: 4 | accuracy: 100.00%
Block: 5 | accuracy: 97.50%
Block: 6 | accuracy: 100.00%


## Another dataset, HS-SSVEP

As we can't see the difference in classification performance with the previous data, in this example we will evaluate with a single subject data taken from the [Tsinghua SSVEP benchmark dataset](https://ieeexplore.ieee.org/document/7740878).

In the following code blocks, we will download and prepare the data and labels.

In [None]:
!wget -r --no-parent ftp://anonymous@sccn.ucsd.edu/pub/ssvep_benchmark_dataset/S33.mat

--2021-08-18 13:12:13--  ftp://anonymous@sccn.ucsd.edu/pub/ssvep_benchmark_dataset/S33.mat
           => ‘sccn.ucsd.edu/pub/ssvep_benchmark_dataset/.listing’
Resolving sccn.ucsd.edu (sccn.ucsd.edu)... 169.228.38.2
Connecting to sccn.ucsd.edu (sccn.ucsd.edu)|169.228.38.2|:21... connected.
Logging in as anonymous ... Logged in!
==> SYST ... done.    ==> PWD ... done.
==> TYPE I ... done.  ==> CWD (1) /pub/ssvep_benchmark_dataset ... done.
==> PASV ... done.    ==> LIST ... done.

sccn.ucsd.edu/pub/s     [ <=>                ]   2.78K  --.-KB/s    in 0s      

2021-08-18 13:12:14 (261 MB/s) - ‘sccn.ucsd.edu/pub/ssvep_benchmark_dataset/.listing’ saved [2850]

Removed ‘sccn.ucsd.edu/pub/ssvep_benchmark_dataset/.listing’.
--2021-08-18 13:12:14--  ftp://anonymous@sccn.ucsd.edu/pub/ssvep_benchmark_dataset/S33.mat
           => ‘sccn.ucsd.edu/pub/ssvep_benchmark_dataset/S33.mat’
==> CWD not required.
==> PASV ... done.    ==> RETR S33.mat ... done.
Length: 106223727 (101M)


2021-08-18 13:12:17

In [None]:
from scipy.io import loadmat

# select channels
ch_names = ['FP1','FPZ','FP2','AF3','AF4','F7','F5','F3','F1','FZ','F2','F4','F6','F8','FT7','FC5','FC3','FC1','FCz','FC2','FC4','FC6','FT8','T7','C5','C3','C1','Cz','C2','C4','C6','T8','M1','TP7','CP5','CP3','CP1','CPZ','CP2','CP4','CP6','TP8','M2','P7','P5','P3','P1','PZ','P2','P4','P6','P8','PO7','PO5','PO3','POz','PO4','PO6','PO8','CB1','O1','Oz','O2','CB2']
ch_index = [47,53,54,55,56,57,60,61,62]

sampling_rate = 250

folder = 'sccn.ucsd.edu/pub/ssvep_benchmark_dataset'
data = loadmat(f"{folder}/S33.mat")
eeg = data['data']
eeg = eeg.transpose((3, 2, 0, 1))
eeg = eeg[:,  :, ch_index, 250:500]
print("Data shape:", eeg.shape)

blocks, targets, channels, samples = eeg.shape
y_train = np.tile(np.arange(0, targets), (1, blocks-1)).squeeze()
y_test = np.arange(0, targets)
print("Label shape:", y_train.shape, y_test.shape)

Data shape: (6, 40, 9, 250)
Label shape: (200,) (40,)


## Classification with TRCA

In [None]:
trca_classifier = TRCA(sampling_rate=sampling_rate)
test_accuracies = leave_one_block_evaluation(classifier=trca_classifier, X=eeg, Y=labels)

Block: 1 | Train acc: 100.00% | Test acc: 70.00%
Block: 2 | Train acc: 100.00% | Test acc: 47.50%
Block: 3 | Train acc: 100.00% | Test acc: 67.50%
Block: 4 | Train acc: 100.00% | Test acc: 47.50%
Block: 5 | Train acc: 100.00% | Test acc: 70.00%
Block: 6 | Train acc: 100.00% | Test acc: 62.50%
Mean test accuracy: 60.8%


### Comparing to CCA

In [None]:
stimulus_frequencies = np.array([8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,8.2,9.2,10.2,11.2,12.2,13.2,14.2,15.2,8.4,9.4,10.4,11.4,12.4,13.4,14.4,15.4,8.6,9.6,10.6,11.6,12.6,13.6,14.6,15.6,8.8,9.8,10.8,11.8,12.8,13.8,14.8,15.8])

cca = CCA(
    sampling_rate=sampling_rate, 
    target_frequencies=stimulus_frequencies,
    signal_size=eeg.shape[3], 
    num_harmonics=2
)

test_accuracies = leave_one_block_evaluation(classifier=cca, X=eeg, Y=labels)

Block: 1 | Train acc: 20.50% | Test acc: 35.00%
Block: 2 | Train acc: 26.00% | Test acc: 7.50%
Block: 3 | Train acc: 22.00% | Test acc: 27.50%
Block: 4 | Train acc: 23.50% | Test acc: 20.00%
Block: 5 | Train acc: 23.50% | Test acc: 20.00%
Block: 6 | Train acc: 22.00% | Test acc: 27.50%
Mean test accuracy: 22.900000000000002%


In [None]:
"""TRCA utils."""
import numpy as np

from scipy.signal import filtfilt, cheb1ord, cheby1
from scipy import stats


def round_half_up(num, decimals=0):
    """Round half up round the last decimal of the number.
    The rules are:
    from 0 to 4 rounds down
    from 5 to 9 rounds up
    Parameters
    ----------
    num : float
        Number to round
    decimals : number of decimals
    Returns
    -------
    num rounded
    """
    multiplier = 10 ** decimals
    return int(np.floor(num * multiplier + 0.5) / multiplier)


def normfit(data, ci=0.95):
    """Compute the mean, std and confidence interval for them.
    Parameters
    ----------
    data : array, shape=()
        Input data.
    ci : float
        Confidence interval (default=0.95).
    Returns
    -------
    m : float
        Mean.
    sigma : float
        Standard deviation
    [m - h, m + h] : list
        Confidence interval of the mean.
    [sigmaCI_lower, sigmaCI_upper] : list
        Confidence interval of the std.
    """
    arr = 1.0 * np.array(data)
    num = len(arr)
    avg, std_err = np.mean(arr), stats.sem(arr)
    h_int = std_err * stats.t.ppf((1 + ci) / 2., num - 1)
    var = np.var(data, ddof=1)
    var_ci_upper = var * (num - 1) / stats.chi2.ppf((1 - ci) / 2, num - 1)
    var_ci_lower = var * (num - 1) / stats.chi2.ppf(1 - (1 - ci) / 2, num - 1)
    sigma = np.sqrt(var)
    sigma_ci_lower = np.sqrt(var_ci_lower)
    sigma_ci_upper = np.sqrt(var_ci_upper)

    return avg, sigma, [avg - h_int, avg +
                        h_int], [sigma_ci_lower, sigma_ci_upper]


def itr(n, p, t):
    """Compute information transfer rate (ITR).
    Definition in [1]_.
    Parameters
    ----------
    n : int
        Number of targets.
    p : float
        Target identification accuracy (0 <= p <= 1).
    t : float
        Average time for a selection (s).
    Returns
    -------
    itr : float
        Information transfer rate [bits/min]
    References
    ----------
    .. [1] M. Cheng, X. Gao, S. Gao, and D. Xu,
        "Design and Implementation of a Brain-Computer Interface With High
        Transfer Rates", IEEE Trans. Biomed. Eng. 49, 1181-1186, 2002.
    """
    itr = 0

    if (p < 0 or 1 < p):
        raise ValueError('Accuracy need to be between 0 and 1.')
    elif (p < 1 / n):
        itr = 0
        raise ValueError('ITR might be incorrect because accuracy < chance')
    elif (p == 1):
        itr = np.log2(n) * 60 / t
    else:
        itr = (np.log2(n) + p * np.log2(p) + (1 - p) *
               np.log2((1 - p) / (n - 1))) * 60 / t

    return itr


def bandpass(eeg, sfreq, Wp, Ws):
    """Filter bank design for decomposing EEG data into sub-band components.
    Parameters
    ----------
    eeg : np.array, shape=(n_samples, n_chans[, n_trials])
        Training data.
    sfreq : int
        Sampling frequency of the data.
    Wp : 2-tuple
        Passband for Chebyshev filter.
    Ws : 2-tuple
        Stopband for Chebyshev filter.
    Returns
    -------
    y: np.array, shape=(n_trials, n_chans, n_samples)
        Sub-band components decomposed by a filter bank.
    See Also
    --------
    scipy.signal.cheb1ord :
        Chebyshev type I filter order selection.
    """
    # Chebyshev type I filter order selection.
    N, Wn = cheb1ord(Wp, Ws, 3, 40, fs=sfreq)

    # Chebyshev type I filter design
    B, A = cheby1(N, 0.5, Wn, btype="bandpass", fs=sfreq)

    # the arguments 'axis=0, padtype='odd', padlen=3*(max(len(B),len(A))-1)'
    # correspond to Matlab filtfilt : https://dsp.stackexchange.com/a/47945
    y = filtfilt(B, A, eeg, axis=0, padtype='odd',
                 padlen=3 * (max(len(B), len(A)) - 1))
    return y


def schaefer_strimmer_cov(X):
    r"""Schaefer-Strimmer covariance estimator.
    Shrinkage estimator described in [1]_:
    .. math:: \hat{\Sigma} = (1 - \gamma)\Sigma_{scm} + \gamma T
    where :math:`T` is the diagonal target matrix:
    .. math:: T_{i,j} = \{ \Sigma_{scm}^{ii} \text{if} i = j,
         0 \text{otherwise} \}
    Note that the optimal :math:`\gamma` is estimated by the authors' method.
    Parameters
    ----------
    X: array, shape=(n_chans, n_samples)
        Signal matrix.
    Returns
    -------
    cov: array, shape=(n_chans, n_chans)
        Schaefer-Strimmer shrinkage covariance matrix.
    References
    ----------
    .. [1] Schafer, J., and K. Strimmer. 2005. A shrinkage approach to
       large-scale covariance estimation and implications for functional
       genomics. Statist. Appl. Genet. Mol. Biol. 4:32.
    """
    ns = X.shape[1]
    C_scm = np.cov(X, ddof=0)
    X_c = X - np.tile(X.mean(axis=1), [ns, 1]).T

    # Compute optimal gamma, the weigthing between SCM and srinkage estimator
    R = ns / (ns - 1.0) * np.corrcoef(X)
    var_R = (X_c ** 2).dot((X_c ** 2).T) - 2 * C_scm * X_c.dot(X_c.T)
    var_R += ns * C_scm ** 2

    var_R = ns / ((ns - 1) ** 3 * np.outer(X.var(1), X.var(1))) * var_R
    R -= np.diag(np.diag(R))
    var_R -= np.diag(np.diag(var_R))
    gamma = max(0, min(1, var_R.sum() / (R ** 2).sum()))

    cov = (1. - gamma) * (ns / (ns - 1.)) * C_scm
    cov += gamma * (ns / (ns - 1.)) * np.diag(np.diag(C_scm))

    return cov


def _check_data(X):
    """Check data is numpy array and has the proper dimensions."""
    if not isinstance(X, (np.ndarray, list)):
        raise AttributeError('data should be a list or a numpy array')

    dtype = np.complex128 if np.any(np.iscomplex(X)) else np.float64
    X = np.asanyarray(X, dtype=dtype)
    if X.ndim > 3:
        raise ValueError('Data must be 3D at most')

    return X


def theshapeof(X):
    """Return the shape of X."""
    X = _check_data(X)
    # if not isinstance(X, np.ndarray):
    #     raise AttributeError('X must be a numpy array')

    if X.ndim == 3:
        return X.shape[0], X.shape[1], X.shape[2]
    elif X.ndim == 2:
        return X.shape[0], X.shape[1], 1
    elif X.ndim == 1:
        return X.shape[0], 1, 1
    else:
        raise ValueError("Array contains more than 3 dimensions")

        
###################


"""Task-Related Component Analysis."""
# Authors: Giuseppe Ferraro <giuseppe.ferraro@isae-supaero.fr>
#          Ludovic Darmet <ludovic.darmet@isae-supaero.fr>
import numpy as np
import scipy.linalg as linalg
from pyriemann.utils.mean import mean_covariance
from pyriemann.estimation import Covariances


class TRCA:
    """Task-Related Component Analysis (TRCA).
    Parameters
    ----------
    sfreq : float
        Sampling rate.
    filterbank : list[[2-tuple, 2-tuple]]
        Filterbank frequencies. Each list element is itself a list of passband
        `Wp` and stopband `Ws` edges frequencies `[Wp, Ws]`. For example, this
        creates 3 bands, starting at 6, 14, and 22 hz respectively::
            [[(6, 90), (4, 100)],
             [(14, 90), (10, 100)],
             [(22, 90), (16, 100)]]
        See :func:`scipy.signal.cheb1ord()` for more information on how to
        specify the `Wp` and `Ws`.
    ensemble : bool
        If True, perform the ensemble TRCA analysis (default=False).
    method : str in {'original'| 'riemann'}
        Use original implementation from [1]_ or a variation that uses
        regularization and the geodesic mean [2]_.
    regularization : str in {'schaefer' | 'lwf' | 'oas' | 'scm'}
        Regularization estimator used for covariance estimation with the
        `riemann` method. Consider 'schaefer', 'lwf', 'oas'. 'scm' does not add
        regularization and is almost equivalent to the original implementation.
    Attributes
    ----------
    traindata : array, shape=(n_bands, n_chans, n_trials)
        Reference (training) data decomposed into sub-band components by the
        filter bank analysis.
    y_train : array, shape=(n_trials)
        Labels associated with the train data.
    coef_ : array, shape=(n_chans, n_chans)
        Weight coefficients for electrodes which can be used as a spatial
        filter.
    classes : list
        Classes.
    n_bands : int
        Number of sub-bands.
    References
    ----------
    .. [1] M. Nakanishi, Y. Wang, X. Chen, Y. -T. Wang, X. Gao, and T.-P. Jung,
       "Enhancing detection of SSVEPs for a high-speed brain speller using
       task-related component analysis", IEEE Trans. Biomed. Eng,
       65(1):104-112, 2018.
    .. [2] Barachant, A., Bonnet, S., Congedo, M., & Jutten, C. (2010,
       October). Common spatial pattern revisited by Riemannian geometry. In
       2010 IEEE International Workshop on Multimedia Signal Processing (pp.
       472-476). IEEE.
    """

    def __init__(self, sfreq, filterbank, ensemble=False, method='original',
                 estimator='scm'):
        self.sfreq = sfreq
        self.ensemble = ensemble
        self.filterbank = filterbank
        self.n_bands = len(self.filterbank)
        self.coef_ = None
        self.method = method
        if estimator == 'schaefer':
            self.estimator = schaefer_strimmer_cov
        else:
            self.estimator = estimator
            
        self.can_train = True

    def fit(self, X, y):
        """Training stage of the TRCA-based SSVEP detection.
        Parameters
        ----------
        X : array, shape=(n_samples, n_chans[, n_trials])
            Training EEG data.
        y : array, shape=(trials,)
            True label corresponding to each trial of the data array.
        """
        
        X = np.transpose(X, (2,1,0))
        
        n_samples, n_chans, _ = theshapeof(X)
        classes = np.unique(y)

        trains = np.zeros((len(classes), self.n_bands, n_samples, n_chans))

        W = np.zeros((self.n_bands, len(classes), n_chans))

        for class_i in classes:
            # Select data with a specific label
            eeg_tmp = X[..., y == class_i]
            for fb_i in range(self.n_bands):
                # Filter the signal with fb_i
                eeg_tmp = bandpass(eeg_tmp, self.sfreq,
                                   Wp=self.filterbank[fb_i][0],
                                   Ws=self.filterbank[fb_i][1])
                if (eeg_tmp.ndim == 3):
                    # Compute mean of the signal across trials
                    trains[class_i, fb_i] = np.mean(eeg_tmp, -1)
                else:
                    trains[class_i, fb_i] = eeg_tmp
                # Find the spatial filter for the corresponding filtered signal
                # and label
                if self.method == 'original':
                    w_best = trca(eeg_tmp)
                elif self.method == 'riemann':
                    w_best = trca_regul(eeg_tmp, self.estimator)
                else:
                    raise ValueError('Invalid `method` option.')

                W[fb_i, class_i, :] = w_best  # Store the spatial filter

        self.trains = trains
        self.coef_ = W
        self.classes = classes

        return self

    def predict(self, X):
        """Test phase of the TRCA-based SSVEP detection.
        Parameters
        ----------
        X: array, shape=(n_samples, n_chans[, n_trials])
            Test data.
        model: dict
            Fitted model to be used in testing phase.
        Returns
        -------
        pred: np.array, shape (trials)
            The target estimated by the method.
        """
        
        X = np.transpose(X, (2,1,0))
        
        if self.coef_ is None:
            raise RuntimeError('TRCA is not fitted')

        # Alpha coefficients for the fusion of filterbank analysis
        fb_coefs = [(x + 1)**(-1.25) + 0.25 for x in range(self.n_bands)]
        _, _, n_trials = theshapeof(X)

        r = np.zeros((self.n_bands, len(self.classes)))
        pred = np.zeros((n_trials), 'int')  # To store predictions

        for trial in range(n_trials):
            test_tmp = X[..., trial]  # pick a trial to be analysed
            for fb_i in range(self.n_bands):

                # filterbank on testdata
                testdata = bandpass(test_tmp, self.sfreq,
                                    Wp=self.filterbank[fb_i][0],
                                    Ws=self.filterbank[fb_i][1])

                for class_i in self.classes:
                    # Retrieve reference signal for class i
                    # (shape: n_chans, n_samples)
                    traindata = np.squeeze(self.trains[class_i, fb_i])
                    if self.ensemble:
                        # shape = (n_chans, n_classes)
                        w = np.squeeze(self.coef_[fb_i]).T
                    else:
                        # shape = (n_chans)
                        w = np.squeeze(self.coef_[fb_i, class_i])

                    # Compute 2D correlation of spatially filtered test data
                    # with ref
                    r_tmp = np.corrcoef((testdata @ w).flatten(),
                                        (traindata @ w).flatten())
                    r[fb_i, class_i] = r_tmp[0, 1]

            rho = np.dot(fb_coefs, r)  # fusion for the filterbank analysis

            tau = np.argmax(rho)  # retrieving index of the max
            pred[trial] = int(tau)

        return pred


def trca(X):
    """Task-related component analysis.
    This function implements the method described in [1]_.
    Parameters
    ----------
    X : array, shape=(n_samples, n_chans[, n_trials])
        Training data.
    Returns
    -------
    W : array, shape=(n_chans,)
        Weight coefficients for electrodes which can be used as a spatial
        filter.
    References
    ----------
    .. [1] M. Nakanishi, Y. Wang, X. Chen, Y. -T. Wang, X. Gao, and T.-P. Jung,
       "Enhancing detection of SSVEPs for a high-speed brain speller using
       task-related component analysis", IEEE Trans. Biomed. Eng,
       65(1):104-112, 2018.
    """
    n_samples, n_chans, n_trials = theshapeof(X)

    # 1. Compute empirical covariance of all data (to be bounded)
    # -------------------------------------------------------------------------
    # Concatenate all the trials to have all the data as a sequence
    UX = np.zeros((n_chans, n_samples * n_trials))
    for trial in range(n_trials):
        UX[:, trial * n_samples:(trial + 1) * n_samples] = X[..., trial].T

    # Mean centering
    UX -= np.mean(UX, 1)[:, None]

    # Covariance
    Q = UX @ UX.T

    # 2. Compute average empirical covariance between all pairs of trials
    # -------------------------------------------------------------------------
    S = np.zeros((n_chans, n_chans))
    for trial_i in range(n_trials - 1):
        x1 = np.squeeze(X[..., trial_i])

        # Mean centering for the selected trial
        x1 -= np.mean(x1, 0)

        # Select a second trial that is different
        for trial_j in range(trial_i + 1, n_trials):
            x2 = np.squeeze(X[..., trial_j])

            # Mean centering for the selected trial
            x2 -= np.mean(x2, 0)

            # Compute empirical covariance between the two selected trials and
            # sum it
            S = S + x1.T @ x2 + x2.T @ x1

    # 3. Compute eigenvalues and vectors
    # -------------------------------------------------------------------------
    lambdas, W = linalg.eig(S, Q, left=True, right=False)

    # Select the eigenvector corresponding to the biggest eigenvalue
    W_best = W[:, np.argmax(lambdas)]

    return W_best


def trca_regul(X, method):
    """Task-related component analysis.
    This function implements a variation of the method described in [1]_. It is
    inspired by a riemannian geometry approach to CSP [2]_. It adds
    regularization to the covariance matrices and uses the riemannian mean for
    the inter-trial covariance matrix `S`.
    Parameters
    ----------
    X : array, shape=(n_samples, n_chans[, n_trials])
        Training data.
    Returns
    -------
    W : array, shape=(n_chans,)
        Weight coefficients for electrodes which can be used as a spatial
        filter.
    References
    ----------
    .. [1] M. Nakanishi, Y. Wang, X. Chen, Y. -T. Wang, X. Gao, and T.-P. Jung,
       "Enhancing detection of SSVEPs for a high-speed brain speller using
       task-related component analysis", IEEE Trans. Biomed. Eng,
       65(1):104-112, 2018.
    .. [2] Barachant, A., Bonnet, S., Congedo, M., & Jutten, C. (2010,
       October). Common spatial pattern revisited by Riemannian geometry. In
       2010 IEEE International Workshop on Multimedia Signal Processing (pp.
       472-476). IEEE.
    """
    n_samples, n_chans, n_trials = theshapeof(X)

    # 1. Compute empirical covariance of all data (to be bounded)
    # -------------------------------------------------------------------------
    # Concatenate all the trials to have all the data as a sequence
    UX = np.zeros((n_chans, n_samples * n_trials))
    for trial in range(n_trials):
        UX[:, trial * n_samples:(trial + 1) * n_samples] = X[..., trial].T

    # Mean centering
    UX -= np.mean(UX, 1)[:, None]

    # Compute empirical variance of all data (to be bounded)
    cov = Covariances(estimator=method).fit_transform(UX[np.newaxis, ...])
    Q = np.squeeze(cov)

    # 2. Compute average empirical covariance between all pairs of trials
    # -------------------------------------------------------------------------
    # Intertrial correlation computation
    data = np.concatenate((X, X), axis=1)

    # Swapaxes to fit pyriemann Covariances
    data = np.swapaxes(data, 0, 2)
    cov = Covariances(estimator=method).fit_transform(data)

    # Keep only inter-trial
    S = cov[:, :n_chans, n_chans:] + cov[:, n_chans:, :n_chans]

    # If the number of samples is too big, we compute an approximate of
    # riemannian mean to speed up the computation
    if n_trials < 30:
        S = mean_covariance(S, metric='riemann')
    else:
        S = mean_covariance(S, metric='logeuclid')

    # 3. Compute eigenvalues and vectors
    # -------------------------------------------------------------------------
    lambdas, W = linalg.eig(S, Q, left=True, right=False)

    # Select the eigenvector corresponding to the biggest eigenvalue
    W_best = W[:, np.argmax(lambdas)]

    return W_best

In [None]:
sfreq = data.sampling_rate
filterbank = [[(6, 90), (4, 100)],  # passband, stopband freqs [(Wp), (Ws)]
              [(14, 90), (10, 100)],
              [(22, 90), (16, 100)],
              [(30, 90), (24, 100)],
              [(38, 90), (32, 100)],
              [(46, 90), (40, 100)],
              [(54, 90), (48, 100)]]

trca_classifier = TRCA(sfreq, filterbank, True)
test_accuracies = leave_one_block_evaluation(classifier=trca_classifier, X=eeg, Y=labels)
test_accuracies

Block: 1 | Train acc: 100.00% | Test acc: 97.50%
Block: 2 | Train acc: 100.00% | Test acc: 100.00%
Block: 3 | Train acc: 100.00% | Test acc: 100.00%
Block: 4 | Train acc: 100.00% | Test acc: 100.00%
Block: 5 | Train acc: 100.00% | Test acc: 97.50%
Block: 6 | Train acc: 100.00% | Test acc: 100.00%
Mean test accuracy: 99.2%


[0.975, 1.0, 1.0, 1.0, 0.975, 1.0]

In [None]:
eeg[0,0,0,0]

-3.4129396046028053