In [None]:
import warnings

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from mne.decoding import CSP
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import make_pipeline

import moabb
from moabb.datasets import BNCI2014_001, Zhou2016, Weibo2014
from moabb.evaluations import WithinSessionEvaluation
from moabb.paradigms import LeftRightImagery

In [None]:
datasets = [Zhou2016(), BNCI2014_001(), Weibo2014()]

In [None]:
sessionss = []
for d in datasets:
    sessionss.append(d.get_data())

In [None]:
unique_channels = set()  # Use a set to store unique channel names

for sessions in sessionss:
    for subject_name, subject_data in sessions.items():
        for session_name, session_data in subject_data.items():
            for run_name, raw in session_data.items():
                unique_channels.update(raw.pick_types(eeg=True).ch_names)  # Add channels to the set
                # print(raw.info)

# Convert to a sorted list for consistency
unique_channels = sorted(unique_channels)
print(len(unique_channels))

In [None]:
import mne
# Step 1: Collect all unique event names across datasets dynamically
all_event_names = set()
for sessions in sessionss:
    for subject_data in sessions.values():
        for session_data in subject_data.values():
            for raw in session_data.values():
                _, event_dict = mne.events_from_annotations(raw)
                all_event_names.update(event_dict.keys())

In [None]:
# Define only the desired event names
desired_events = ["feet", "left_hand", "right_hand"]
# desired_events = ["left_hand", "right_hand"]

# Create standardized label mapping for these events only
standardized_labels = {event: idx for idx, event in enumerate(sorted(desired_events))}

print("Standardized Labels:", standardized_labels)  # Debugging


In [None]:
import mne
import numpy as np

X_all0 = []
YD_all0 = []
padding_masks0 = []

channel_to_index = {ch: i for i, ch in enumerate(unique_channels)}  # Map each channel to an index
num_channels = len(unique_channels)  # Total unique channels

# Assuming `sessionss` is a list of subjects, each containing sessions and runs
for dom, sessions in enumerate(sessionss):
    subj_idx = 0
    for subject_name, subject_data in sessions.items():
        sess_idx = 0
        for session_name, session_data in subject_data.items():
            for run_name, raw in session_data.items():
                # Pick only EEG channels directly (no need to copy first)
                raw.pick_types(eeg=True)
                raw.resample(200)
                raw.filter(l_freq=8, h_freq=30, fir_design='firwin')

                # Convert annotations to events
                events, event_dict = mne.events_from_annotations(raw)
                event_id_to_name = {v: k for k, v in event_dict.items()}

                # Define epoch time range
                tmin, tmax = -0.2, 4

                # Create epochs
                epochs = mne.Epochs(
                    raw, events, event_dict, tmin=tmin, tmax=tmax, baseline=(None, 0),
                    preload=False,  # Don't load into memory immediately
                    event_repeated='drop'  # Avoid duplicate event errors
                )

                # Get epoch data and labels
                X_raw = epochs.get_data()  # Shape: (n_epochs, n_channels, n_times)
                # Normalize each time series in `aligned_X` to [-1,1]
                # X_min = np.min(X_raw, axis=2, keepdims=True)
                # X_max = np.max(X_raw, axis=2, keepdims=True)
                # # Avoid division by zero
                # X_range = X_max - X_min
                # X_range[X_range == 0] = 1  # Prevent division by zero
                # # Normalize to [-1, 1]
                # X_raw = 2 * (X_raw - X_min) / X_range - 1

                # Z-score normalization: Normalize each epoch's time series
                # Calculate mean and std along the time axis for each channel
                X_mean = np.mean(X_raw, axis=2, keepdims=True)  # Mean along time axis
                X_std = np.std(X_raw, axis=2, keepdims=True)  # Standard deviation along time axis
                X_raw = (X_raw - X_mean) / (X_std + 1e-8)  # Z-score normalization (avoid division by zero)

                Y = epochs.events[:, -1]  # Event labels
                Y_names = [event_id_to_name.get(event, "unknown") for event in Y]
                Y_standardized = np.array([standardized_labels.get(name, -1) for name in Y_names])
                valid_idx = Y_standardized != -1
                Y = Y_standardized[valid_idx]
                X_raw = X_raw[valid_idx]
                current_channels = raw.ch_names  # Current session's channels

                # Align channels using NumPy indexing (faster than loops)
                aligned_X = np.zeros((X_raw.shape[0], num_channels, X_raw.shape[2]))  # Empty padded array
                padding_mask = np.zeros((X_raw.shape[0], num_channels))  # Binary mask for padding

                valid_idx = [channel_to_index[ch] for ch in current_channels if ch in channel_to_index]

                aligned_X[:, valid_idx, :] = X_raw[:, :len(valid_idx), :]
                padding_mask[:, valid_idx] = 1  # Mark real channels as `1`

                # Domain label for domain
                D = np.full((aligned_X.shape[0], 1), dom)

                # Session label for subject
                S = np.full((aligned_X.shape[0], 1), subj_idx)

                # Session label for session
                T = np.full((aligned_X.shape[0], 1), sess_idx)

                # ID = np.full((aligned_X.shape[0], 1), dom * 10000 + subj_idx * 100 + sess_idx)
                ID = np.full((aligned_X.shape[0], 1), dom * 10000 + subj_idx * 100)


                # Stack labels with domain
                YD = np.column_stack((Y, D.flatten(), S.flatten(), T.flatten(), ID.flatten()))

                # Efficient stacking
                if len(X_all0) == 0:
                    X_all0 = aligned_X
                    YD_all0 = YD
                    padding_masks0 = padding_mask
                else:
                    X_all0 = np.vstack((X_all0, aligned_X))
                    YD_all0 = np.vstack((YD_all0, YD))
                    padding_masks0 = np.vstack((padding_masks0, padding_mask))
            sess_idx += 1
        subj_idx += 1
    # break
