In [49]:
import numpy as np
from moabb.datasets import BNCI2014_001
from moabb.paradigms import MotorImagery, LeftRightImagery
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from mne.decoding import CSP
from moabb.evaluations import CrossSubjectEvaluation
from sklearn.pipeline import make_pipeline
from scipy import signal
from scipy.io import loadmat
import os
import mne
import numpy as np
from pyriemann.utils.mean import mean_riemann
from pyriemann.utils.distance import distance_riemann
from sklearn.model_selection import KFold
from tqdm import tqdm
import scipy.linalg
from sklearn.mixture import GaussianMixture

In [2]:
import os
import numpy as np
import mne
from scipy.signal import butter, lfilter

# Define a causal bandpass filter function using a Butterworth design.
def causal_bandpass_filter(data, lowcut=8, highcut=30, fs=250, order=50):
    nyq = 0.5 * fs
    # Normalize the cutoff frequencies (Matlab's fir1 expects normalized cutoff frequencies
    low = lowcut / nyq
    high = highcut / nyq
    # Design the FIR filter. Note: order+1 coefficients are returned to match Matlab's fir1 which returns n+1 taps.
    b = signal.firwin(order + 1, [low, high], window='hamming', pass_zero=False)
    # Apply the filter causally using lfilter (this introduces a constant delay).
    filtered_data = signal.lfilter(b, [1.0], data)
    return filtered_data

In [3]:

# Set the data directory where your GDF files are stored.
data_dir = '/home/vishwa/eeg_tl/Recreating papers/BCICIV_2a'  # Replace with your actual path

In [4]:
active_all_event_ids = {'769': 769, '770': 770, '771': 771, '772': 772}
active_lr_event_ids = {'769': 769, '770': 770}
unknown_event_id = {'783': 783} 

In [5]:
# With a sampling frequency of 250 Hz, 1001 samples equate to 1001/250 seconds.
sfreq = 250
tmin = 0.5       # Epoch start at cue onset.
# Set tmax so that n_samples = (tmax-tmin)*sfreq + 1 = 1001, i.e. 4 seconds long.
tmax = 2.5  # This gives 4.0 seconds.

### Train & Eval - active

In [6]:
train_active_X = []         # List to hold numpy arrays with shape (n_trials, 22, 1001) per subject.
train_active_y = []         # List to hold event labels per subject.
train_active_metadata = []  # List to hold event metadata per subject.

# Loop over subjects. Assume files are named "A01T.gdf", "A02T.gdf", ..., "A09T.gdf".
for subj in range(1, 10):
    filename = os.path.join(data_dir, f'A{subj:02d}T.gdf')
    
    # Read the GDF file (using preload=True to load data into memory).
    train_raw = mne.io.read_raw_gdf(filename, preload=True, verbose=False, eog=['EOG-left', 'EOG-central', 'EOG-right'])
    # Retain only EEG channels (22 channels) and exclude EOG channels.
    train_raw.pick_types(eeg=True, eog=False)
    
    # Extract events corresponding only to the four desired types.
    train_active_events, _ = mne.events_from_annotations(train_raw, event_id=active_all_event_ids)
    
    # Create epochs from tmin to tmax.
    train_active_epochs = mne.Epochs(train_raw, train_active_events, event_id=active_all_event_ids, tmin=tmin, tmax=tmax,
                        baseline=None, preload=True, verbose=False)
    
    # Get the epoch data (num_epochs x 22 channels x 1001 samples).
    train_active_data = train_active_epochs.get_data()
    
    # Print the number of extracted epochs to verify
    print(f"Subject {subj}: Epoch data shape {train_active_data.shape}")
    
    # Sampling frequency from raw.info (should be 250).
    fs = int(train_raw.info['sfreq'])
    # print(fs)
    n_trials, n_channels, n_times = train_active_data.shape
    train_active_filtered_data = np.empty_like(train_active_data)
    
    # Apply the causal bandpass filter channel‐wise for each trial.
    for trial in range(n_trials):
        for ch in range(n_channels):
            train_active_filtered_data[trial, ch, :] = causal_bandpass_filter(
                train_active_data[trial, ch, :],
                lowcut=8,    # Lower bound of sensorimotor rhythm.
                highcut=30,  # Upper bound of sensorimotor rhythm.
                fs=fs,
                order=50     # Lower order for a smoother causal filter.
            )
    
    # Append the processed data, labels, and event metadata.
    train_active_X.append(train_active_filtered_data)
    train_active_y.append(train_active_epochs.events[:, 2])  # The third column holds the event code.
    train_active_metadata.append(train_active_epochs.events)

