In [1]:
import warnings

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from mne.decoding import CSP
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import make_pipeline

import moabb
from moabb.datasets import BNCI2014_001, Zhou2016, Schirrmeister2017, Weibo2014
from moabb.evaluations import WithinSessionEvaluation
from moabb.paradigms import LeftRightImagery

moabb.set_log_level("info")
warnings.filterwarnings("ignore")
# moabb.set_download_dir("D:\TA\database")

DATA LOAD

In [2]:
from torch.utils.data import Dataset, DataLoader

class MultisourceDataset(Dataset):
    def __init__(self, X, YD, channel_mask):
        self.X = X
        self.YD = YD
        self.channel_mask = channel_mask

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = self.X[idx]
        yd = self.YD[idx]
        masks = self.channel_mask[idx]
        return x, yd, masks  # Return data + index

eeg_data_200 = multi-source dataset
eeg_data_200_cross = multi-source dataset for cross subject classification
eeg_data_bnci = BNCI2014_001 subset of multi-source dataset
eeg_data_zhou = Zhou2016 subset of multi-source dataset
eeg_data_weibo = Weibo2014 subset of multi-source dataset

X = training rows/trials
YD = training labels
mask = training channel masks

X_val = validatiom rows/trials
YD_val = validation labels
mask_val = validation channel masks

X_test = test rows/trials
YD_test = test labels
mask_test = test channel masks

channel_xy = list of channels positions

Labels = [class, dataset sources id, subject id (of each dataset), session id (of each subject), domain id]
domain id is unique
combination (dataset sources id, subject id) is unique
combination (dataset sources id, subject id, session id) is unique

In [3]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = torch.load('./Dataset/eeg_data_200.pt', map_location=device)

X_tensor = data['X']
YD_tensor = data['YD']
padding_masks_tensor = data['mask']
channel_xy_tensor = data['channel_xy']

X_tensor_half = data['X_val']
YD_tensor_half = data['YD_val']
padding_masks_tensor_half = data['mask_val']

X_tensor_test = data['X_test']
YD_tensor_test = data['YD_test']
padding_masks_tensor_test = data['mask_test']

dataset = MultisourceDataset(X_tensor, YD_tensor, padding_masks_tensor)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
dataset_test = MultisourceDataset(X_tensor_half, YD_tensor_half, padding_masks_tensor_half)
dataloader_test = DataLoader(dataset_test, batch_size=10, shuffle=True)
dataset_test0 = MultisourceDataset(X_tensor_test, YD_tensor_test, padding_masks_tensor_test)
dataloader_test0 = DataLoader(dataset_test0, batch_size=10, shuffle=True)

