In [1]:
# 包导入
import torch
import torch.nn as nn
from torch.autograd import Function
import numpy as np

import mne
import matplotlib.pyplot as pl
%matplotlib inline
# from functions import ReverseLayerF

## 模型函数声明

1. `__init__(classes 表示目标分类的类别数量。
channels 表示输入数据的通道数。
F1 表示卷积核的数量。
D 表示深度可分离卷积的倍数。
domains 表示源领域的数量。）`

In [2]:
# 模型定义
# coding=utf-8

class DG_Network(nn.Module):
    def __init__(self, classes, channels, F1=4, D=2, domains=3): 
        super(DG_Network, self).__init__()
        self.dropout = 0.25  # default:0.25

        # 四个并行的卷积块，用于提取不同尺度的特征。
        self.block1_1 = nn.Sequential(  
            nn.ZeroPad2d((3, 4, 0, 0)),
            nn.Conv2d(1, F1, kernel_size=(1, 8), bias=False),  
            nn.BatchNorm2d(F1)
        )

        self.block1_2 = nn.Sequential(  
            nn.ZeroPad2d((7, 8, 0, 0)),
            nn.Conv2d(1, F1, kernel_size=(1, 16), bias=False), 
            nn.BatchNorm2d(F1)
        )

        self.block1_3 = nn.Sequential(  
            nn.ZeroPad2d((15, 16, 0, 0)),
            nn.Conv2d(1, F1, kernel_size=(1, 32), bias=False), 
            nn.BatchNorm2d(F1)
        )

        self.block1_4 = nn.Sequential( 
            nn.ZeroPad2d((31, 32, 0, 0)),
            nn.Conv2d(1, F1, kernel_size=(1, 64), bias=False), 
            nn.BatchNorm2d(F1)
        )

        # 深度可分离卷积块
        self.block2 = nn.Sequential(  
            # DepthwiseConv2D
            nn.Conv2d(F1 * 4, F1 * 4 * D, kernel_size=(channels, 1), groups=F1 * 4, bias=False),
            # groups=F1 for depthWiseConv 
            nn.BatchNorm2d(F1 * 4 * D),
            # nn.ELU(inplace=True),
            nn.ReLU(inplace=True),
            nn.AvgPool2d((1, 4)),  
            nn.Dropout(self.dropout),
        )

        # 四个并行的深度可分离卷积块，用于更加复杂的特征学习。
        self.block3_1 = nn.Sequential(  
            # SeparableConv2D
            nn.ZeroPad2d((0, 1, 0, 0)),
            nn.Conv2d(F1 * 4 * D, F1 * 4 * D, kernel_size=(1, 2), groups=F1 * 4 * D, bias=False),
            # groups=F1 for depthWiseConv 
            nn.BatchNorm2d(F1 * 4 * D),
            # nn.ELU(inplace=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(F1 * 4 * D, F1 * 4 * D, kernel_size=(1, 1), groups=1, bias=False),  # point-wise cnn  
            nn.BatchNorm2d(F1 * 4 * D),
        )

        self.block3_2 = nn.Sequential(
            # SeparableConv2D
            nn.ZeroPad2d((1, 2, 0, 0)),
            nn.Conv2d(F1 * 4 * D, F1 * 4 * D, kernel_size=(1, 4), groups=F1 * 4 * D, bias=False),
            # groups=F1 for depthWiseConv  
            nn.BatchNorm2d(F1 * 4 * D),
            # nn.ELU(inplace=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(F1 * 4 * D, F1 * 4 * D, kernel_size=(1, 1), groups=1, bias=False),  # point-wise cnn   
            nn.BatchNorm2d(F1 * 4 * D),
        )

        self.block3_3 = nn.Sequential(
            # SeparableConv2D
            nn.ZeroPad2d((3, 4, 0, 0)),
            nn.Conv2d(F1 * 4 * D, F1 * 4 * D, kernel_size=(1, 8), groups=F1 * 4 * D, bias=False),
            # groups=F1 for depthWiseConv 
            nn.BatchNorm2d(F1 * 4 * D),
            # nn.ELU(inplace=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(F1 * 4 * D, F1 * 4 * D, kernel_size=(1, 1), groups=1, bias=False),  # point-wise cnn 
            nn.BatchNorm2d(F1 * 4 * D),
        )

        self.block3_4 = nn.Sequential(
            # SeparableConv2D
            nn.ZeroPad2d((7, 8, 0, 0)),
            nn.Conv2d(F1 * 4 * D, F1 * 4 * D, kernel_size=(1, 16), groups=F1 * 4 * D, bias=False),
            # groups=F1 for depthWiseConv 
            nn.BatchNorm2d(F1 * 4 * D),
            # nn.ELU(inplace=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(F1 * 4 * D, F1 * 4 * D, kernel_size=(1, 1), groups=1, bias=False),  # point-wise cnn  
            nn.BatchNorm2d(F1 * 4 * D),
        )

        # 全局平均池化和 dropout 的块。
        self.block4 = nn.Sequential(
            # nn.ELU(inplace=True),
            nn.ReLU(inplace=True),
            nn.AvgPool2d((1, 8)), 
            nn.Dropout(self.dropout)
        )

        # 三个全连接层，用于生成特定的特征。
        self.special_features1 = nn.Sequential( 
            nn.Linear(3968, 400),
            # nn.Dropout(self.dropout)
        )

        self.special_features2 = nn.Sequential( 
            nn.Linear(3968, 400),
            # nn.Dropout(self.dropout)
        )

        self.special_features3 = nn.Sequential(  
            nn.Linear(3968, 400),
            # nn.Dropout(self.dropout)
        )

        # 领域分类器，用于领域分类的全连接层
        self.domain_classifier = nn.Sequential(
            nn.Linear(3968, domains),
        )

        # 最终分类器，进行目标分类
        self.classifier = nn.Sequential(
            nn.Linear(400, classes)
        )


    def forward(self, data_train1, data_train2, data_train3): 
        data1 = data_train1.to(torch.float32)
        data2 = data_train2.to(torch.float32)
        data3 = data_train3.to(torch.float32)
        data = torch.cat((data1, data2, data3), dim=0)

        # extracting general features
        feat_1 = self.block1_1(data)   
        feat_2 = self.block1_2(data)
        feat_3 = self.block1_3(data)
        feat_4 = self.block1_4(data)
        feat = torch.cat((feat_1, feat_2, feat_3, feat_4), dim=1)  

        # extracting special features
        feature = self.block2(feat)              

        feature_1 = self.block3_1(feature)
        feature_2 = self.block3_2(feature)
        feature_3 = self.block3_3(feature)
        feature_4 = self.block3_4(feature)
        features = torch.cat((feature_1, feature_2, feature_3, feature_4), dim=1) 

        features = self.block4(features)       

        features = torch.flatten(features, 1)    


        # extracting special features
        feat1 = self.special_features1(features)  
        feat2 = self.special_features2(features)
        feat3 = self.special_features3(features)
        Feat_s = [feat1, feat2, feat3]

        # feat for domain classifier, dom for computing domain specific loss
        feat_ = self.domain_classifier(features)
        weight = nn.functional.softmax(feat_, dim=1)

        feat123 = torch.stack((feat1, feat2, feat3), dim=1)   
        weighted = weight.unsqueeze(0).permute(1, 0, 2)
        weighted_feature = torch.bmm(weighted, feat123)      
        weighted_feature = torch.flatten(weighted_feature, 1)
        feature = self.classifier(weighted_feature)
        out = nn.functional.softmax(feature, dim=1)

        return out, weight, Feat_s, weighted_feature   #


    def predict(self, data):
        data = data.to(torch.float32)

        feat_1 = self.block1_1(data)  
        feat_2 = self.block1_2(data)
        feat_3 = self.block1_3(data)
        feat_4 = self.block1_4(data)
        feat = torch.cat((feat_1, feat_2, feat_3, feat_4), dim=1)

        # extracting special features
        feature = self.block2(feat)          
        feature_1 = self.block3_1(feature)
        feature_2 = self.block3_2(feature)
        feature_3 = self.block3_3(feature)
        feature_4 = self.block3_4(feature)

        features = torch.cat((feature_1, feature_2, feature_3, feature_4), dim=1) 
        features = self.block4(features)      
        features = torch.flatten(features, 1) 

        # extracting special features
        feat1 = self.special_features1(features)  
        feat2 = self.special_features2(features)
        feat3 = self.special_features3(features)

        # feat for domain classifier, dom for computing domain specific loss
        feat_ = self.domain_classifier(features)
        weight = nn.functional.softmax(feat_, dim=1)

        feat123 = torch.stack((feat1, feat2, feat3), dim=1) 
        weighted = weight.unsqueeze(0).permute(1, 0, 2)
        weighted_feature = torch.bmm(weighted, feat123)  
        weighted_feature = torch.flatten(weighted_feature, 1)
        feature = self.classifier(weighted_feature)
        out = nn.functional.softmax(feature, dim=1)

        return out, weight, weighted_feature


