In [1]:
import warnings

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from mne.decoding import CSP
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import make_pipeline

import moabb
from moabb.datasets import BNCI2014_001, Zhou2016, Schirrmeister2017, Weibo2014
from moabb.evaluations import WithinSessionEvaluation
from moabb.paradigms import LeftRightImagery

moabb.set_log_level("info")
warnings.filterwarnings("ignore")
# moabb.set_download_dir("D:\TA\database")

DATA LOAD

In [2]:
from torch.utils.data import Dataset, DataLoader

class MultisourceDataset(Dataset):
    def __init__(self, X, YD, channel_mask):
        self.X = X
        self.YD = YD
        self.channel_mask = channel_mask

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = self.X[idx]
        yd = self.YD[idx]
        masks = self.channel_mask[idx]
        return x, yd, masks  # Return data + index

In [3]:
from torch.utils.data import Sampler
from collections import defaultdict
import random

class MultiDomainBatchSampler(Sampler):
    def __init__(self, yd, batch_size):
        self.batch_size = batch_size
        self.domain_to_indices = defaultdict(list)

        # Collect indices for each domain
        for idx, label in enumerate(yd):
            domain = int(label[4])  # domain info is in 5th column
            self.domain_to_indices[domain].append(idx)

        self.domains = list(self.domain_to_indices.keys())

    def __iter__(self):
        # Create a shuffled iterator for each domain
        domain_iters = {
            d: iter(random.sample(v, len(v)))
            for d, v in self.domain_to_indices.items()
        }
        domain_cycle = self.domains.copy()
        random.shuffle(domain_cycle)

        # Track active domains
        active_domains = set(domain_iters.keys())
        batch = []

        while active_domains:
            for domain in domain_cycle:
                if domain not in active_domains:
                    # If domain is exhausted, pick a random sample from the original pool
                    idx = random.choice(self.domain_to_indices[domain])
                else:
                    try:
                        # Pick next in the shuffled list
                        idx = next(domain_iters[domain])
                    except StopIteration:
                        # If exhausted, switch to random sampling
                        active_domains.remove(domain)
                        idx = random.choice(self.domain_to_indices[domain])

                batch.append(idx)

                if len(batch) == self.batch_size:
                    yield batch
                    batch = []

        if batch:
            yield batch

    def __len__(self):
        return sum(len(v) for v in self.domain_to_indices.values()) // self.batch_size


eeg_data_200 = multi-source dataset
eeg_data_200_cross = multi-source dataset for cross subject classification
eeg_data_bnci = BNCI2014_001 subset of multi-source dataset
eeg_data_zhou = Zhou2016 subset of multi-source dataset
eeg_data_weibo = Weibo2014 subset of multi-source dataset

X = training rows/trials
YD = training labels
mask = training channel masks

X_val = validatiom rows/trials
YD_val = validation labels
mask_val = validation channel masks

X_test = test rows/trials
YD_test = test labels
mask_test = test channel masks

channel_xy = list of channels positions

Labels = [class, dataset sources id, subject id (of each dataset), session id (of each subject), domain id]
domain id is unique
combination (dataset sources id, subject id) is unique
combination (dataset sources id, subject id, session id) is unique

In [4]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = torch.load('./Dataset/eeg_data_200.pt', map_location=device)

X_tensor = data['X']
YD_tensor = data['YD']
padding_masks_tensor = data['mask']
channel_xy_tensor = data['channel_xy']

X_tensor_half = data['X_val']
YD_tensor_half = data['YD_val']
padding_masks_tensor_half = data['mask_val']

X_tensor_test = data['X_test']
YD_tensor_test = data['YD_test']
padding_masks_tensor_test = data['mask_test']

dataset = MultisourceDataset(X_tensor, YD_tensor, padding_masks_tensor)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
dataset_test = MultisourceDataset(X_tensor_half, YD_tensor_half, padding_masks_tensor_half)
dataloader_test = DataLoader(dataset_test, batch_size=10, shuffle=True)
dataset_test0 = MultisourceDataset(X_tensor_test, YD_tensor_test, padding_masks_tensor_test)
dataloader_test0 = DataLoader(dataset_test0, batch_size=10, shuffle=True)

SPATIAL ATTENTION

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def compute_cos_sin(x, y, K):
    kk = torch.arange(1, K+1, device=device)
    ll = torch.arange(1, K+1, device=device)
    cos_fun = lambda k, l, x, y: torch.cos(2*torch.pi*(k*x + l*y))
    sin_fun = lambda k, l, x, y: torch.sin(2*torch.pi*(k*x + l*y))
    return torch.stack([cos_fun(kk[None,:], ll[:,None], x, y) for x, y in zip(x, y)]).reshape(x.shape[0],-1).float(), torch.stack([sin_fun(kk[None,:], ll[:,None], x, y) for x, y in zip(x, y)]).reshape(x.shape[0],-1).float()


