In [None]:
import os
import mne
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

electrode_to_region = {
    'Fp1': 'Frontal', 'Fp2': 'Frontal', 'Fpz': 'Frontal', 'AF3': 'Frontal', 'AF4': 'Frontal', 
    'F11': 'Frontal', 'F7': 'Frontal', 'F5': 'Frontal', 'F3': 'Frontal', 'F1': 'Frontal', 
    'Fz': 'Frontal', 'F2': 'Frontal', 'F4': 'Frontal', 'F6': 'Frontal', 'F8': 'Frontal', 
    'F12': 'Frontal', 'FT11': 'Frontal', 'FC5': 'Frontal', 'FC3': 'Frontal', 'FC1': 'Frontal', 
    'FCz': 'Frontal', 'FC2': 'Frontal', 'FC4': 'Frontal', 'FC6': 'Frontal', 'FT12': 'Frontal',
    'C5': 'Central', 'C3': 'Central', 'C1': 'Central', 'Cz': 'Central', 'C2': 'Central', 
    'C4': 'Central', 'C6': 'Central',
    'T7': 'Temporal', 'T8': 'Temporal', 'TP7': 'Temporal', 'TP8': 'Temporal',
    'P7': 'Parietal', 'P5': 'Parietal', 'P3': 'Parietal', 'P1': 'Parietal', 'Pz': 'Parietal', 
    'P2': 'Parietal', 'P4': 'Parietal', 'P6': 'Parietal', 'P8': 'Parietal', 'PO7': 'Parietal', 
    'PO3': 'Parietal', 'POz': 'Parietal', 'PO4': 'Parietal', 'PO8': 'Parietal',
    'O1': 'Occipital', 'Oz': 'Occipital', 'O2': 'Occipital',
    'M1': 'Reference', 'M2': 'Reference',  
}
# 电极位置编码
def generate_position_encoding(positions, d=256):
    encodings = []
    for i, pos in enumerate(positions):
        x, y, z = pos
        encoding = []
        for i in range(d // 2):
            encoding.append(np.sin((x / (10000 ** (2 * i / d)))) )
            encoding.append(np.cos((x / (10000 ** (2 * i / d)))) )
            encoding.append(np.sin((y / (10000 ** (2 * i / d)))) )
            encoding.append(np.cos((y / (10000 ** (2 * i / d)))) )
            encoding.append(np.sin((z / (10000 ** (2 * i / d)))) )
            encoding.append(np.cos((z / (10000 ** (2 * i / d)))) )
        encodings.append(encoding)
    return np.array(encodings)

# EEGNet模型定义
class EEGNet(nn.Module):
    def __init__(self, num_classes=10, num_regions=20, input_dim=256, num_heads=4, num_filters=32):
        super(EEGNet, self).__init__()
        
        # EEGNet特征提取部分
        self.conv1 = nn.Conv2d(1, num_filters, kernel_size=(1, 64), padding='same')
        self.conv2 = nn.Conv2d(num_filters, num_filters * 2, kernel_size=(1, 64), padding='same')
        self.conv3 = nn.Conv2d(num_filters * 2, num_filters * 4, kernel_size=(1, 64), padding='same')
        self.pool = nn.MaxPool2d((1, 2))
        self.fc1 = nn.Linear(num_filters * 4 * 64, input_dim)  # 输出到256维特征空间 4x64
        
        # 区域解码器
        self.region_decoders = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, 128),
                nn.ReLU(),
                nn.BatchNorm1d(128),
                nn.MaxPool1d(kernel_size=2)
            ) for _ in range(num_regions)
        ])
        
        # 动态注意力机制
        self.fc_attention = nn.Linear(input_dim, 1)
        
        # 多头注意力机制
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)
        
        # 最终分类层
        self.fc_out = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        # 第一步：通过EEGNet提取特征
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        
        x = x.view(x.size(0), -1)  # Flatten 展平
        x = self.fc1(x)  # 输出到256维特征空间
        
        # 第二步：通过区域解码器处理不同区域
        region_outputs = []
        for decoder in self.region_decoders:
            region_outputs.append(decoder(x))
        
        # 第三步：通过动态注意力机制调整区域重要性
        attention_scores = self.fc_attention(x)  # 计算注意力得分
        attention_weights = F.softmax(attention_scores, dim=1)  # 归一化为权重
        
        # 第四步：通过多头注意力机制整合不同区域的特征
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        
        # 计算注意力权重并进行加权求和
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (K.size(-1) ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)
        multi_head_output = torch.matmul(attention_weights, V)
        
        # 第五步：最终分类
        final_output = self.fc_out(multi_head_output)
        
        return final_output