# Remap domain labels (column 1 of YD_all0)
unique_ids, remapped_ids = np.unique(YD_all0[:, 4], return_inverse=True)
YD_all0[:, 4] = remapped_ids  # Replace domain column with remapped values

print(X_all0.shape)  # (n_epochs, n_channels, n_times)
print(YD_all0.shape)  # (n_epochs, 2)
print(padding_masks0.shape)  # (n_epochs, n_channels) - Binary mask for padding

SUBJECT DEPENDENT DATA

In [None]:
# Step 1: Find indices where domain label (YD[:, 1]) == 0
domain0_indices = np.where(((YD_all0[:, 1] == 1) & (YD_all0[:, 3] == 1)) | ((YD_all0[:, 1] == 0) & (YD_all0[:, 3] == 2)))[0]
# domain0_indices = np.where(((YD_all0[:, 1] == 0) & (YD_all0[:, 2] == 3)))[0]
domain2_indices = np.where(YD_all0[:, 1] == 2)[0]
# Step 2: Randomly select 540 indices from those
rng = np.random.RandomState(42)
domain0_indices = np.union1d(domain0_indices, rng.choice(domain2_indices, size=int(domain2_indices.shape[0]*0.5)))
# selected_indices = rng.choice(domain0_indices, size=int(domain0_indices.shape[0]*0.3), replace=False)

X_half = X_all0[domain0_indices]
YD_half = YD_all0[domain0_indices]
padding_masks_half = padding_masks0[domain0_indices]
# print(domain0_indices.shape)
print(X_half.shape)  # (n_epochs, n_channels, n_times)
print(YD_half.shape)  # (n_epochs, 2)
print(padding_masks_half.shape)  # (n_epochs, n_channels) - Binary mask for padding

In [None]:
# Step 1: Find indices where domain label (YD[:, 1]) == 0
# domain1_indices = np.where(((YD_all0[:, 1] == 0) & (YD_all0[:, 3] < 2)) | ((YD_all0[:, 1] == 1) & (YD_all0[:, 3] == 0)) | ((YD_all0[:, 1] == 2) & (YD_all0[:, 2] < 5)))[0]
domain1_indices = np.setdiff1d(np.arange(YD_all0.shape[0]), domain0_indices)
rng = np.random.RandomState(42)
selected_indices = rng.choice(domain1_indices, size=int(domain1_indices.shape[0]*0.7), replace=False)
# Get what is NOT in selected_indices
# not_selected = np.setdiff1d(domain0_indices, selected_indices)
# not_selected = np.where((YD_all0[:, 1] == 0))[0]
not_selected = np.setdiff1d(domain1_indices, selected_indices)
# not_selected_subset = np.random.choice(not_selected, replace=False)
# Step 2: Randomly select 540 indices from those
# selected_indices2 = np.random.choice(domain1_indices, size=1260, replace=False)
# Union with selected_indices2
final_indices = np.union1d(domain1_indices, not_selected)

X_val = X_all0[not_selected]
YD_val = YD_all0[not_selected]
mask_val = padding_masks0[not_selected]


X_all = X_all0[selected_indices]          # (8058 - 1185, 60, 1051)
YD_all = YD_all0[selected_indices]          # (8058 - 1185, 2)
padding_masks = padding_masks0[selected_indices]          # (8058 - 1185, 60)

print(X_all.shape)  # (n_epochs, n_channels, n_times)
print(YD_all.shape)  # (n_epochs, 2)
print(padding_masks.shape)  # (n_epochs, n_channels) - Binary mask for padding