class SpatialAttention(nn.Module):
    def __init__(self, out_channels, K, dropout=False, seed=None):
        super().__init__()
        self.outchans = out_channels
        self.K = K
        self.z = nn.Parameter(
            torch.randn(self.outchans, K * K, dtype=torch.cfloat, device=device) / (K * K)
        )  # Each output channel has its own KxK z matrix
        self.z.requires_grad = True
        self.dropout = dropout
        self.seed = seed

    def forward(self, X, masks, channel_xy):
        batch_size, _, _ = X.shape  # (batch, 27, T)

        # Move everything to device before processing
        X = X.to(device)
        masks = masks.to(device)
        channel_xy = channel_xy.to(device)

        # Masking: Extract only valid channels
        X_valid = [X[i][masks[i]] for i in range(batch_size)]  # List of tensors with different shapes
        ch_pos_valid = [channel_xy[masks[i]] for i in range(batch_size)]  # Different per sample

        # Compute cosine and sine matrices
        cos_sin_matrices = [compute_cos_sin(ch[:, 0], ch[:, 1], self.K) for ch in ch_pos_valid]

        attended_outputs = []
        for i in range(batch_size):
            cos_mat, sin_mat = cos_sin_matrices[i]
            cos_mat, sin_mat = cos_mat.to(device), sin_mat.to(device)

            # Compute attention weights
            a = torch.matmul(self.z.real, cos_mat.T) + torch.matmul(self.z.imag, sin_mat.T)

            if self.dropout and X_valid[i].shape[0] > 1:
                # print(ch_pos_valid.shape)
                # Get min-max x and y positions
                min_x, max_x = ch_pos_valid[i][:, 0].min(), ch_pos_valid[i][:, 0].max()
                min_y, max_y = ch_pos_valid[i][:, 1].min(), ch_pos_valid[i][:, 1].max()

                # Sample a random position within this range
                if self.seed is not None:
                    gen = torch.Generator(device=device).manual_seed(self.seed)
                    rand_x = torch.empty(1, device=device).uniform_(min_x, max_x, generator=gen).item()
                    rand_y = torch.empty(1, device=device).uniform_(min_y, max_y, generator=gen).item()
                else:
                    rand_x = torch.FloatTensor(1).uniform_(min_x, max_x).item()
                    rand_y = torch.FloatTensor(1).uniform_(min_y, max_y).item()

                # Compute Euclidean distance from the random point to all channels
                distances = (ch_pos_valid[i][:, 0] - rand_x) ** 2 + (ch_pos_valid[i][:, 1] - rand_y) ** 2

                # Find indices of channels within a 000.1 distance radius
                drop_mask = distances < 0.001

                # Drop the selected channels
                X_valid[i] = X_valid[i][~drop_mask]  # Keep only non-dropped channels
                a = a[:, ~drop_mask]  # Adjust attention weights accordingly

            a = F.softmax(a, dim=1)  # Normalize over all valid input channels

            # Apply attention to EEG data
            attended = torch.matmul(a, X_valid[i])  # (out_channels, T)
            attended_outputs.append(attended)

        # Pad to ensure uniform shape
        max_len = max(att.shape[1] for att in attended_outputs)
        padded_attended = torch.stack([F.pad(att, (0, max_len - att.shape[1])) for att in attended_outputs])

        return padded_attended


EEG_DG

In [6]:
# coding=utf-8
import torch
import torch.nn as nn
import numpy as np



