In [None]:
import numpy as np
import mne
from moabb.datasets import BNCI2014001
from moabb.paradigms import MotorImagery, LeftRightImagery
from moabb.evaluations import WithinSessionEvaluation
from moabb.analysis.meta_analysis import compute_dataset_statistics
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.preprocessing import StandardScaler
import ot  # POT library for Optimal Transport[1]

##############################################################################
# Reasoning integrated into the code/comments:
# StandardScaler (and typical scikit-learn estimators) require 2D data:
# (n_samples, n_features). However, EEG data as returned by Moabb can be 3D:
# (n_trials, n_channels, n_times). Hence, we provide a custom FlattenTransform
# to reshape each trial into a 1D vector (n_channels*n_times), removing the
# dimension mismatch error ("ValueError: Found array with dim 3. StandardScaler expected <= 2.")
#
# We also define a “Backward OT” domain adaptation transform, following the
# principle from the reference. The transform in the pipeline will
# first flatten, then standardize, then map the target data into source space,
# so the already-trained classifier can be used directly and avoid retraining[1].

class FlattenTransform(BaseEstimator, TransformerMixin):
    """Flatten epochs from (n_trials, n_channels, n_times) to (n_trials, n_channels*n_times)."""
    def fit(self, X, y=None):
        return self
    
    def transform(self, X, y=None):
        # Assume X.shape = (n_trials, n_channels, n_times)
        if X.ndim != 3:
            raise ValueError(f"Expected 3D array, got shape={X.shape}")
        n_trials, n_channels, n_times = X.shape
        return X.reshape(n_trials, n_channels * n_times)


class BackwardOTTransform(BaseEstimator, TransformerMixin):
    """
    Backward OT: The target data is transported into the source domain, so
    no model retraining is needed. We follow a regularized Sinkhorn approach[1].
    """
    def __init__(self, reg=1.0, use_labels=False, metric='sqeuclidean'):
        self.reg = reg
        self.use_labels = use_labels
        self.metric = metric
    
    def fit(self, X, y=None):
        # Store source data (Xs) and labels
        self.Xs_ = X
        self.ys_ = y
        return self
    
    def transform(self, X, y=None):
        if not hasattr(self, 'Xs_'):
            raise RuntimeError("BackwardOTTransform must be fit first.")

        X_src = self.Xs_
        X_tar = X
        
        # Cost matrix
        if self.metric == 'sqeuclidean':
            cost_matrix = ot.dist(X_tar, X_src, metric='euclidean') ** 2
        else:
            cost_matrix = ot.dist(X_tar, X_src, metric=self.metric)
        
        # Uniform distributions over samples
        a = np.ones(X_tar.shape[0]) / X_tar.shape[0]
        b = np.ones(X_src.shape[0]) / X_src.shape[0]
        
        # If we incorporate label knowledge with a penalty for mismatches
        if self.use_labels and y is not None and self.ys_ is not None:
            mean_cost = np.mean(cost_matrix)
            for i in range(X_tar.shape[0]):
                for j in range(X_src.shape[0]):
                    if y[i] != self.ys_[j]:
                        cost_matrix[i, j] += mean_cost
        
        # Solve Sinkhorn for a regularized OT plan
        gamma_ = ot.sinkhorn(a, b, cost_matrix, self.reg)
        
        # Barycentric mapping: transport X_tar -> X_src
        X_trans = gamma_.dot(X_src)
        return X_trans


def make_backward_ot_pipeline(reg=1.0, use_labels=False, metric='sqeuclidean'):
    """
    Returns a pipeline that:
      1) flattens the data from 3D to 2D,
      2) standardizes each feature,
      3) applies backward OT transform to map the target domain to the source,
      4) classifies with LDA.
    """
    return Pipeline([
        ('flatten', FlattenTransform()),
        ('scaler', StandardScaler()),
        ('bot', BackwardOTTransform(reg=reg, use_labels=use_labels, metric=metric)),
        ('lda', LDA(solver='lsqr', shrinkage='auto'))
    ])


##############################################################################
# Example usage with BNCI2014001 (hand vs feet), within-session

dataset = BNCI2014001()  # 9 subjects, 2 classes
dataset.subject_list = [1]
paradigm = LeftRightImagery(fmin=8, fmax=30, channels=None, baseline=None)

evaluation = WithinSessionEvaluation(
    paradigm=paradigm,
    datasets=[dataset],
    suffix='BackwardOT_dimfix'
)

# Define some pipelines with different hyperparams
pipelines = {
    'BOT-S-reg1': make_backward_ot_pipeline(reg=1.0, use_labels=False),
    # 'BOT-S-reg2': make_backward_ot_pipeline(reg=2.0, use_labels=False),
    # 'BOT-GL-reg1': make_backward_ot_pipeline(reg=1.0, use_labels=True),
    # 'BOT-GL-reg2': make_backward_ot_pipeline(reg=2.0, use_labels=True),
}

results = evaluation.process(pipelines)
print("Results (head):")
print(results.head())

stats = compute_dataset_statistics(results)
print("\nAggregate stats:")
print(stats)

print("\nPipeline mean accuracies:")
print(results.groupby('pipeline')['score'].mean())


BNCI2014001 has been renamed to BNCI2014_001. BNCI2014001 will be removed in version 1.1.
The dataset class name 'BNCI2014001' must be an abbreviation of its code 'BNCI2014-001'. See moabb.datasets.base.is_abbrev for more information.
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_han

No hdf5_path provided, models will not be saved.