CROSS-SUBJECT

In [None]:
# Step 1: Find indices where domain label (YD[:, 1]) == 0
domain0_indices = np.where(((YD_all0[:, 1] == 1) & (YD_all0[:, 2] > 5)) | ((YD_all0[:, 1] == 0) & (YD_all0[:, 2] > 2)))[0]
# domain0_indices = np.where(((YD_all0[:, 1] == 0) & (YD_all0[:, 2] == 3)))[0]
domain2_indices = np.where((YD_all0[:, 1] == 2) & (YD_all0[:, 2] > 6))[0]
# Step 2: Randomly select 540 indices from those
rng = np.random.RandomState(42)
# domain0_indices = np.union1d(domain0_indices, rng.choice(domain2_indices, size=int(domain2_indices.shape[0]*0.5)))
domain0_indices = np.union1d(domain0_indices, domain2_indices)
# selected_indices = rng.choice(domain0_indices, size=int(domain0_indices.shape[0]*0.3), replace=False)

X_half = X_all0[domain0_indices]
YD_half = YD_all0[domain0_indices]
padding_masks_half = padding_masks0[domain0_indices]
# print(domain0_indices.shape)
print(X_half.shape)  # (n_epochs, n_channels, n_times)
print(YD_half.shape)  # (n_epochs, 2)
print(padding_masks_half.shape)  # (n_epochs, n_channels) - Binary mask for padding

In [None]:

domain1_indices = np.where(((YD_all0[:, 1] == 1) & (YD_all0[:, 2] > 3)) | ((YD_all0[:, 1] == 0) & (YD_all0[:, 2] > 1)) | ((YD_all0[:, 1] == 2) & (YD_all0[:, 2] > 4)))[0]

not_selected = np.setdiff1d(domain1_indices, domain0_indices)
selected_indices = np.setdiff1d(np.arange(YD_all0.shape[0]), domain0_indices)
selected_indices = np.setdiff1d(selected_indices, not_selected)

final_indices = np.union1d(domain1_indices, not_selected)

X_val = X_all0[not_selected]
YD_val = YD_all0[not_selected]
mask_val = padding_masks0[not_selected]


X_all = X_all0[selected_indices]          
YD_all = YD_all0[selected_indices]          
padding_masks = padding_masks0[selected_indices]

print(X_all.shape)  # (n_epochs, n_channels, n_times)
print(YD_all.shape)  # (n_epochs, 2)
print(padding_masks.shape)  # (n_epochs, n_channels) - Binary mask for padding

In [None]:
import mne
import numpy as np
import torch

# Load the 10-20 montage
montage_1020 = mne.channels.make_standard_montage('standard_1020')

# Extract channel positions
channel_positions = {
    ch: pos[:2] for ch, pos in montage_1020.get_positions()['ch_pos'].items()  # Extract (x, y) only
}

# Define number of total channels (27 in your case)
num_channels = padding_masks.shape[1]

# Initialize an array to store (x, y) locations, defaulting to (0,0) for missing ones
channel_xy = np.zeros((num_channels, 2))  # Shape (27, 2)


# Assign known positions to their corresponding indices
for ch, idx in channel_to_index.items():
    if ch in channel_positions:  # Ensure the channel exists in the montage
        channel_xy[idx] = channel_positions[ch]

print(channel_xy.shape)


In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

X_tensor = torch.tensor(X_all, dtype=torch.float32, device=device)
YD_tensor = torch.tensor(YD_all, dtype=torch.long, device=device)
padding_masks_tensor = torch.tensor(padding_masks, dtype=torch.bool, device=device)
channel_xy_tensor = torch.tensor(channel_xy, dtype=torch.float32, device=device)

X_tensor_half = torch.tensor(X_val, dtype=torch.float32, device=device)
YD_tensor_half = torch.tensor(YD_val, dtype=torch.long, device=device)
padding_masks_tensor_half = torch.tensor(mask_val, dtype=torch.bool, device=device)

X_tensor_test = torch.tensor(X_half, dtype=torch.float32, device=device)
YD_tensor_test = torch.tensor(YD_half, dtype=torch.long, device=device)
padding_masks_tensor_test = torch.tensor(padding_masks_half, dtype=torch.bool, device=device)

In [None]:
torch.save({
    'X': X_tensor,
    'YD': YD_tensor,
    'mask': padding_masks_tensor,
    'channel_xy': channel_xy_tensor,
    'X_val': X_tensor_half,
    'YD_val': YD_tensor_half,
    'mask_val': padding_masks_tensor_half,
    'X_test': X_tensor_test,
    'YD_test': YD_tensor_test,
    'mask_test': padding_masks_tensor_test
}, '/content/drive/MyDrive/eeg_data_TEST.pt')