class DG_Network(nn.Module):
    def __init__(self, classes, domains, feature_size=4096, F1=4, D=2, channels=14):
        super(DG_Network, self).__init__()
        self.dropout = 0.25  # default:0.25

        self.special_features = nn.ModuleList([
            nn.Sequential(
                nn.Linear(feature_size, 400),
                nn.Dropout(self.dropout)
            )
            for _ in range(domains)
        ])

        # self.bn = nn.BatchNorm1d(400)

        self.domain_classifier = nn.Sequential(
            nn.Linear(feature_size, domains),
        )

        self.classifier = nn.Sequential(
            nn.Linear(400, classes)
        )

        self.block1_1 = nn.Sequential(
            nn.ZeroPad2d((3, 4, 0, 0)),
            nn.Conv2d(1, F1, kernel_size=(1, 8), bias=False),
            nn.BatchNorm2d(F1)
        )

        self.block1_2 = nn.Sequential(
            nn.ZeroPad2d((7, 8, 0, 0)),
            nn.Conv2d(1, F1, kernel_size=(1, 16), bias=False),
            nn.BatchNorm2d(F1)
        )

        self.block1_3 = nn.Sequential(
            nn.ZeroPad2d((15, 16, 0, 0)),
            nn.Conv2d(1, F1, kernel_size=(1, 32), bias=False),
            nn.BatchNorm2d(F1)
        )

        self.block1_4 = nn.Sequential(
            nn.ZeroPad2d((31, 32, 0, 0)),
            nn.Conv2d(1, F1, kernel_size=(1, 64), bias=False),
            nn.BatchNorm2d(F1)
        )

        self.block2 = nn.Sequential(
            # DepthwiseConv2D
            nn.Conv2d(F1 * 4, F1 * 4 * D, kernel_size=(channels, 1), groups=F1 * 4, bias=False),
            # groups=F1 for depthWiseConv
            nn.BatchNorm2d(F1 * 4 * D),
            # nn.ELU(inplace=True),
            nn.ReLU(inplace=True),
            nn.AvgPool2d((1, 4)),
            nn.Dropout(self.dropout),
        )

        self.block3_1 = nn.Sequential(
            # SeparableConv2D
            nn.ZeroPad2d((0, 1, 0, 0)),
            nn.Conv2d(F1 * 4 * D, F1 * 4 * D, kernel_size=(1, 2), groups=F1 * 4 * D, bias=False),
            # groups=F1 for depthWiseConv
            nn.BatchNorm2d(F1 * 4 * D),
            # nn.ELU(inplace=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(F1 * 4 * D, F1 * 4 * D, kernel_size=(1, 1), groups=1, bias=False),  # point-wise cnn
            nn.BatchNorm2d(F1 * 4 * D),
        )

        self.block3_2 = nn.Sequential(
            # SeparableConv2D
            nn.ZeroPad2d((1, 2, 0, 0)),
            nn.Conv2d(F1 * 4 * D, F1 * 4 * D, kernel_size=(1, 4), groups=F1 * 4 * D, bias=False),
            # groups=F1 for depthWiseConv
            nn.BatchNorm2d(F1 * 4 * D),
            # nn.ELU(inplace=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(F1 * 4 * D, F1 * 4 * D, kernel_size=(1, 1), groups=1, bias=False),  # point-wise cnn
            nn.BatchNorm2d(F1 * 4 * D),
        )

        self.block3_3 = nn.Sequential(
            # SeparableConv2D
            nn.ZeroPad2d((3, 4, 0, 0)),
            nn.Conv2d(F1 * 4 * D, F1 * 4 * D, kernel_size=(1, 8), groups=F1 * 4 * D, bias=False),
            # groups=F1 for depthWiseConv
            nn.BatchNorm2d(F1 * 4 * D),
            # nn.ELU(inplace=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(F1 * 4 * D, F1 * 4 * D, kernel_size=(1, 1), groups=1, bias=False),  # point-wise cnn
            nn.BatchNorm2d(F1 * 4 * D),
        )

        self.block3_4 = nn.Sequential(
            # SeparableConv2D
            nn.ZeroPad2d((7, 8, 0, 0)),
            nn.Conv2d(F1 * 4 * D, F1 * 4 * D, kernel_size=(1, 16), groups=F1 * 4 * D, bias=False),
            # groups=F1 for depthWiseConv
            nn.BatchNorm2d(F1 * 4 * D),
            # nn.ELU(inplace=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(F1 * 4 * D, F1 * 4 * D, kernel_size=(1, 1), groups=1, bias=False),  # point-wise cnn
            nn.BatchNorm2d(F1 * 4 * D),
        )

        self.block4 = nn.Sequential(
            # nn.ELU(inplace=True),
            nn.ReLU(inplace=True),
            nn.AvgPool2d((1, 8)),
            nn.Dropout(self.dropout)
        )


    def forward(self, features):

        feat_1 = self.block1_1(features)
        feat_2 = self.block1_2(features)
        feat_3 = self.block1_3(features)
        feat_4 = self.block1_4(features)
        feat = torch.cat((feat_1, feat_2, feat_3, feat_4), dim=1)

        feature = self.block2(feat)

        feature_1 = self.block3_1(feature)
        feature_2 = self.block3_2(feature)
        feature_3 = self.block3_3(feature)
        feature_4 = self.block3_4(feature)
        features = torch.cat((feature_1, feature_2, feature_3, feature_4), dim=1)

        features = self.block4(features)

        features = torch.flatten(features, 1)

        # Assuming `features` is the input tensor of shape [batch_size, feature_size]
        Feat_s = [special(features) for special in self.special_features]

        # feat for domain classifier, dom for computing domain specific loss
        feat_ = self.domain_classifier(features)
        weight = nn.functional.softmax(feat_, dim=1)

        featAll = torch.stack(Feat_s, dim=1)
        weighted = weight.unsqueeze(0).permute(1, 0, 2)
        weighted_feature = torch.bmm(weighted, featAll)
        weighted_feature = torch.flatten(weighted_feature, 1)
        # weighted_feature = self.bn(weighted_feature)
        logits = self.classifier(weighted_feature)

        if self.training:
            return logits, feat_, Feat_s, weighted_feature
        else:
            return logits, weighted_feature