EEGNet

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class EEGNet0(nn.Module):
    def __init__(self, nb_classes, Chans=64, Samples=128, dropoutRate=0.5,
                 kernLength=64, F1=8, D=2, F2=16, norm_rate=0.25):
        super(EEGNet0, self).__init__()

        # First temporal convolution
        self.conv1 = nn.Conv2d(1, F1, (1, kernLength), padding=(0, kernLength // 2), bias=False)
        self.bn1 = nn.BatchNorm2d(F1)

        # Depthwise convolution (spatial filter)
        self.depthwiseConv = nn.Conv2d(F1, F1 * D, (Chans, 1), groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(F1 * D)

        # Pool and dropout after depthwise
        self.pool1 = nn.AvgPool2d((1, 4))
        self.drop1 = nn.Dropout(dropoutRate)

        # Separable convolution (depthwise + pointwise)
        self.sep_depth = nn.Conv2d(F1 * D, F1 * D, (1, 16), padding=(0, 8), groups=F1 * D, bias=False)
        self.sep_point = nn.Conv2d(F1 * D, F2, (1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(F2)

        # Pool and dropout after separable conv
        self.pool2 = nn.AvgPool2d((1, 8))
        self.drop2 = nn.Dropout(dropoutRate)

        # Final classifier
        self.classifier = nn.Linear(F2 * ((Samples // 32)), nb_classes)  # Adapt to pooling sizes

    def forward(self, x):
        # x shape: (batch, 1, Chans, Samples)
        x = self.conv1(x)
        x = self.bn1(x)

        x = self.depthwiseConv(x)
        x = self.bn2(x)
        x = F.elu(x)
        x = self.pool1(x)
        x = self.drop1(x)

        x = self.sep_depth(x)
        x = self.sep_point(x)
        x = self.bn3(x)
        x = F.elu(x)
        x = self.pool2(x)
        x = self.drop2(x)

        x = x.flatten(start_dim=1)
        logits = self.classifier(x)
        return logits, x


DATA AUGMENTATIONS

In [5]:
import torch
import random

def add_gaussian_noise_classwise(X, YD, mask, n_aug=1, N=5):
    # Group trial indices by class
    class_to_indices = {}
    for i in range(X.shape[0]):
        label = int(YD[i][0].item())
        class_to_indices.setdefault(label, []).append(i)

    aug_X, aug_YD, aug_mask = [], [], []

    for label, indices in class_to_indices.items():
        # Shuffle indices to ensure randomness
        random.shuffle(indices)

        # Partition into groups of N (e.g., 5)
        for i in range(0, len(indices) - N + 1, N):  # ensures full groups
            group = indices[i:i+N]  # size N group
            group_trials = torch.stack([X[j] for j in group])  # [N, C, T]
            group_mean = group_trials.mean().item()

            for j in group:
                x_i = X[j]
                for _ in range(n_aug):
                    noise = torch.randn_like(x_i) * group_mean
                    x_aug = x_i + noise
                    aug_X.append(x_aug.unsqueeze(0))
                    aug_YD.append(YD[j].unsqueeze(0))
                    aug_mask.append(mask[j].unsqueeze(0))

    return (
        torch.cat(aug_X, dim=0),
        torch.cat(aug_YD, dim=0),
        torch.cat(aug_mask, dim=0),
    )


In [6]:
def recombination_in_time(X, YD, mask, n_aug=1, n_chunks=5, output_length=None):
    """
    Recombine time segments from different same-class samples.
    Pads or trims to match `output_length`.
    """
    augmented_X, augmented_YD, augmented_mask = [], [], []
    classes = torch.unique(YD[:, 0])

    if output_length is None:
        output_length = X.shape[2]  # default to original trial length

    for cls in classes:
        idx_cls = (YD[:, 0] == cls).nonzero(as_tuple=True)[0]
        if len(idx_cls) < n_chunks:
            continue

        for _ in range(n_aug * len(idx_cls)):
            chosen = idx_cls[torch.randperm(len(idx_cls))[:n_chunks]]
            segments = []
            total_length = 0
            chunk_target = output_length // n_chunks

            for ch in chosen:
                trial = X[ch]
                trial_len = trial.shape[1]
                if trial_len < chunk_target:
                    continue  # skip if not enough length

                start = torch.randint(0, trial_len - chunk_target + 1, (1,)).item()
                segment = trial[:, start:start + chunk_target]
                segments.append(segment)
                total_length += segment.shape[1]

            if len(segments) != n_chunks:
                continue  # skip incomplete recombinations

            new_trial = torch.cat(segments, dim=1)

            # Ensure exact output length (pad or trim)
            if new_trial.shape[1] < output_length:
                pad_len = output_length - new_trial.shape[1]
                pad = torch.zeros((new_trial.shape[0], pad_len), dtype=new_trial.dtype, device=new_trial.device)
                new_trial = torch.cat([new_trial, pad], dim=1)
            elif new_trial.shape[1] > output_length:
                new_trial = new_trial[:, :output_length]

            new_mask = mask[chosen[0]].clone()


            augmented_X.append(new_trial.unsqueeze(0))
            augmented_YD.append(YD[chosen[0]].unsqueeze(0))
            augmented_mask.append(new_mask.unsqueeze(0))

    if augmented_X:
        return (
            torch.cat(augmented_X, dim=0),
            torch.cat(augmented_YD, dim=0),
            torch.cat(augmented_mask, dim=0)
        )
    else:
        return None, None, None


TRAIN LOOP NORMAL

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split, DataLoader
import numpy as np
import random

for indSubj in range(9):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = torch.load('./Datasets/eeg_data_bnci.pt', map_location=device)

    X_tensor = data['X']
    YD_tensor = data['YD']
    padding_masks_tensor = data['mask']
    channel_xy_tensor = data['channel_xy']

    X_tensor_half = data['X_val']
    YD_tensor_half = data['YD_val']
    padding_masks_tensor_half = data['mask_val']

    X_tensor_test = data['X_test']
    YD_tensor_test = data['YD_test']
    padding_masks_tensor_test = data['mask_test']

    idx = YD_tensor[:, 4] == indSubj + 4
    idx_half = YD_tensor_half[:, 4] == indSubj + 4
    idx_test = YD_tensor_test[:, 4] == indSubj + 4

    X_tensor = X_tensor[idx]
    YD_tensor = YD_tensor[idx]
    padding_masks_tensor = padding_masks_tensor[idx]

    X_tensor_half = X_tensor_half[idx_half]
    YD_tensor_half = YD_tensor_half[idx_half]
    padding_masks_tensor_half = padding_masks_tensor_half[idx_half]

    X_tensor_test = X_tensor_test[idx_test]
    YD_tensor_test = YD_tensor_test[idx_test]
    padding_masks_tensor_test = padding_masks_tensor_test[idx_test]

    dataset = MultisourceDataset(X_tensor, YD_tensor, padding_masks_tensor)
    dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
    dataset_test = MultisourceDataset(X_tensor_half, YD_tensor_half, padding_masks_tensor_half)
    dataloader_test = DataLoader(dataset_test, batch_size=10, shuffle=True)
    dataset_test0 = MultisourceDataset(X_tensor_test, YD_tensor_test, padding_masks_tensor_test)
    dataloader_test0 = DataLoader(dataset_test0, batch_size=10, shuffle=True)

    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    out_channels = 22
    num_classes = 3
    learning_rate = 0.0005
    num_epochs = 500

    model = EEGNet0(nb_classes=num_classes, Chans=out_channels, Samples=841).to(device)

    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)

    train_loader = DataLoader(dataloader.dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(dataloader_test.dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(dataloader_test0.dataset, batch_size=32, shuffle=True)

    best_val_acc = 0
    for epoch in range(num_epochs):
        model.train()
        correct, total, running_loss = 0, 0, 0
        for batch in train_loader:
            x_batch, yd_batch, mask_batch = batch
            X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
            X_valid = X_valid.unsqueeze(1).to(device)
            yd_batch = yd_batch.to(device)

            output, _ = model(X_valid)
            loss = criterion(output, yd_batch[:, 0])
            running_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            predicted_labels = torch.argmax(output, dim=1)
            correct += (predicted_labels == yd_batch[:, 0]).sum().item()
            total += yd_batch.size(0)

        train_loss = running_loss / len(train_loader)
        train_acc = correct / total * 100

        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for batch in val_loader:
                x_batch, yd_batch, mask_batch = batch
                X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
                X_valid = X_valid.unsqueeze(1).to(device)
                yd_batch = yd_batch.to(device)

                output, _ = model(X_valid)
                loss = criterion(output, yd_batch[:, 0])
                val_loss += loss.item()

                predicted_labels = torch.argmax(output, dim=1)
                val_correct += (predicted_labels == yd_batch[:, 0]).sum().item()
                val_total += yd_batch.size(0)

        val_loss /= len(val_loader)
        val_acc = val_correct / val_total * 100

        test_loss, test_correct, test_total = 0, 0, 0
        with torch.no_grad():
            for batch in test_loader:
                x_batch, yd_batch, mask_batch = batch
                X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
                X_valid = X_valid.unsqueeze(1).to(device)
                yd_batch = yd_batch.to(device)

                output, _ = model(X_valid)
                loss = criterion(output, yd_batch[:, 0])
                test_loss += loss.item()

                predicted_labels = torch.argmax(output, dim=1)
                test_correct += (predicted_labels == yd_batch[:, 0]).sum().item()
                test_total += yd_batch.size(0)

        test_loss /= len(test_loader)
        test_acc = test_correct / test_total * 100

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            save_path = f"./EEGNet0-BNCI-s{indSubj}-TEST.pth"
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, save_path)

            print("✅ Model saved.")

        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}% | "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split, DataLoader
import numpy as np
import random

for indSubj in range(10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = torch.load('./Dataset/eeg_data_weibo.pt', map_location=device)

    X_tensor = data['X']
    YD_tensor = data['YD']
    padding_masks_tensor = data['mask']
    channel_xy_tensor = data['channel_xy']

    X_tensor_half = data['X_val']
    YD_tensor_half = data['YD_val']
    padding_masks_tensor_half = data['mask_val']

    X_tensor_test = data['X_test']
    YD_tensor_test = data['YD_test']
    padding_masks_tensor_test = data['mask_test']

    idx = YD_tensor[:, 4] == indSubj + 13
    idx_half = YD_tensor_half[:, 4] == indSubj + 13
    idx_test = YD_tensor_test[:, 4] == indSubj + 13

    X_tensor = X_tensor[idx]
    YD_tensor = YD_tensor[idx]
    padding_masks_tensor = padding_masks_tensor[idx]

    X_tensor_half = X_tensor_half[idx_half]
    YD_tensor_half = YD_tensor_half[idx_half]
    padding_masks_tensor_half = padding_masks_tensor_half[idx_half]

    X_tensor_test = X_tensor_test[idx_test]
    YD_tensor_test = YD_tensor_test[idx_test]
    padding_masks_tensor_test = padding_masks_tensor_test[idx_test]

    dataset = MultisourceDataset(X_tensor, YD_tensor, padding_masks_tensor)
    dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
    dataset_test = MultisourceDataset(X_tensor_half, YD_tensor_half, padding_masks_tensor_half)
    dataloader_test = DataLoader(dataset_test, batch_size=10, shuffle=True)
    dataset_test0 = MultisourceDataset(X_tensor_test, YD_tensor_test, padding_masks_tensor_test)
    dataloader_test0 = DataLoader(dataset_test0, batch_size=10, shuffle=True)

    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    out_channels = 60
    num_classes = 3
    learning_rate = 0.0005
    num_epochs = 500

    model = EEGNet0(nb_classes=num_classes, Chans=out_channels, Samples=841).to(device)

    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)

    train_loader = DataLoader(dataloader.dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(dataloader_test.dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(dataloader_test0.dataset, batch_size=32, shuffle=True)

    best_val_acc = 0
    for epoch in range(num_epochs):
        model.train()
        correct, total, running_loss = 0, 0, 0
        for batch in train_loader:
            x_batch, yd_batch, mask_batch = batch
            X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
            X_valid = X_valid.unsqueeze(1).to(device)
            yd_batch = yd_batch.to(device)

            output, _ = model(X_valid)
            loss = criterion(output, yd_batch[:, 0])
            running_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            predicted_labels = torch.argmax(output, dim=1)
            correct += (predicted_labels == yd_batch[:, 0]).sum().item()
            total += yd_batch.size(0)

        train_loss = running_loss / len(train_loader)
        train_acc = correct / total * 100

        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for batch in val_loader:
                x_batch, yd_batch, mask_batch = batch
                X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
                X_valid = X_valid.unsqueeze(1).to(device)
                yd_batch = yd_batch.to(device)

                output, _ = model(X_valid)
                loss = criterion(output, yd_batch[:, 0])
                val_loss += loss.item()

                predicted_labels = torch.argmax(output, dim=1)
                val_correct += (predicted_labels == yd_batch[:, 0]).sum().item()
                val_total += yd_batch.size(0)

        val_loss /= len(val_loader)
        val_acc = val_correct / val_total * 100

        test_loss, test_correct, test_total = 0, 0, 0
        with torch.no_grad():
            for batch in test_loader:
                x_batch, yd_batch, mask_batch = batch
                X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
                X_valid = X_valid.unsqueeze(1).to(device)
                yd_batch = yd_batch.to(device)

                output, _ = model(X_valid)
                loss = criterion(output, yd_batch[:, 0])
                test_loss += loss.item()

                predicted_labels = torch.argmax(output, dim=1)
                test_correct += (predicted_labels == yd_batch[:, 0]).sum().item()
                test_total += yd_batch.size(0)

        test_loss /= len(test_loader)
        test_acc = test_correct / test_total * 100

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            save_path = f"./EEGNet0-Weibo-s{indSubj}-TEST.pth"
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, save_path)

            print("✅ Model saved.")

        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}% | "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split, DataLoader
import numpy as np
import random

for indSubj in range(4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = torch.load('./Dataset/eeg_data_zhou.pt', map_location=device)

    X_tensor = data['X']
    YD_tensor = data['YD']
    padding_masks_tensor = data['mask']
    channel_xy_tensor = data['channel_xy']

    X_tensor_half = data['X_val']
    YD_tensor_half = data['YD_val']
    padding_masks_tensor_half = data['mask_val']

    X_tensor_test = data['X_test']
    YD_tensor_test = data['YD_test']
    padding_masks_tensor_test = data['mask_test']

    idx = YD_tensor[:, 4] == indSubj + 0
    idx_half = YD_tensor_half[:, 4] == indSubj + 0
    idx_test = YD_tensor_test[:, 4] == indSubj + 0

    X_tensor = X_tensor[idx]
    YD_tensor = YD_tensor[idx]
    padding_masks_tensor = padding_masks_tensor[idx]

    X_tensor_half = X_tensor_half[idx_half]
    YD_tensor_half = YD_tensor_half[idx_half]
    padding_masks_tensor_half = padding_masks_tensor_half[idx_half]

    X_tensor_test = X_tensor_test[idx_test]
    YD_tensor_test = YD_tensor_test[idx_test]
    padding_masks_tensor_test = padding_masks_tensor_test[idx_test]

    dataset = MultisourceDataset(X_tensor, YD_tensor, padding_masks_tensor)
    dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
    dataset_test = MultisourceDataset(X_tensor_half, YD_tensor_half, padding_masks_tensor_half)
    dataloader_test = DataLoader(dataset_test, batch_size=10, shuffle=True)
    dataset_test0 = MultisourceDataset(X_tensor_test, YD_tensor_test, padding_masks_tensor_test)
    dataloader_test0 = DataLoader(dataset_test0, batch_size=10, shuffle=True)

    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    out_channels = 14
    num_classes = 3
    learning_rate = 0.0005
    num_epochs = 500

    model = EEGNet0(nb_classes=num_classes, Chans=out_channels, Samples=841).to(device)

    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)

    train_loader = DataLoader(dataloader.dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(dataloader_test.dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(dataloader_test0.dataset, batch_size=32, shuffle=True)

    best_val_acc = 0
    for epoch in range(num_epochs):
        model.train()
        correct, total, running_loss = 0, 0, 0
        for batch in train_loader:
            x_batch, yd_batch, mask_batch = batch
            X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
            X_valid = X_valid.unsqueeze(1).to(device)
            yd_batch = yd_batch.to(device)

            output, _ = model(X_valid)
            loss = criterion(output, yd_batch[:, 0])
            running_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            predicted_labels = torch.argmax(output, dim=1)
            correct += (predicted_labels == yd_batch[:, 0]).sum().item()
            total += yd_batch.size(0)

        train_loss = running_loss / len(train_loader)
        train_acc = correct / total * 100

        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for batch in val_loader:
                x_batch, yd_batch, mask_batch = batch
                X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
                X_valid = X_valid.unsqueeze(1).to(device)
                yd_batch = yd_batch.to(device)

                output, _ = model(X_valid)
                loss = criterion(output, yd_batch[:, 0])
                val_loss += loss.item()

                predicted_labels = torch.argmax(output, dim=1)
                val_correct += (predicted_labels == yd_batch[:, 0]).sum().item()
                val_total += yd_batch.size(0)

        val_loss /= len(val_loader)
        val_acc = val_correct / val_total * 100

        test_loss, test_correct, test_total = 0, 0, 0
        with torch.no_grad():
            for batch in test_loader:
                x_batch, yd_batch, mask_batch = batch
                X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
                X_valid = X_valid.unsqueeze(1).to(device)
                yd_batch = yd_batch.to(device)

                output, _ = model(X_valid)
                loss = criterion(output, yd_batch[:, 0])
                test_loss += loss.item()

                predicted_labels = torch.argmax(output, dim=1)
                test_correct += (predicted_labels == yd_batch[:, 0]).sum().item()
                test_total += yd_batch.size(0)

        test_loss /= len(test_loader)
        test_acc = test_correct / test_total * 100

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            save_path = f"./EEGNet0-Zhou-s{indSubj}-TEST.pth"
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, save_path)

            print("✅ Model saved.")

        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}% | "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")


TRAIN LOOP + AUGMENT (NOISE ADDITION/RECOMBINATION IN TIME)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split, DataLoader
import numpy as np
import random

for indSubj in range(9):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = torch.load('./Dataset/eeg_data_bnci.pt', map_location=device)

    X_tensor = data['X']
    YD_tensor = data['YD']
    padding_masks_tensor = data['mask']
    channel_xy_tensor = data['channel_xy']

    X_tensor_half = data['X_val']
    YD_tensor_half = data['YD_val']
    padding_masks_tensor_half = data['mask_val']

    X_tensor_test = data['X_test']
    YD_tensor_test = data['YD_test']
    padding_masks_tensor_test = data['mask_test']

    idx = YD_tensor[:, 4] == indSubj + 4
    idx_half = YD_tensor_half[:, 4] == indSubj + 4
    idx_test = YD_tensor_test[:, 4] == indSubj + 4

    X_tensor = X_tensor[idx]
    YD_tensor = YD_tensor[idx]
    padding_masks_tensor = padding_masks_tensor[idx]

    aug_X, aug_YD, aug_mask = add_gaussian_noise_classwise(X_tensor, YD_tensor, padding_masks_tensor)
    # aug_X, aug_YD, aug_mask = recombination_in_time(
    #     X_tensor, YD_tensor, padding_masks_tensor,
    #     n_chunks=5, output_length=X_tensor.shape[2]
    # )

    if aug_X is not None:
        X_tensor = torch.cat([X_tensor, aug_X], dim=0)
        YD_tensor = torch.cat([YD_tensor, aug_YD], dim=0)
        padding_masks_tensor = torch.cat([padding_masks_tensor, aug_mask], dim=0)

    X_tensor_half = X_tensor_half[idx_half]
    YD_tensor_half = YD_tensor_half[idx_half]
    padding_masks_tensor_half = padding_masks_tensor_half[idx_half]

    X_tensor_test = X_tensor_test[idx_test]
    YD_tensor_test = YD_tensor_test[idx_test]
    padding_masks_tensor_test = padding_masks_tensor_test[idx_test]

    dataset = MultisourceDataset(X_tensor, YD_tensor, padding_masks_tensor)
    dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
    dataset_test = MultisourceDataset(X_tensor_half, YD_tensor_half, padding_masks_tensor_half)
    dataloader_test = DataLoader(dataset_test, batch_size=10, shuffle=True)
    dataset_test0 = MultisourceDataset(X_tensor_test, YD_tensor_test, padding_masks_tensor_test)
    dataloader_test0 = DataLoader(dataset_test0, batch_size=10, shuffle=True)

    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    out_channels = 22
    num_classes = 3
    learning_rate = 0.0005
    num_epochs = 500

    model = EEGNet0(nb_classes=num_classes, Chans=out_channels, Samples=841).to(device)

    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)

    train_loader = DataLoader(dataloader.dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(dataloader_test.dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(dataloader_test0.dataset, batch_size=32, shuffle=True)

    best_val_acc = 0
    for epoch in range(num_epochs):
        model.train()
        correct, total, running_loss = 0, 0, 0
        for batch in train_loader:
            x_batch, yd_batch, mask_batch = batch
            X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
            X_valid = X_valid.unsqueeze(1).to(device)
            yd_batch = yd_batch.to(device)

            output, _ = model(X_valid)
            loss = criterion(output, yd_batch[:, 0])
            running_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            predicted_labels = torch.argmax(output, dim=1)
            correct += (predicted_labels == yd_batch[:, 0]).sum().item()
            total += yd_batch.size(0)

        train_loss = running_loss / len(train_loader)
        train_acc = correct / total * 100

        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for batch in val_loader:
                x_batch, yd_batch, mask_batch = batch
                X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
                X_valid = X_valid.unsqueeze(1).to(device)
                yd_batch = yd_batch.to(device)

                output, _ = model(X_valid)
                loss = criterion(output, yd_batch[:, 0])
                val_loss += loss.item()

                predicted_labels = torch.argmax(output, dim=1)
                val_correct += (predicted_labels == yd_batch[:, 0]).sum().item()
                val_total += yd_batch.size(0)

        val_loss /= len(val_loader)
        val_acc = val_correct / val_total * 100

        test_loss, test_correct, test_total = 0, 0, 0
        with torch.no_grad():
            for batch in test_loader:
                x_batch, yd_batch, mask_batch = batch
                X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
                X_valid = X_valid.unsqueeze(1).to(device)
                yd_batch = yd_batch.to(device)

                output, _ = model(X_valid)
                loss = criterion(output, yd_batch[:, 0])
                test_loss += loss.item()

                predicted_labels = torch.argmax(output, dim=1)
                test_correct += (predicted_labels == yd_batch[:, 0]).sum().item()
                test_total += yd_batch.size(0)

        test_loss /= len(test_loader)
        test_acc = test_correct / test_total * 100

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            save_path = f"./EEGNet0-BNCI-s{indSubj}-TEST.pth"
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, save_path)

            print("✅ Model saved.")

        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}% | "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split, DataLoader
import numpy as np
import random

for indSubj in range(10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = torch.load('./Dataset/eeg_data_weibo.pt', map_location=device)

    X_tensor = data['X']
    YD_tensor = data['YD']
    padding_masks_tensor = data['mask']
    channel_xy_tensor = data['channel_xy']

    X_tensor_half = data['X_val']
    YD_tensor_half = data['YD_val']
    padding_masks_tensor_half = data['mask_val']

    X_tensor_test = data['X_test']
    YD_tensor_test = data['YD_test']
    padding_masks_tensor_test = data['mask_test']

    idx = YD_tensor[:, 4] == indSubj + 13
    idx_half = YD_tensor_half[:, 4] == indSubj + 13
    idx_test = YD_tensor_test[:, 4] == indSubj + 13

    X_tensor = X_tensor[idx]
    YD_tensor = YD_tensor[idx]
    padding_masks_tensor = padding_masks_tensor[idx]

    aug_X, aug_YD, aug_mask = add_gaussian_noise_classwise(X_tensor, YD_tensor, padding_masks_tensor)
    # aug_X, aug_YD, aug_mask = recombination_in_time(
    #     X_tensor, YD_tensor, padding_masks_tensor,
    #     n_chunks=5, output_length=X_tensor.shape[2]
    # )

    if aug_X is not None:
        X_tensor = torch.cat([X_tensor, aug_X], dim=0)
        YD_tensor = torch.cat([YD_tensor, aug_YD], dim=0)
        padding_masks_tensor = torch.cat([padding_masks_tensor, aug_mask], dim=0)

    X_tensor_half = X_tensor_half[idx_half]
    YD_tensor_half = YD_tensor_half[idx_half]
    padding_masks_tensor_half = padding_masks_tensor_half[idx_half]

    X_tensor_test = X_tensor_test[idx_test]
    YD_tensor_test = YD_tensor_test[idx_test]
    padding_masks_tensor_test = padding_masks_tensor_test[idx_test]

    dataset = MultisourceDataset(X_tensor, YD_tensor, padding_masks_tensor)
    dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
    dataset_test = MultisourceDataset(X_tensor_half, YD_tensor_half, padding_masks_tensor_half)
    dataloader_test = DataLoader(dataset_test, batch_size=10, shuffle=True)
    dataset_test0 = MultisourceDataset(X_tensor_test, YD_tensor_test, padding_masks_tensor_test)
    dataloader_test0 = DataLoader(dataset_test0, batch_size=10, shuffle=True)

    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    out_channels = 60
    num_classes = 3
    learning_rate = 0.0005
    num_epochs = 500

    model = EEGNet0(nb_classes=num_classes, Chans=out_channels, Samples=841).to(device)

    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)

    train_loader = DataLoader(dataloader.dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(dataloader_test.dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(dataloader_test0.dataset, batch_size=32, shuffle=True)

    best_val_acc = 0
    for epoch in range(num_epochs):
        model.train()
        correct, total, running_loss = 0, 0, 0
        for batch in train_loader:
            x_batch, yd_batch, mask_batch = batch
            X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
            X_valid = X_valid.unsqueeze(1).to(device)
            yd_batch = yd_batch.to(device)

            output, _ = model(X_valid)
            loss = criterion(output, yd_batch[:, 0])
            running_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            predicted_labels = torch.argmax(output, dim=1)
            correct += (predicted_labels == yd_batch[:, 0]).sum().item()
            total += yd_batch.size(0)

        train_loss = running_loss / len(train_loader)
        train_acc = correct / total * 100

        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for batch in val_loader:
                x_batch, yd_batch, mask_batch = batch
                X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
                X_valid = X_valid.unsqueeze(1).to(device)
                yd_batch = yd_batch.to(device)

                output, _ = model(X_valid)
                loss = criterion(output, yd_batch[:, 0])
                val_loss += loss.item()

                predicted_labels = torch.argmax(output, dim=1)
                val_correct += (predicted_labels == yd_batch[:, 0]).sum().item()
                val_total += yd_batch.size(0)

        val_loss /= len(val_loader)
        val_acc = val_correct / val_total * 100

        test_loss, test_correct, test_total = 0, 0, 0
        with torch.no_grad():
            for batch in test_loader:
                x_batch, yd_batch, mask_batch = batch
                X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
                X_valid = X_valid.unsqueeze(1).to(device)
                yd_batch = yd_batch.to(device)

                output, _ = model(X_valid)
                loss = criterion(output, yd_batch[:, 0])
                test_loss += loss.item()

                predicted_labels = torch.argmax(output, dim=1)
                test_correct += (predicted_labels == yd_batch[:, 0]).sum().item()
                test_total += yd_batch.size(0)

        test_loss /= len(test_loader)
        test_acc = test_correct / test_total * 100

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            save_path = f"./EEGNet0-Weibo-s{indSubj}-TEST.pth"
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, save_path)

            print("✅ Model saved.")

        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}% | "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split, DataLoader
import numpy as np
import random

for indSubj in range(4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = torch.load('./Dataset/eeg_data_zhou.pt', map_location=device)

    X_tensor = data['X']
    YD_tensor = data['YD']
    padding_masks_tensor = data['mask']
    channel_xy_tensor = data['channel_xy']

    X_tensor_half = data['X_val']
    YD_tensor_half = data['YD_val']
    padding_masks_tensor_half = data['mask_val']

    X_tensor_test = data['X_test']
    YD_tensor_test = data['YD_test']
    padding_masks_tensor_test = data['mask_test']

    idx = YD_tensor[:, 4] == indSubj + 0
    idx_half = YD_tensor_half[:, 4] == indSubj + 0
    idx_test = YD_tensor_test[:, 4] == indSubj + 0

    X_tensor = X_tensor[idx]
    YD_tensor = YD_tensor[idx]
    padding_masks_tensor = padding_masks_tensor[idx]

    aug_X, aug_YD, aug_mask = add_gaussian_noise_classwise(X_tensor, YD_tensor, padding_masks_tensor)
    # aug_X, aug_YD, aug_mask = recombination_in_time(
    #     X_tensor, YD_tensor, padding_masks_tensor,
    #     n_chunks=5, output_length=X_tensor.shape[2]
    # )

    if aug_X is not None:
        X_tensor = torch.cat([X_tensor, aug_X], dim=0)
        YD_tensor = torch.cat([YD_tensor, aug_YD], dim=0)
        padding_masks_tensor = torch.cat([padding_masks_tensor, aug_mask], dim=0)

    X_tensor_half = X_tensor_half[idx_half]
    YD_tensor_half = YD_tensor_half[idx_half]
    padding_masks_tensor_half = padding_masks_tensor_half[idx_half]

    X_tensor_test = X_tensor_test[idx_test]
    YD_tensor_test = YD_tensor_test[idx_test]
    padding_masks_tensor_test = padding_masks_tensor_test[idx_test]

    dataset = MultisourceDataset(X_tensor, YD_tensor, padding_masks_tensor)
    dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
    dataset_test = MultisourceDataset(X_tensor_half, YD_tensor_half, padding_masks_tensor_half)
    dataloader_test = DataLoader(dataset_test, batch_size=10, shuffle=True)
    dataset_test0 = MultisourceDataset(X_tensor_test, YD_tensor_test, padding_masks_tensor_test)
    dataloader_test0 = DataLoader(dataset_test0, batch_size=10, shuffle=True)

    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    out_channels = 14
    num_classes = 3
    learning_rate = 0.0005
    num_epochs = 500

    model = EEGNet0(nb_classes=num_classes, Chans=out_channels, Samples=841).to(device)

    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)

    train_loader = DataLoader(dataloader.dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(dataloader_test.dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(dataloader_test0.dataset, batch_size=32, shuffle=True)

    best_val_acc = 0
    for epoch in range(num_epochs):
        model.train()
        correct, total, running_loss = 0, 0, 0
        for batch in train_loader:
            x_batch, yd_batch, mask_batch = batch
            X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
            X_valid = X_valid.unsqueeze(1).to(device)
            yd_batch = yd_batch.to(device)

            output, _ = model(X_valid)
            loss = criterion(output, yd_batch[:, 0])
            running_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            predicted_labels = torch.argmax(output, dim=1)
            correct += (predicted_labels == yd_batch[:, 0]).sum().item()
            total += yd_batch.size(0)

        train_loss = running_loss / len(train_loader)
        train_acc = correct / total * 100

        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for batch in val_loader:
                x_batch, yd_batch, mask_batch = batch
                X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
                X_valid = X_valid.unsqueeze(1).to(device)
                yd_batch = yd_batch.to(device)

                output, _ = model(X_valid)
                loss = criterion(output, yd_batch[:, 0])
                val_loss += loss.item()

                predicted_labels = torch.argmax(output, dim=1)
                val_correct += (predicted_labels == yd_batch[:, 0]).sum().item()
                val_total += yd_batch.size(0)

        val_loss /= len(val_loader)
        val_acc = val_correct / val_total * 100

        test_loss, test_correct, test_total = 0, 0, 0
        with torch.no_grad():
            for batch in test_loader:
                x_batch, yd_batch, mask_batch = batch
                X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
                X_valid = X_valid.unsqueeze(1).to(device)
                yd_batch = yd_batch.to(device)

                output, _ = model(X_valid)
                loss = criterion(output, yd_batch[:, 0])
                test_loss += loss.item()

                predicted_labels = torch.argmax(output, dim=1)
                test_correct += (predicted_labels == yd_batch[:, 0]).sum().item()
                test_total += yd_batch.size(0)

        test_loss /= len(test_loader)
        test_acc = test_correct / test_total * 100

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            save_path = f"./EEGNet0-Zhou-s{indSubj}-TEST.pth"
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, save_path)

            print("✅ Model saved.")

        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}% | "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")


TRAIN LOOP MULTI-SOURCE DATASET

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split, DataLoader
import numpy as np
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = torch.load('./Dataset/eeg_data_200.pt', map_location=device)

X_tensor = data['X']
YD_tensor = data['YD']
padding_masks_tensor = data['mask']
channel_xy_tensor = data['channel_xy']

X_tensor_half = data['X_val']
YD_tensor_half = data['YD_val']
padding_masks_tensor_half = data['mask_val']

X_tensor_test = data['X_test']
YD_tensor_test = data['YD_test']
padding_masks_tensor_test = data['mask_test']

dataset = MultisourceDataset(X_tensor, YD_tensor, padding_masks_tensor)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
dataset_test = MultisourceDataset(X_tensor_half, YD_tensor_half, padding_masks_tensor_half)
dataloader_test = DataLoader(dataset_test, batch_size=10, shuffle=True)
dataset_test0 = MultisourceDataset(X_tensor_test, YD_tensor_test, padding_masks_tensor_test)
dataloader_test0 = DataLoader(dataset_test0, batch_size=10, shuffle=True)

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

out_channels = 60
num_classes = 3
learning_rate = 0.0005
num_epochs = 500

model = EEGNet0(nb_classes=num_classes, Chans=out_channels, Samples=841).to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)

train_loader = DataLoader(dataloader.dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(dataloader_test.dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataloader_test0.dataset, batch_size=32, shuffle=True)

best_val_acc = 0
for epoch in range(num_epochs):
    model.train()
    correct, total, running_loss = 0, 0, 0
    for batch in train_loader:
        x_batch, yd_batch, mask_batch = batch
        # X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
        X_valid = x_batch
        X_valid = X_valid.unsqueeze(1).to(device)
        yd_batch = yd_batch.to(device)

        output, _ = model(X_valid)
        loss = criterion(output, yd_batch[:, 0])
        running_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        predicted_labels = torch.argmax(output, dim=1)
        correct += (predicted_labels == yd_batch[:, 0]).sum().item()
        total += yd_batch.size(0)

    train_loss = running_loss / len(train_loader)
    train_acc = correct / total * 100

    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0
    with torch.no_grad():
        for batch in val_loader:
            x_batch, yd_batch, mask_batch = batch
            # X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
            X_valid = x_batch


            X_valid = X_valid.unsqueeze(1).to(device)
            yd_batch = yd_batch.to(device)

            output, _ = model(X_valid)
            loss = criterion(output, yd_batch[:, 0])
            val_loss += loss.item()

            predicted_labels = torch.argmax(output, dim=1)
            val_correct += (predicted_labels == yd_batch[:, 0]).sum().item()
            val_total += yd_batch.size(0)

    val_loss /= len(val_loader)
    val_acc = val_correct / val_total * 100

    test_loss, test_correct, test_total = 0, 0, 0
    with torch.no_grad():
        for batch in test_loader:
            x_batch, yd_batch, mask_batch = batch
            # X_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
            X_valid = x_batch


            X_valid = X_valid.unsqueeze(1).to(device)
            yd_batch = yd_batch.to(device)

            output, _ = model(X_valid)
            loss = criterion(output, yd_batch[:, 0])
            test_loss += loss.item()

            predicted_labels = torch.argmax(output, dim=1)
            test_correct += (predicted_labels == yd_batch[:, 0]).sum().item()
            test_total += yd_batch.size(0)

    test_loss /= len(test_loader)
    test_acc = test_correct / test_total * 100

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        save_path = f"./EEGNet0-X-TEST.pth"
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, save_path)

        print("✅ Model saved.")

    print(f"Epoch [{epoch+1}/{num_epochs}], "
            f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
            f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}% | "
            f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")


PERFORMANCE EVALUATIONS

In [7]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = torch.load('./Dataset/eeg_data_weibo.pt', map_location=device)

X_tensor = data['X']
YD_tensor = data['YD']
padding_masks_tensor = data['mask']
channel_xy_tensor = data['channel_xy']

X_tensor_half = data['X_val']
YD_tensor_half = data['YD_val']
padding_masks_tensor_half = data['mask_val']

X_tensor_test = data['X_test']
YD_tensor_test = data['YD_test']
padding_masks_tensor_test = data['mask_test']

dataset = MultisourceDataset(X_tensor, YD_tensor, padding_masks_tensor)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
dataset_test = MultisourceDataset(X_tensor_half, YD_tensor_half, padding_masks_tensor_half)
dataloader_test = DataLoader(dataset_test, batch_size=10, shuffle=True)
dataset_test0 = MultisourceDataset(X_tensor_test, YD_tensor_test, padding_masks_tensor_test)
dataloader_test0 = DataLoader(dataset_test0, batch_size=10, shuffle=True)

from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import random
import numpy as np
import os

# === Seeding ===
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# === Hyperparameters ===
Chans = 60  # Input EEG channels
Samples = 841  # Number of time samples
num_classes = 3
learning_rate = 0.0005
num_epochs = 500

# === Test Loader ===
test_loader = DataLoader(dataloader_test0.dataset, batch_size=32, shuffle=False)

# === Global Containers for Overall Metrics ===
overall_preds = []
overall_targets = []

# === Loop over last 10 domains (13 to 22 in dataset) ===
print("Metrics by Domain (Each with its own EEGNet0 model):")
for model_id in range(10):
    dataset_domain_id = model_id + 13

    # === Initialize Model ===
    model = EEGNet0(nb_classes=num_classes, Chans=Chans, Samples=Samples).to(device)
    model.eval()

    # === Load Checkpoint ===
    save_path = f"./BaselineModel/EEGNet0-Weibo-s{model_id}-rt-500.pth"
    if not os.path.exists(save_path):
        print(f"⚠️  Checkpoint not found for model {model_id}: {save_path}")
        continue

    checkpoint = torch.load(save_path)
    model.load_state_dict(checkpoint['model_state_dict'])

    # === Containers for Domain-Specific Evaluation ===
    y_true_domain = []
    y_pred_domain = []

    with torch.no_grad():
        for batch in test_loader:
            x_batch, yd_batch, mask_batch = batch
            domains = yd_batch[:, 4].cpu().numpy()

            # Filter only samples from dataset domain `model_id + 13`
            mask = (domains == dataset_domain_id)
            if not np.any(mask):
                continue

            x_domain = x_batch[mask]
            yd_domain = yd_batch[mask]
            mask_domain = mask_batch[mask]

            if x_domain.size(0) == 0:
                continue

            # Apply masking and reshape: (B, 1, Chans, Samples)
            X_valid = x_domain[mask_domain].reshape(x_domain.shape[0], -1, x_domain.shape[-1])
            X_valid = X_valid.unsqueeze(1).to(device)

            logits, _ = model(X_valid)
            predictions = torch.argmax(logits, dim=1)

            targets = yd_domain[:, 0].cpu().numpy()
            preds = predictions.cpu().numpy()

            y_true_domain.extend(targets)
            y_pred_domain.extend(preds)

    # === Metrics for current domain ===
    if not y_true_domain:
        print(f"\n- Domain {dataset_domain_id} (Model {model_id}): No data.")
        continue

    print(f"\n- Domain {dataset_domain_id} (Model {model_id}):")
    accuracy = accuracy_score(y_true_domain, y_pred_domain)
    print(f"    - Accuracy: {accuracy:.4f}")

    classes = sorted(set(y_true_domain))
    f1_scores = f1_score(y_true_domain, y_pred_domain, average=None, labels=classes)
    for cls, f1 in zip(classes, f1_scores):
        print(f"    - Class {cls}: F1 Score = {f1:.4f}")

    cm = confusion_matrix(y_true_domain, y_pred_domain, labels=classes)
    cm_percent = cm.astype(float) / cm.sum(axis=1, keepdims=True)

    print(f"    - Misclassification Breakdown (rows = true class):")
    for i, cls in enumerate(classes):
        breakdown = []
        for j, pred_cls in enumerate(classes):
            pct = cm_percent[i, j] * 100
            if i == j:
                continue
            breakdown.append(f"{pct:.1f}% → class {pred_cls}")
        if breakdown:
            print(f"        Class {cls} misclassified as: {', '.join(breakdown)}")
        else:
            print(f"        Class {cls} has no misclassifications.")

    # Accumulate for overall metrics
    overall_preds.extend(y_pred_domain)
    overall_targets.extend(y_true_domain)

# === Print overall metrics ===
if overall_preds:
    overall_accuracy = accuracy_score(overall_targets, overall_preds)
    all_classes = sorted(set(overall_targets))
    f1_per_class = f1_score(overall_targets, overall_preds, average=None, labels=all_classes)

    print("\nOverall Metrics:")
    print(f"    - Overall Accuracy: {overall_accuracy:.4f}")
    for cls, f1 in zip(all_classes, f1_per_class):
        print(f"    - Class {cls}: F1 Score = {f1:.4f}")
else:
    print("\nNo predictions were collected. Please check domain IDs or dataset.")


Metrics by Domain (Each with its own EEGNet0 model):

- Domain 13 (Model 0):
    - Accuracy: 0.5612
    - Class 0: F1 Score = 0.7719
    - Class 1: F1 Score = 0.5185
    - Class 2: F1 Score = 0.4138
    - Misclassification Breakdown (rows = true class):
        Class 0 misclassified as: 23.5% → class 1, 11.8% → class 2
        Class 1 misclassified as: 2.9% → class 0, 37.1% → class 2
        Class 2 misclassified as: 0.0% → class 0, 58.6% → class 1

- Domain 14 (Model 1):
    - Accuracy: 0.3882
    - Class 0: F1 Score = 0.4675
    - Class 1: F1 Score = 0.4000
    - Class 2: F1 Score = 0.1818
    - Misclassification Breakdown (rows = true class):
        Class 0 misclassified as: 35.5% → class 1, 6.5% → class 2
        Class 1 misclassified as: 48.3% → class 0, 10.3% → class 2
        Class 2 misclassified as: 56.0% → class 0, 32.0% → class 1

- Domain 15 (Model 2):
    - Accuracy: 0.3298
    - Class 0: F1 Score = 0.3158
    - Class 1: F1 Score = 0.3721
    - Class 2: F1 Score = 0.2667


In [9]:
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import random
import numpy as np
import os
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = torch.load('./Dataset/eeg_data_200.pt', map_location=device)

X_tensor = data['X']
YD_tensor = data['YD']
padding_masks_tensor = data['mask']
channel_xy_tensor = data['channel_xy']

X_tensor_half = data['X_val']
YD_tensor_half = data['YD_val']
padding_masks_tensor_half = data['mask_val']

X_tensor_test = data['X_test']
YD_tensor_test = data['YD_test']
padding_masks_tensor_test = data['mask_test']

dataset = MultisourceDataset(X_tensor, YD_tensor, padding_masks_tensor)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
dataset_test = MultisourceDataset(X_tensor_half, YD_tensor_half, padding_masks_tensor_half)
dataloader_test = DataLoader(dataset_test, batch_size=10, shuffle=True)
dataset_test0 = MultisourceDataset(X_tensor_test, YD_tensor_test, padding_masks_tensor_test)
dataloader_test0 = DataLoader(dataset_test0, batch_size=10, shuffle=True)

# === Seeding ===
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# === Hyperparameters ===
Chans = 60  # Input EEG channels
Samples = 841  # Number of time samples
num_classes = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Test Loader ===
test_loader = DataLoader(dataloader_test0.dataset, batch_size=32, shuffle=False)

# === Initialize and Load One Model ===
model = EEGNet0(nb_classes=num_classes, Chans=Chans, Samples=Samples).to(device)
model.eval()

# === Load Checkpoint ===
save_path = "./BaselineModel/EEGNet0-X.pth"
if not os.path.exists(save_path):
    raise FileNotFoundError(f"Model checkpoint not found: {save_path}")

checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint['model_state_dict'])

# === Containers for Evaluation ===
overall_preds = []
overall_targets = []
domain_wise_results = {}

with torch.no_grad():
    for batch in test_loader:
        x_batch, yd_batch, mask_batch = batch
        domains = yd_batch[:, 4].cpu().numpy()
        targets = yd_batch[:, 0].cpu().numpy()

        # Apply masking and reshape: (B, 1, Chans, Samples)
        # x_valid = x_batch[mask_batch].reshape(x_batch.shape[0], -1, x_batch.shape[-1])
        x_valid = x_batch
        x_valid = x_valid.unsqueeze(1).to(device)

        logits, _ = model(x_valid)
        predictions = torch.argmax(logits, dim=1).cpu().numpy()

        # Store all predictions and targets
        overall_preds.extend(predictions)
        overall_targets.extend(targets)

        # Store per-domain
        for i in range(len(domains)):
            dom = int(domains[i])
            if dom not in domain_wise_results:
                domain_wise_results[dom] = {"y_true": [], "y_pred": []}
            domain_wise_results[dom]["y_true"].append(targets[i])
            domain_wise_results[dom]["y_pred"].append(predictions[i])

# === Per-Domain Metrics ===
print("Metrics by Domain (Using One Model):")
for dom in sorted(domain_wise_results.keys()):
    y_true = domain_wise_results[dom]["y_true"]
    y_pred = domain_wise_results[dom]["y_pred"]

    print(f"\n- Domain {dom}:")
    accuracy = accuracy_score(y_true, y_pred)
    print(f"    - Accuracy: {accuracy:.4f}")

    classes = sorted(set(y_true))
    f1_scores = f1_score(y_true, y_pred, average=None, labels=classes)
    for cls, f1 in zip(classes, f1_scores):
        print(f"    - Class {cls}: F1 Score = {f1:.4f}")

    cm = confusion_matrix(y_true, y_pred, labels=classes)
    cm_percent = cm.astype(float) / cm.sum(axis=1, keepdims=True)

    print(f"    - Misclassification Breakdown (rows = true class):")
    for i, cls in enumerate(classes):
        breakdown = []
        for j, pred_cls in enumerate(classes):
            if i == j:
                continue
            pct = cm_percent[i, j] * 100
            breakdown.append(f"{pct:.1f}% → class {pred_cls}")
        if breakdown:
            print(f"        Class {cls} misclassified as: {', '.join(breakdown)}")
        else:
            print(f"        Class {cls} has no misclassifications.")

# === Overall Metrics ===
if overall_preds:
    overall_accuracy = accuracy_score(overall_targets, overall_preds)
    all_classes = sorted(set(overall_targets))
    f1_per_class = f1_score(overall_targets, overall_preds, average=None, labels=all_classes)

    print("\nOverall Metrics (All Domains):")
    print(f"    - Overall Accuracy: {overall_accuracy:.4f}")
    for cls, f1 in zip(all_classes, f1_per_class):
        print(f"    - Class {cls}: F1 Score = {f1:.4f}")
else:
    print("\nNo predictions were collected. Please check masking or data loading.")


Metrics by Domain (Using One Model):

- Domain 0:
    - Accuracy: 0.7800
    - Class 0: F1 Score = 0.8000
    - Class 1: F1 Score = 0.7872
    - Class 2: F1 Score = 0.7586
    - Misclassification Breakdown (rows = true class):
        Class 0 misclassified as: 8.0% → class 1, 20.0% → class 2
        Class 1 misclassified as: 2.0% → class 0, 24.0% → class 2
        Class 2 misclassified as: 6.0% → class 0, 6.0% → class 1

- Domain 1:
    - Accuracy: 0.8800
    - Class 0: F1 Score = 0.9899
    - Class 1: F1 Score = 0.8317
    - Class 2: F1 Score = 0.8200
    - Misclassification Breakdown (rows = true class):
        Class 0 misclassified as: 0.0% → class 1, 2.0% → class 2
        Class 1 misclassified as: 0.0% → class 0, 16.0% → class 2
        Class 2 misclassified as: 0.0% → class 0, 18.0% → class 1

- Domain 2:
    - Accuracy: 0.8200
    - Class 0: F1 Score = 0.8257
    - Class 1: F1 Score = 0.8261
    - Class 2: F1 Score = 0.8081
    - Misclassification Breakdown (rows = true class):