In [5]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

dataDir = "./DataProcessed/BCI2a/data/"
labelDir = "./DataProcessed/BCI2a/labels/"

class EEGDataset(Dataset):
    def __init__(self, data_path, labels_path):
        self.data = np.load(data_path)
        self.labels = np.load(labels_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = {
            'data': torch.tensor(self.data[idx], dtype=torch.float32),
            'label': torch.tensor(self.labels[idx], dtype=torch.long)
        }
        return sample

# 假设你有三组数据集 data_A01T.npy, labels_A01T.npy，以及一个测试集 data_test.npy
# 分别创建对应的 Dataset 实例
dataset1 = EEGDataset(dataDir+'data_A01T.npy', labelDir+'labels_A01T.npy')
dataset2 = EEGDataset(dataDir+'data_A02T.npy', labelDir+'labels_A02T.npy')
dataset3 = EEGDataset(dataDir+'data_A03T.npy', labelDir+'labels_A03T.npy')
test_dataset = EEGDataset(dataDir+'data_A05T.npy', labelDir+'labels_A05T.npy')

# 创建对应的 DataLoader 实例
dataloader1 = DataLoader(dataset1, batch_size=8, shuffle=True)
dataloader2 = DataLoader(dataset2, batch_size=8, shuffle=True)
dataloader3 = DataLoader(dataset3, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)


# 训练

先拿3个数据 + 1个验证集练练手

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim

num_classes = 4
learning_rate = 0.005

# 初始化模型
model = DG_Network(classes=num_classes, channels=22)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

num_epochs = 500

for epoch in range(num_epochs):
    for batch_data1, batch_data2, batch_data3 in zip(dataloader1, dataloader2, dataloader3):
        optimizer.zero_grad()

        # 提取数据和标签
        data1, labels1 = batch_data1['data'], batch_data1['label']
        data2, labels2 = batch_data2['data'], batch_data2['label']
        data3, labels3 = batch_data3['data'], batch_data3['label']

        # 前向传播
        outputs, _, _, _ = model(data1, data2, data3)

        # 根据你的标签格式，可能需要修改损失函数的计算方式
        # 这里假设你的标签是 one-hot 编码的，需要使用交叉熵损失
        loss = criterion(outputs, torch.cat([labels1, labels2, labels3]))

        # 反向传播
        loss.backward()
        optimizer.step()

    # 在每个周期结束时记录准确度等信息
    # 可以根据需要添加其他的记录和验证步骤
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}')



# 训练完成后，你可以保存训练好的模型参数
torch.save(model.state_dict(), './model.pth')


NameError: name 'learning_rate' is not defined