In [7]:
import torch
import torch.nn as nn
from sklearn.metrics.pairwise import euclidean_distances

class Dist_Loss(nn.Module):
    def __init__(self):
        super(Dist_Loss, self).__init__()


    def intraclass_compactness(self, data, labels):
        unique_labels = torch.unique(labels)
        compactness = torch.tensor(0.0, device=data.device)

        for label in unique_labels:
            class_data = data[labels == label]

            distances = torch.cdist(class_data, class_data, p=2)
            compactness += torch.sum(distances) / 2

        return compactness / data.shape[0]


    def interclass_separability(self, data, labels):
        unique_labels = torch.unique(labels)
        separability = torch.tensor(0.0, device=data.device)

        for i in range(len(unique_labels)):
            for j in range(len(unique_labels)):
                if i != j:
                  class_data_1 = data[labels == unique_labels[i]]
                  class_data_2 = data[labels == unique_labels[j]]

                  distances = torch.cdist(class_data_1, class_data_2, p=2)
                  separability += torch.sum(distances)

        return separability / data.shape[0]

    def compute_class_centers(self, data, labels, all_unique_labels):
        class_centers = torch.zeros(len(all_unique_labels), data.size(1), device=data.device)
        valid_mask = torch.zeros(len(all_unique_labels), dtype=torch.bool, device=data.device)

        for label in torch.unique(labels):
            class_data = data[labels == label]
            index = (all_unique_labels == label).nonzero(as_tuple=True)[0].item()

            if class_data.size(0) > 0:
                class_center = torch.mean(class_data, dim=0)
                class_centers[index] = class_center
                valid_mask[index] = True

        return class_centers, valid_mask


    def forward(self, all_data, all_labels, alpha):
        num_domains = len(all_data)
        total_dist = torch.tensor(0.0, device=all_data[0].device)

        # Step 1: Intra + Inter losses for each domain
        for data, labels in zip(all_data, all_labels):
            dist_1 = self.intraclass_compactness(data, labels)
            # print(dist_1.item())
            dist_2 = self.interclass_separability(data, labels)
            # print(dist_2.item())
            total_dist += dist_1 - alpha * dist_2

        all_unique_labels = torch.unique(torch.cat(all_labels))

        all_centers = []
        all_valid_masks = []

        for data, labels in zip(all_data, all_labels):
            centers, valid_mask = self.compute_class_centers(data, labels, all_unique_labels)
            all_centers.append(centers)
            all_valid_masks.append(valid_mask)


        # Step 3: Compute pairwise dist_31 across different domains only
        dist_31 = torch.tensor(0.0, device=all_data[0].device)

        for i in range(num_domains):
            for j in range(i, num_domains):
                src_centers, src_valid = all_centers[i], all_valid_masks[i]
                tgt_centers, tgt_valid = all_centers[j], all_valid_masks[j]

                valid_indices = [k for k in range(len(src_centers)) if src_valid[k] and tgt_valid[k]]

                if valid_indices:
                    valid_idx_tensor = torch.tensor(valid_indices, device=src_centers.device)
                    src = src_centers[valid_idx_tensor]
                    tgt = tgt_centers[valid_idx_tensor]

                    matrix = torch.cdist(src, tgt, p=2)
                    diagonal = torch.diag(matrix)
                    dist_31 += torch.mean(diagonal)
                    # print(dist_31.item())

        return total_dist

In [8]:
import torch
import torch.nn as nn


class MMD_loss(nn.Module):
    def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5, eps=1e-8):
        super(MMD_loss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = None
        self.kernel_type = kernel_type
        self.eps = eps

    def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        n_samples = int(source.size()[0]) + int(target.size()[0])
        total = torch.cat([source, target], dim=0)
        total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        L2_distance = ((total0-total1)**2).sum(2)
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples + self.eps)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
        kernel_val = [torch.exp(-L2_distance / (bandwidth_temp + self.eps)) for bandwidth_temp in bandwidth_list]
        return sum(kernel_val)

    def linear_mmd2(self, f_of_X, f_of_Y):
        delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)
        loss = delta.dot(delta.T)
        return loss

    def forward(self, source_all):
        # Compute mean across domains → shape: [batch_size, feat_dim]
        target = torch.mean(source_all, dim=0)

        total_loss = torch.tensor(0.0, device=source_all.device)
        for i in range(source_all.shape[0]):
            source = source_all[i]
            if self.kernel_type == 'linear':
                total_loss += self.linear_mmd2(source, target)
            elif self.kernel_type == 'rbf':
                batch_size = int(source.shape[0])
                kernels = self.guassian_kernel(
                    source, target,
                    kernel_mul=self.kernel_mul,
                    kernel_num=self.kernel_num,
                    fix_sigma=self.fix_sigma
                )
                XX = torch.mean(kernels[:batch_size, :batch_size])
                YY = torch.mean(kernels[batch_size:, batch_size:])
                XY = torch.mean(kernels[:batch_size, batch_size:])
                YX = torch.mean(kernels[batch_size:, :batch_size])
                loss = torch.mean(XX + YY - XY - YX)
                total_loss += loss
        return total_loss / source_all.shape[0]