# 处理EEG数据
def process_set_files(set_files, words, tmin, tmax, bad_channels=None, filter_params=None, p=0.2):
    X_all, y_all = [], []
    for set_file in set_files:
        print(f"读取EEGLAB .set 文件: {set_file}...")

        try:
            raw = mne.io.read_raw_eeglab(set_file, preload=True)
            print("文件读取成功！")
        except Exception as e:
            print(f"读取文件时出错: {e}")
            continue

        # 手动标记坏道
        if bad_channels:
            raw.info['bads'] = bad_channels

        # 滤波处理
        if filter_params:
            raw.filter(filter_params['l_freq'], filter_params['h_freq'], fir_design='firwin')

        events, current_event_id_map = mne.events_from_annotations(raw)
        event_id = {}
        event_code_to_word = {}

        for i, word in enumerate(words):
            think_code = str(50 + i)  # 只处理“想”相关的事件
            think_event_id = current_event_id_map.get(think_code)

            if think_event_id is not None:
                event_id[f'想_{word}'] = think_event_id
                event_code_to_word[think_event_id] = i

        epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax, preload=True, verbose='ERROR', baseline=(0, 0))
        X = epochs.get_data()  
        y_event_ids = epochs.events[:, 2]

        y = []
        for label in y_event_ids:
            word = event_code_to_word.get(label, -1)
            if word != -1:
                y.append(word)

        y = np.array(y)
        valid_indices = (y != -1)
        X = X[valid_indices]
        y = y[valid_indices]

        # 获取电极的三维坐标
        ch_pos = np.array([raw.info['chs'][i]['loc'][:3] for i in range(len(raw.ch_names))])
        positions = ch_pos[:, :3]  # 只获取x, y, z坐标
        position_encodings = generate_position_encoding(positions, d=256)  # 计算位置编码

        # 随机掩码
        if np.random.rand() < p:
            X = np.zeros_like(X)

        X_all.append(X)
        y_all.append(y)
    return np.concatenate(X_all, axis=0), np.concatenate(y_all, axis=0), position_encodings

# 数据标准化
def standardize_data(X_train, X_test):
    scaler = StandardScaler()
    num_channels = X_train.shape[2]
    for i in range(num_channels):
        if np.isnan(X_train[:, :, i]).any():
            print(f"Channel {i} contains NaN values, filling with mean.")
            mean_value = np.nanmean(X_train[:, :, i])
            X_train[np.isnan(X_train[:, :, i]), :, i] = mean_value  
        X_train[:, :, i] = scaler.fit_transform(X_train[:, :, i])
        X_test[:, :, i] = scaler.transform(X_test[:, :, i])
    return X_train, X_test

# EEG 数据集类
class EEGDataset(Dataset):
    def __init__(self, X, y, position_encodings):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.position_encodings = position_encodings.astype(np.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.position_encodings[idx]


# 主程序
if __name__ == "__main__":
    words = [
        '农民种菜', '厨师做饭', '祖父喝茶', '病人咳嗽', 
        '叛军投降', '孔雀开屏', '老牛耕地', '母鸡下蛋', 
        '蜻蜓点水', '螳螂捕蝉'
    ]

    # 处理多个训练集和测试集
    train_set_files = [
        'c:/Users/clock/Desktop/hanzi/008C1.set',
    ]
    
    test_set_files = [
        'c:/Users/clock/Desktop/hanzi/008C2.set',
    ]

    # 坏道通道
    bad_channels = ['VEOG', 'HEOG', 'Trigger'] 
    # 滤波参数（低频和高频）
    filter_params = {'l_freq': 1.0, 'h_freq': 80.0}

    # 读取数据并处理
    X_train, y_train, position_encodings_train = process_set_files(train_set_files, words, tmin=0, tmax=4, bad_channels=bad_channels, filter_params=filter_params)
    X_test, y_test, position_encodings_test = process_set_files(test_set_files, words, tmin=0, tmax=4, bad_channels=bad_channels, filter_params=filter_params)

    # 标准化数据
    X_train, X_test = standardize_data(X_train, X_test)

    # 转换为Tensor
    X_train = torch.tensor(X_train, dtype=torch.float32).unsqueeze(1).to(device)
    X_test = torch.tensor(X_test, dtype=torch.float32).unsqueeze(1).to(device)
    y_train = torch.tensor(y_train).to(device)
    y_test = torch.tensor(y_test).to(device)
    position_encodings_train = torch.tensor(position_encodings_train).to(device)
    position_encodings_test = torch.tensor(position_encodings_test).to(device)

    # 创建数据集和数据加载器
    train_dataset = EEGDataset(X_train, y_train, position_encodings_train)
    test_dataset = EEGDataset(X_test, y_test, position_encodings_test)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # 模型定义
    model = EEGNet(num_classes=10, num_channels=X_train.shape[2]).to(device)

    # 损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # 训练模型
    num_epochs = 10
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, (inputs, labels, pos_encodings) in enumerate(train_loader):
            inputs, labels, pos_encodings = inputs.to(device), labels.to(device), pos_encodings.to(device)

            optimizer.zero_grad()
            outputs = model(inputs, pos_encodings)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader)}")

    # 测试模型
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels, pos_encodings in test_loader:
            inputs, labels, pos_encodings = inputs.to(device), labels.to(device), pos_encodings.to(device)
            outputs = model(inputs, pos_encodings)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # 输出分类报告
    print(classification_report(all_labels, all_preds))