print("Loaded data for", len(train_active_X), "subjects.")

  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 1: Epoch data shape (288, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 2: Epoch data shape (288, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 3: Epoch data shape (288, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 4: Epoch data shape (288, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 5: Epoch data shape (288, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 6: Epoch data shape (288, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 7: Epoch data shape (288, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 8: Epoch data shape (288, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 9: Epoch data shape (288, 22, 501)
Loaded data for 9 subjects.


In [7]:
eval_active_X = []         
eval_active_y = []        
eval_active_metadata = [] 

# Loop over subjects. Assume files are named "A01T.gdf", "A02T.gdf", ..., "A09T.gdf".
for subj in range(1, 10):
    filename = os.path.join(data_dir, f'A{subj:02d}E.gdf')
    mat_data = loadmat(f'/home/vishwa/eeg_tl/Recreating papers/BCICIV_2A true labels/A{subj:02d}E.mat')
    true_y =  np.array(mat_data['classlabel'], dtype=np.int64).reshape(288,) + 768
    # Read the GDF file (using preload=True to load data into memory).
    eval_raw = mne.io.read_raw_gdf(filename, preload=True, verbose=False, eog=['EOG-left', 'EOG-central', 'EOG-right'])
    # Retain only EEG channels (22 channels) and exclude EOG channels.
    eval_raw.pick_types(eeg=True, eog=False)
    
    # Extract events corresponding only to the four desired types.
    eval_active_events, _ = mne.events_from_annotations(eval_raw, event_id=unknown_event_id)
    
    # Create epochs from tmin to tmax.
    eval_active_epochs = mne.Epochs(eval_raw, eval_active_events, event_id=unknown_event_id, tmin=tmin, tmax=tmax,
                        baseline=None, preload=True, verbose=False)
    
    # Get the epoch data (num_epochs x 22 channels x 1001 samples).
    eval_active_data = eval_active_epochs.get_data()
    
    # Print the number of extracted epochs to verify
    print(f"Subject {subj}: Epoch data shape {eval_active_data.shape}")
    
    # Sampling frequency from raw.info (should be 250).
    fs = int(eval_raw.info['sfreq'])
    # print(fs)
    n_trials, n_channels, n_times = eval_active_data.shape
    eval_active_filtered_data = np.empty_like(eval_active_data)
    
    # Apply the causal bandpass filter channel‐wise for each trial.
    for trial in range(n_trials):
        for ch in range(n_channels):
            eval_active_filtered_data[trial, ch, :] = causal_bandpass_filter(
                eval_active_data[trial, ch, :],
                lowcut=8,    # Lower bound of sensorimotor rhythm.
                highcut=30,  # Upper bound of sensorimotor rhythm.
                fs=fs,
                order=50     # Lower order for a smoother causal filter.
            )
    
    # Append the processed data, labels, and event metadata.
    eval_active_X.append(eval_active_filtered_data)
    eval_active_y.append(true_y)  # The third column holds the event code.
    eval_active_metadata.append(eval_active_epochs.events)

print("Loaded data for", len(eval_active_X), "subjects.")

  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 1: Epoch data shape (288, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 2: Epoch data shape (288, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 3: Epoch data shape (288, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 4: Epoch data shape (288, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 5: Epoch data shape (288, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 6: Epoch data shape (288, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 7: Epoch data shape (288, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 8: Epoch data shape (288, 22, 501)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 9: Epoch data shape (288, 22, 501)
Loaded data for 9 subjects.


### Train and Eval - resting

In [8]:
train_resting_X = []
train_resting_metadata = []

for subj in range(1, 10):
    filename = os.path.join(data_dir, f'A{subj:02d}T.gdf')
    
    # Load the GDF file and select EEG channels.
    train_raw = mne.io.read_raw_gdf(filename, preload=True, verbose=False, 
                              eog=['EOG-left', 'EOG-central', 'EOG-right'])
    train_raw.pick_types(eeg=True, eog=False)
    
    # Extract active events (e.g. left/right motor imagery cues).
    train_resting_events, _ = mne.events_from_annotations(train_raw, event_id=active_all_event_ids)
    
    # Create active epochs.
    train_resting_epochs = mne.Epochs(train_raw, train_resting_events, event_id=active_all_event_ids,
                               tmin=tmin, tmax=tmax, baseline=None, preload=True, verbose=False)
    
    # Now derive resting events by shifting each active event by 6 seconds (trial end).
    fs = int(train_raw.info['sfreq'])
    train_resting_events[:, 0] += int(6 * fs)
    
    # Create resting epochs of 1.5 s duration.
    train_resting_epochs = mne.Epochs(train_raw, train_resting_events, tmin=0, tmax=1.5,
                                baseline=None, preload=True, verbose=False)
    
    train_resting_data = train_resting_epochs.get_data()
    print(f"Subject {subj}: Resting data shape {train_resting_data.shape}")
    
    n_trials, n_channels, n_times = train_resting_data.shape
    train_resting_filtered_data = np.empty_like(train_resting_data)
    
    for trial in range(n_trials):
        for ch in range(n_channels):
            train_resting_filtered_data[trial, ch, :] = causal_bandpass_filter(
                train_resting_data[trial, ch, :],
                lowcut=8,
                highcut=30,
                fs=fs,
                order=50
            )
    
    train_resting_X.append(train_resting_filtered_data)
    train_resting_metadata.append(train_resting_epochs.events)

print("Loaded resting data for", len(train_resting_X), "subjects.")


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 1: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 2: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 3: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 4: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 5: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 6: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 7: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 8: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['769', '770', '771', '772']
Subject 9: Resting data shape (287, 22, 376)
Loaded resting data for 9 subjects.


In [9]:
eval_resting_X = []
eval_resting_metadata = []

for subj in range(1, 10):
    filename = os.path.join(data_dir, f'A{subj:02d}E.gdf')
   
    # Load the GDF file and select EEG channels.
    eval_raw = mne.io.read_raw_gdf(filename, preload=True, verbose=False,
                              eog=['EOG-left', 'EOG-central', 'EOG-right'])
    eval_raw.pick_types(eeg=True, eog=False)
   
    # Extract resting events (e.g. left/right motor imagery cues).
    eval_resting_events, _ = mne.events_from_annotations(eval_raw, event_id=unknown_event_id)
   
    # Create resting epochs.
    eval_resting_epochs = mne.Epochs(eval_raw, eval_resting_events, event_id=unknown_event_id,
                               tmin=tmin, tmax=tmax, baseline=None, preload=True, verbose=False)
   
    # Now derive resting events by shifting each resting event by 6 seconds (trial end).
    fs = int(eval_raw.info['sfreq'])
    # eval_resting_events = eval_resting_events.copy()
    eval_resting_events[:, 0] += int(6 * fs)
   
    # Create resting epochs of 1.5 s duration.
    eval_resting_epochs = mne.Epochs(eval_raw, eval_resting_events, tmin=0, tmax=1.5,
                                baseline=None, preload=True, verbose=False)
   
    eval_resting_data = eval_resting_epochs.get_data()
    print(f"Subject {subj}: Resting data shape {eval_resting_data.shape}")
   
    n_trials, n_channels, n_times = eval_resting_data.shape
    eval_resting_filtered_data = np.empty_like(eval_resting_data)
   
    for trial in range(n_trials):
        for ch in range(n_channels):
            eval_resting_filtered_data[trial, ch, :] = causal_bandpass_filter(
                eval_resting_data[trial, ch, :],
                lowcut=8,
                highcut=30,
                fs=fs,
                order=50
            )
   
    eval_resting_X.append(eval_resting_filtered_data)
    # For resting epochs, labels might not be needed or can be set to a dummy value.
    eval_resting_metadata.append(eval_resting_epochs.events)


print("Loaded resting data for", len(eval_resting_X), "subjects.")

  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 1: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 2: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 3: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 4: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 5: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 6: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 7: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 8: Resting data shape (287, 22, 376)


  next(self.gen)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['783']
Subject 9: Resting data shape (287, 22, 376)
Loaded resting data for 9 subjects.


### Cov + Affine transform

In [10]:
def cov(X):
    covs = np.array([trial @ trial.T for trial in X])
    return covs

In [14]:
train_active_covs = [cov(i) for i in train_active_X]
eval_active_covs = [cov(i) for i in eval_active_X]

train_resting_covs = [cov(i) for i in train_resting_X]
eval_resting_covs =  [cov(i) for i in eval_resting_X]

In [17]:
train_resting_covs_mean = [mean_riemann(i) for i in train_resting_covs]
eval_resting_covs_mean = [mean_riemann(i) for i in eval_resting_covs]

In [25]:
def affine_transform(covmats, reference):
    """
    Apply the affine transformation:
    C -> R^(-1/2) * C * R^(-1/2)
    """
    ref_sqrt_inv = np.linalg.inv(scipy.linalg.sqrtm(reference))
    return np.array([ref_sqrt_inv @ c @ ref_sqrt_inv for c in covmats])

In [29]:
affine_transformed_train_covs = [affine_transform(train_active_covs[i], train_resting_covs_mean[i]) for i in range(len(train_active_covs))]
affine_transformed_eval_covs = [affine_transform(eval_active_covs[i], eval_resting_covs_mean[i]) for i in range(len(eval_active_covs))]

### MDM - classifier

In [33]:
def mdm_classify(cov_train, y_train, cov_test, y_test):
    """
    Minimum Distance to Mean (MDM) classification.

    1) Compute Riemannian mean for each class
    2) Assign each test trial to class with min Riemannian distance
    """
    classes = np.unique(y_train)
    means = {}
    for c in classes:
        means[c] = mean_riemann(cov_train[y_train == c])
    
    predictions = []
    for test_cov in cov_test:
        # Compute distance to each class mean
        dists = {}
        for c in classes:
            dists[c] = distance_riemann(test_cov, means[c])
        # Pick class with minimal distance
        predictions.append(min(dists, key=dists.get))
    
    return np.mean(np.array(predictions) == y_test)

In [35]:
print(mdm_classify(train_active_covs[0], train_active_y[0], eval_active_covs[0], eval_active_y[0]))
print(mdm_classify(affine_transformed_train_covs[0], train_active_y[0], affine_transformed_eval_covs[0], eval_active_y[0]))

0.7743055555555556
0.7708333333333334


In [40]:
print("MDM - untransformed")
for i in range(9):
    print(f"Subject {i+1}: {mdm_classify(train_active_covs[i], train_active_y[i], eval_active_covs[i], eval_active_y[i]):.2f}")
    

MDM - untransformed
Subject 1: 0.77
Subject 2: 0.48
Subject 3: 0.69
Subject 4: 0.64
Subject 5: 0.48
Subject 6: 0.47
Subject 7: 0.67
Subject 8: 0.69
Subject 9: 0.72


In [41]:
print("MDM - affine transformed")
for i in range(9):
    print(f"Subject {i+1}: {mdm_classify(affine_transformed_train_covs[i], train_active_y[i], affine_transformed_eval_covs[i], eval_active_y[i]):.2f}")

MDM - affine transformed
Subject 1: 0.77
Subject 2: 0.52
Subject 3: 0.80
Subject 4: 0.62
Subject 5: 0.48
Subject 6: 0.51
Subject 7: 0.78
Subject 8: 0.79
Subject 9: 0.73


In [44]:
def bayesian_classify(cov_train, y_train, cov_test, y_test):
    """
    Bayesian classification (eq. 8).
    Each class: Riemannian mean = center of mass, plus MLE for sigma.

    Classification rule:
      arg min_k { log(zeta(sigma_k)) + dR^2(C, mean_k) / [2 * sigma_k^2] }
    We do not need to compute log(zeta(sigma_k)) precisely for comparison,
    but we include log(sigma_k) for approximate effect. 
    """
    classes = np.unique(y_train)
    means = {}
    sigmas = {}

    # Estimate mean and sigma for each class
    for c in classes:
        cov_per_class = cov_train[y_train == c]
        means[c] = mean_riemann(cov_per_class)
        # Distances for MLE of sigma
        dists = [distance_riemann(cov_, means[c]) for cov_ in cov_per_class]
        sigmas[c] = np.sqrt(np.mean(np.square(dists))) if len(dists) > 0 else 1e-9

    predictions = []
    for test_cov in cov_test:
        scores = {}
        # We compare log(sigma_k) + d^2 / (2 sigma_k^2)
        for c in classes:
            d2 = distance_riemann(test_cov, means[c]) ** 2
            # We skip the constant log(zeta(sigma_k)) because it doesn't change the arg min
            # Use an approximate version: Score = log(sigma_k) + d2 / (2*sigma_k^2)
            # The class with the smallest score is chosen
            score = np.log(sigmas[c]) + (d2 / (2.0 * (sigmas[c] ** 2)))
            scores[c] = score
        predictions.append(min(scores, key=scores.get))
    
    return np.mean(np.array(predictions) == y_test)

In [45]:
print("BC - untransformed")
for i in range(9):
    print(f"Subject {i+1}: {bayesian_classify(train_active_covs[i], train_active_y[i], eval_active_covs[i], eval_active_y[i]):.2f}")
    

BC - untransformed
Subject 1: 0.76
Subject 2: 0.46
Subject 3: 0.69
Subject 4: 0.60
Subject 5: 0.37
Subject 6: 0.49
Subject 7: 0.65
Subject 8: 0.72
Subject 9: 0.74


In [46]:
print("BC - affine transformed")
for i in range(9):
    print(f"Subject {i+1}: {bayesian_classify(affine_transformed_train_covs[i], train_active_y[i], affine_transformed_eval_covs[i], eval_active_y[i]):.2f}")

BC - affine transformed
Subject 1: 0.78
Subject 2: 0.51
Subject 3: 0.80
Subject 4: 0.62
Subject 5: 0.48
Subject 6: 0.51
Subject 7: 0.78
Subject 8: 0.78
Subject 9: 0.73


In [74]:
import numpy as np
import scipy.special
from pyriemann.clustering import Kmeans
from pyriemann.utils.distance import distance_riemann

def log_zeta(sigma, m):
    """
    Computes the log of the normalizing factor ζ(σ) for the Riemannian Gaussian.
    For m = 2 an analytic expression is available.
    """
    eps = 1e-10
    if m == 2:
        return 1.5 * np.log(2 * np.pi) + 2 * np.log(sigma + eps) + (sigma**2)/4 + np.log(scipy.special.erf(sigma/2) + eps)
    else:
        # For other dimensions, one might evaluate ζ(σ) numerically.
        return 0.0

def gmm_classify(cov_train, y_train, cov_test, y_test, n_components=2):
    """
    Classify test SPD matrices using a Riemannian Gaussian mixture model.
    
    Parameters
    ----------
    cov_train : ndarray, shape (n_train, m, m)
        Training SPD matrices.
    y_train : ndarray, shape (n_train,)
        Training class labels.
    cov_test : ndarray, shape (n_test, m, m)
        Test SPD matrices.
    y_test : ndarray, shape (n_test,)
        True test labels.
    n_components : int, default=2
        Number of mixture components (clusters) per class.
    
    Returns
    -------
    predictions : ndarray, shape (n_test,)
        Predicted labels for the test data.
    """
    eps = 1e-10  # small constant to avoid division by 0 and log(0)
    m = cov_train.shape[1]  # assuming SPD matrices of size m x m
    cluster_params = []  # each entry is (class label, center, sigma, weight)

    # Process each class separately.
    unique_classes = np.unique(y_train)
    for cls in unique_classes:
        idx = np.where(y_train == cls)[0]
        mats = cov_train[idx]
        # Cluster using pyriemann's Kmeans with the affine-invariant (Riemannian) metric.
        kmeans = Kmeans(n_clusters=n_components, metric='riemann', init='random', tol=1e-6, max_iter=50)
        labels = kmeans.fit_predict(mats)
        # IMPORTANT: use the centroids() method (NOT the attribute "centroids_")
        centers = kmeans.centroids()  

        for k in range(n_components):
            cluster_members = mats[labels == k]
            if len(cluster_members) == 0:
                continue
            center = centers[k]
            # Estimate dispersion sigma from the squared Riemannian distances.
            dists_sq = np.array([distance_riemann(M, center, squared=True) for M in cluster_members])
            sigma = np.sqrt(np.mean(dists_sq))
            # Weight is the relative frequency of this cluster among all training matrices.
            weight = len(cluster_members) / float(len(cov_train))
            cluster_params.append((cls, center, sigma, weight))
    
    predictions = []
    # Classify each test SPD matrix.
    for Y in cov_test:
        best_score = np.inf
        best_class = None
        for (cls, center, sigma, weight) in cluster_params:
            # Calculate the squared Riemannian distance between Y and the cluster center.
            dist_sq = distance_riemann(Y, center, squared=True)
            # Compute a score that approximates the negative log-likelihood.
            score = -np.log(weight + eps) + log_zeta(sigma, m) + dist_sq / (2 * (sigma**2 + eps))
            if score < best_score:
                best_score = score
                best_class = cls
        predictions.append(best_class)
    
    # return np.array(predictions)
    
    return np.mean(np.array(predictions) == y_test)

In [75]:
print("GM-4 - untransformed")
for i in range(9):
    print(f"Subject {i+1}: {gmm_classify(train_active_covs[i], train_active_y[i], eval_active_covs[i], eval_active_y[i]):.2f}")
    

GM-4 - untransformed
Subject 1: 0.49
Subject 2: 0.31
Subject 3: 0.52
Subject 4: 0.25
Subject 5: 0.25
Subject 6: 0.28
Subject 7: 0.65
Subject 8: 0.43
Subject 9: 0.48


In [76]:
print("GM-4 - affine transformed")
for i in range(9):
    print(f"Subject {i+1}: {gmm_classify(affine_transformed_train_covs[i], train_active_y[i], affine_transformed_eval_covs[i], eval_active_y[i]):.2f}")

GM-4 - affine transformed
Subject 1: 0.48
Subject 2: 0.33
Subject 3: 0.49
Subject 4: 0.27
Subject 5: 0.25
Subject 6: 0.30
Subject 7: 0.58
Subject 8: 0.41
Subject 9: 0.49