TRAIN LOOP

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import os
import random
import numpy as np
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Set hyperparameters
out_channels = 46
K = 48
num_classes = 3
learning_rate = 0.0001
num_epochs = 500 #166
train_ratio = 0.7  # 70% train, 30% validation

# Initialize model
model = SpatialAttention(out_channels, K, dropout=True, seed=seed).to(device)
pool = nn.AvgPool1d(kernel_size=75, stride=15)

# Classification head
num_domains = 23
dg = DG_Network(num_classes, num_domains, channels=out_channels, feature_size=3328).to(device)
dl = Dist_Loss().to(device)
mmdl = MMD_loss().to(device)
softmax = nn.Softmax(dim=1)

dist_weight = 0.1
mmd_weigth = 0.1
dom_weight = 0.15

# Loss and optimizer
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
criterion_dom = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.Adam(
    list(model.parameters()) +
    list(dg.parameters()), lr=learning_rate, weight_decay=5e-4
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

sampler = MultiDomainBatchSampler(dataloader.dataset.YD, batch_size=30)
train_loader = DataLoader(dataloader.dataset, batch_sampler=sampler)

val_loader = DataLoader(dataloader_test.dataset, batch_size=39, shuffle=True)
test_loader = DataLoader(dataloader_test0.dataset, batch_size=39, shuffle=True)

best_val_acc = 0

# save_path = "/content/drive/MyDrive/modelTA/checkpoint46out.pth"

# if os.path.exists(save_path):
#     checkpoint = torch.load(save_path)
#     model.load_state_dict(checkpoint['model_state_dict'])
#     dg.load_state_dict(checkpoint['dg_state_dict'])
#     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])


# Training loop
try:
    for epoch in range(num_epochs):
        model.train()
        # model1.train()
        dg.train()

        correct, total, running_loss, total_dist_loss, total_mmd_loss, total_dom_loss = 0, 0, 0, 0, 0, 0
        for batch in train_loader:
            x_batch, yd_batch, mask_batch = batch
            # if epoch >= 100:
            #     noise = torch.randn_like(x_batch) * 0.2  # Adjust std as needed
            #     x_batch = x_batch + noise
            output = model(x_batch, mask_batch, channel_xy_tensor)
            # output = model1(output)
            # output = pool(output)

            # Flatten
            # output = output.view(output.shape[0], -1)
            output = output.unsqueeze(1)
            logits, weight, Feat_s, weighted_feature = dg(output)

            # loss
            feat_all = torch.stack(Feat_s, dim=0)  # [n_domains, batch_size, feat_dim]
            domain_indices = yd_batch[:, 4]        # [batch_size]
            unique_domains = domain_indices.unique(sorted=True)
            labels = yd_batch[:, 0]

            domain_id_list = domain_indices.unique(sorted=True)  # e.g., [1, 8, 101]
            domain_id_to_index = {dom.item(): idx for idx, dom in enumerate(domain_id_list)}

            grouped_feats = []
            grouped_labels = []

            for dom_id in domain_id_list:
                domain_mask = domain_indices == dom_id  # Boolean mask
                feats = feat_all[domain_id_to_index[dom_id.item()]][domain_mask]
                labs = labels[domain_mask]

                grouped_feats.append(feats)
                grouped_labels.append(labs)

            predictions = softmax(logits)
            # Compute loss
            dist_loss = dl(grouped_feats, grouped_labels, 0.1)
            mmd_loss = mmdl(feat_all)
            loss = criterion(logits, yd_batch[:, 0])
            dom_loss = criterion_dom(weight, yd_batch[:, 4])
            total_loss = loss + dist_weight * dist_loss + mmd_weigth * mmd_loss + dom_weight * dom_loss
            running_loss += total_loss.item()
            total_dist_loss += dist_loss.item()
            total_mmd_loss += mmd_loss.item()
            total_dom_loss += dom_loss.item()

            # Backpropagation
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()


            # Training accuracy
            predicted_labels = torch.argmax(predictions, dim=1)
            correct += (predicted_labels == yd_batch[:, 0]).sum().item()
            total += yd_batch.size(0)

        train_loss = running_loss / len(train_loader)
        train_acc = correct / total * 100
        # scheduler.step()

        # Validation phase
        model.eval()
        # model1.eval()
        dg.eval()
        val_loss, val_correct, val_total = 0, 0, 0

        with torch.no_grad():
            for batch in val_loader:
                x_batch, yd_batch, mask_batch = batch
                output = model(x_batch, mask_batch, channel_xy_tensor)
                # output = model1(output)
                # output = pool(output)

                # Flatten
                output = output.unsqueeze(1)
                # print(output.shape)
                logits, _ = dg(output)
                predictions = softmax(logits)

                # Compute loss
                loss = criterion(logits, yd_batch[:, 0])
                # dist_loss = dl(grouped_feats, grouped_labels, 0.1)
                val_loss += loss.item()

                # Validation accuracy
                predicted_labels = torch.argmax(predictions, dim=1)
                val_correct += (predicted_labels == yd_batch[:, 0]).sum().item()
                val_total += yd_batch.size(0)

        val_loss /= len(val_loader)
        val_acc = val_correct / val_total * 100

        test_loss, test_correct, test_total = 0, 0, 0

        with torch.no_grad():
            for batch in test_loader:
                x_batch, yd_batch, mask_batch = batch
                output = model(x_batch, mask_batch, channel_xy_tensor)
                # output = model1(output)
                # output = pool(output)

                # Flatten
                output = output.unsqueeze(1)
                logits, _ = dg(output)
                predictions = softmax(logits)

                # Compute loss
                loss = criterion(logits, yd_batch[:, 0])
                # dist_loss = dl(grouped_feats, grouped_labels, 0.1)
                test_loss += loss.item()

                # Validation accuracy
                predicted_labels = torch.argmax(predictions, dim=1)
                test_correct += (predicted_labels == yd_batch[:, 0]).sum().item()
                test_total += yd_batch.size(0)

        test_loss /= len(test_loader)
        test_acc = test_correct / test_total * 100

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            save_path = "./checkpoint-TEST-BEST.pth"
            torch.save({
                'model_state_dict': model.state_dict(),
                'dg_state_dict': dg.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, save_path)

            print("✅ Model saved.")

        save_path = "./checkpoint-TEST.pth"
        torch.save({
            'model_state_dict': model.state_dict(),
            'dg_state_dict': dg.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, save_path)

        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {train_loss:.4f}, Dist Loss: {total_dist_loss/len(train_loader):.4f}, MMD Loss: {total_mmd_loss/len(train_loader):.4f}, Dom Loss {total_dom_loss/len(train_loader):.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}% | "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")

except KeyboardInterrupt:
    print("\n⛔ Training interrupted. Saving model...")

    save_path = "./checkpoint-TEST.pth"
    torch.save({
        'model_state_dict': model.state_dict(),
        'dg_state_dict': dg.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, save_path)

    print("✅ Model saved. Exiting cleanly.")



⛔ Training interrupted. Saving model...
✅ Model saved. Exiting cleanly.


PERFORMANCE EVALUATIONS

In [10]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = torch.load('./Dataset/eeg_data_weibo.pt', map_location=device)

X_tensor = data['X']
YD_tensor = data['YD']
padding_masks_tensor = data['mask']
channel_xy_tensor = data['channel_xy']

X_tensor_half = data['X_val']
YD_tensor_half = data['YD_val']
padding_masks_tensor_half = data['mask_val']

X_tensor_test = data['X_test']
YD_tensor_test = data['YD_test']
padding_masks_tensor_test = data['mask_test']

dataset = MultisourceDataset(X_tensor, YD_tensor, padding_masks_tensor)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
dataset_test = MultisourceDataset(X_tensor_half, YD_tensor_half, padding_masks_tensor_half)
dataloader_test = DataLoader(dataset_test, batch_size=10, shuffle=True)
dataset_test0 = MultisourceDataset(X_tensor_test, YD_tensor_test, padding_masks_tensor_test)
dataloader_test0 = DataLoader(dataset_test0, batch_size=10, shuffle=True)

from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
import pandas as pd
from collections import defaultdict
import os
import torch
from torch.utils.data import DataLoader
import torch.nn as nn

# === Initialize ===
seed = 42
num_channels = 46
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load DataLoader
test_loader = DataLoader(dataloader_test0.dataset, batch_size=39, shuffle=False)

# Load models
model = SpatialAttention(num_channels, 48, dropout=False, seed=seed).to(device)
dg = DG_Network(3, 23, channels=num_channels, feature_size=3328).to(device)
softmax = nn.Softmax(dim=1)

# Load checkpoint
save_path = "./ProposedModel/checkpoint46out.pth"
if os.path.exists(save_path):
    checkpoint = torch.load(save_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    dg.load_state_dict(checkpoint['dg_state_dict'])

model.eval()
dg.eval()

# === Storage for metrics ===
domain_predictions = defaultdict(list)
domain_targets = defaultdict(list)
all_preds = []
all_targets = []

# === Inference ===
with torch.no_grad():
    for batch in test_loader:
        x_batch, yd_batch, mask_batch = batch
        x_batch, mask_batch = x_batch.to(device), mask_batch.to(device)
        output = model(x_batch, mask_batch, channel_xy_tensor.to(device))

        output = output.unsqueeze(1)  # (B, 1, C_out, T)
        logits, _ = dg(output)
        predictions = torch.argmax(logits, dim=1)

        domains = yd_batch[:, 4].cpu().numpy()
        targets = yd_batch[:, 0].cpu().numpy()
        preds = predictions.cpu().numpy()

        # Save for overall metrics
        all_preds.extend(preds)
        all_targets.extend(targets)

        # Save for per-domain metrics
        for domain, target, pred in zip(domains, targets, preds):
            domain_predictions[domain].append(pred)
            domain_targets[domain].append(target)

# === Print per-domain metrics ===
print("Metrics by Domain:")
for domain in sorted(domain_predictions.keys()):
    print(f"\n- Domain {domain}:")

    y_true = domain_targets[domain]
    y_pred = domain_predictions[domain]

    accuracy = accuracy_score(y_true, y_pred)
    print(f"    - Accuracy: {accuracy:.4f}")

    classes = sorted(set(y_true))
    f1_scores = f1_score(y_true, y_pred, average=None, labels=classes)
    for cls, f1 in zip(classes, f1_scores):
        print(f"    - Class {cls}: F1 Score = {f1:.4f}")

    cm = confusion_matrix(y_true, y_pred, labels=classes)
    cm_percent = cm.astype(float) / cm.sum(axis=1, keepdims=True)

    print(f"    - Misclassification Breakdown (rows = true class):")
    for i, cls in enumerate(classes):
        breakdown = []
        for j, pred_cls in enumerate(classes):
            pct = cm_percent[i, j] * 100
            if i == j:
                continue
            breakdown.append(f"{pct:.1f}% → class {pred_cls}")
        if breakdown:
            print(f"        Class {cls} misclassified as: {', '.join(breakdown)}")
        else:
            print(f"        Class {cls} has no misclassifications.")

# === Overall Metrics ===
overall_accuracy = accuracy_score(all_targets, all_preds)
classes = sorted(set(all_targets))
f1_per_class = f1_score(all_targets, all_preds, average=None, labels=classes)

print("\nOverall Metrics:")
print(f"    - Overall Accuracy: {overall_accuracy:.4f}")
print(f"    - F1 Scores per Class:")
for cls, f1 in zip(classes, f1_per_class):
    print(f"        Class {cls}: F1 Score = {f1:.4f}")


Metrics by Domain:

- Domain 13:
    - Accuracy: 0.5612
    - Class 0: F1 Score = 0.7671
    - Class 1: F1 Score = 0.5385
    - Class 2: F1 Score = 0.2667
    - Misclassification Breakdown (rows = true class):
        Class 0 misclassified as: 11.8% → class 1, 5.9% → class 2
        Class 1 misclassified as: 17.1% → class 0, 22.9% → class 2
        Class 2 misclassified as: 17.2% → class 0, 62.1% → class 1

- Domain 14:
    - Accuracy: 0.5882
    - Class 0: F1 Score = 0.6301
    - Class 1: F1 Score = 0.5556
    - Class 2: F1 Score = 0.5581
    - Misclassification Breakdown (rows = true class):
        Class 0 misclassified as: 19.4% → class 1, 6.5% → class 2
        Class 1 misclassified as: 34.5% → class 0, 13.8% → class 2
        Class 2 misclassified as: 36.0% → class 0, 16.0% → class 1

- Domain 15:
    - Accuracy: 0.4894
    - Class 0: F1 Score = 0.7273
    - Class 1: F1 Score = 0.4615
    - Class 2: F1 Score = 0.1818
    - Misclassification Breakdown (rows = true class):
        

In [11]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = torch.load('./Dataset/eeg_data_weibo.pt', map_location=device)

X_tensor = data['X']
YD_tensor = data['YD']
padding_masks_tensor = data['mask']
channel_xy_tensor = data['channel_xy']

X_tensor_half = data['X_val']
YD_tensor_half = data['YD_val']
padding_masks_tensor_half = data['mask_val']

X_tensor_test = data['X_test']
YD_tensor_test = data['YD_test']
padding_masks_tensor_test = data['mask_test']

dataset = MultisourceDataset(X_tensor, YD_tensor, padding_masks_tensor)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
dataset_test = MultisourceDataset(X_tensor_half, YD_tensor_half, padding_masks_tensor_half)
dataloader_test = DataLoader(dataset_test, batch_size=10, shuffle=True)
dataset_test0 = MultisourceDataset(X_tensor_test, YD_tensor_test, padding_masks_tensor_test)
dataloader_test0 = DataLoader(dataset_test0, batch_size=10, shuffle=True)

from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
import pandas as pd
from collections import defaultdict
import os
import torch
from torch.utils.data import DataLoader
import torch.nn as nn

# === Initialize ===
seed = 42
num_channels = 46
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load DataLoader
test_loader = DataLoader(dataloader_test0.dataset, batch_size=39, shuffle=False)

# Load models
model = SpatialAttention(num_channels, 48, dropout=False, seed=seed).to(device)
dg = DG_Network(3, 10, channels=num_channels, feature_size=3328).to(device)
softmax = nn.Softmax(dim=1)

# Load checkpoint
save_path = "./ProposedModel/checkpoint46outWeibo.pth"
if os.path.exists(save_path):
    checkpoint = torch.load(save_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    dg.load_state_dict(checkpoint['dg_state_dict'])

model.eval()
dg.eval()

# === Storage for metrics ===
domain_predictions = defaultdict(list)
domain_targets = defaultdict(list)
all_preds = []
all_targets = []

# === Inference ===
with torch.no_grad():
    for batch in test_loader:
        x_batch, yd_batch, mask_batch = batch
        x_batch, mask_batch = x_batch.to(device), mask_batch.to(device)
        output = model(x_batch, mask_batch, channel_xy_tensor.to(device))

        output = output.unsqueeze(1)  # (B, 1, C_out, T)
        logits, _ = dg(output)
        predictions = torch.argmax(logits, dim=1)

        domains = yd_batch[:, 4].cpu().numpy()
        targets = yd_batch[:, 0].cpu().numpy()
        preds = predictions.cpu().numpy()

        # Save for overall metrics
        all_preds.extend(preds)
        all_targets.extend(targets)

        # Save for per-domain metrics
        for domain, target, pred in zip(domains, targets, preds):
            domain_predictions[domain].append(pred)
            domain_targets[domain].append(target)

# === Print per-domain metrics ===
print("Metrics by Domain:")
for domain in sorted(domain_predictions.keys()):
    print(f"\n- Domain {domain}:")

    y_true = domain_targets[domain]
    y_pred = domain_predictions[domain]

    accuracy = accuracy_score(y_true, y_pred)
    print(f"    - Accuracy: {accuracy:.4f}")

    classes = sorted(set(y_true))
    f1_scores = f1_score(y_true, y_pred, average=None, labels=classes)
    for cls, f1 in zip(classes, f1_scores):
        print(f"    - Class {cls}: F1 Score = {f1:.4f}")

    cm = confusion_matrix(y_true, y_pred, labels=classes)
    cm_percent = cm.astype(float) / cm.sum(axis=1, keepdims=True)

    print(f"    - Misclassification Breakdown (rows = true class):")
    for i, cls in enumerate(classes):
        breakdown = []
        for j, pred_cls in enumerate(classes):
            pct = cm_percent[i, j] * 100
            if i == j:
                continue
            breakdown.append(f"{pct:.1f}% → class {pred_cls}")
        if breakdown:
            print(f"        Class {cls} misclassified as: {', '.join(breakdown)}")
        else:
            print(f"        Class {cls} has no misclassifications.")

# === Overall Metrics ===
overall_accuracy = accuracy_score(all_targets, all_preds)
classes = sorted(set(all_targets))
f1_per_class = f1_score(all_targets, all_preds, average=None, labels=classes)

print("\nOverall Metrics:")
print(f"    - Overall Accuracy: {overall_accuracy:.4f}")
print(f"    - F1 Scores per Class:")
for cls, f1 in zip(classes, f1_per_class):
    print(f"        Class {cls}: F1 Score = {f1:.4f}")


Metrics by Domain:

- Domain 13:
    - Accuracy: 0.5918
    - Class 0: F1 Score = 0.8358
    - Class 1: F1 Score = 0.4688
    - Class 2: F1 Score = 0.4615
    - Misclassification Breakdown (rows = true class):
        Class 0 misclassified as: 8.8% → class 1, 8.8% → class 2
        Class 1 misclassified as: 5.7% → class 0, 51.4% → class 2
        Class 2 misclassified as: 10.3% → class 0, 37.9% → class 1

- Domain 14:
    - Accuracy: 0.4000
    - Class 0: F1 Score = 0.4615
    - Class 1: F1 Score = 0.1333
    - Class 2: F1 Score = 0.5333
    - Misclassification Breakdown (rows = true class):
        Class 0 misclassified as: 25.8% → class 1, 25.8% → class 2
        Class 1 misclassified as: 51.7% → class 0, 37.9% → class 2
        Class 2 misclassified as: 16.0% → class 0, 20.0% → class 1

- Domain 15:
    - Accuracy: 0.3511
    - Class 0: F1 Score = 0.2857
    - Class 1: F1 Score = 0.3864
    - Class 2: F1 Score = 0.3529
    - Misclassification Breakdown (rows = true class):
        C