In [None]:
import os
import mne
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import KFold
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def process_set_files(set_files, words, tmin, tmax):
    all_X = []
    all_y_word = []
    for set_file in set_files:
        try:
            raw = mne.io.read_raw_eeglab(set_file, preload=True)
            print("文件读取成功！")
        except Exception as e:
            print(f"读取文件时出错: {e}")
            exit()

        if 'VEOG' in raw.ch_names: raw.drop_channels(['VEOG'])
        if 'HEOG' in raw.ch_names: raw.drop_channels(['HEOG'])
        if 'Trigger' in raw.ch_names: raw.drop_channels(['Trigger'])


        raw.filter(0.5, 80., fir_design='firwin')

        events, current_event_id_map = mne.events_from_annotations(raw)
        event_id = {}
        event_code_to_word = {}
        
        for i, word in enumerate(words):
            think_code = str(50 + i)  
            think_event_id = current_event_id_map.get(think_code)
            
            if think_event_id is not None:
                event_id[f'想_{word}'] = think_event_id
                event_code_to_word[think_event_id] = i  
        
        epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax, preload=True, verbose='ERROR', baseline=(0, 0))
        X = epochs.get_data() 
        y_event_ids = epochs.events[:, 2]
        
        y_word = []
        for label in y_event_ids:
            word = event_code_to_word.get(label, -1)
            if word != -1:
                y_word.append(word)
        
        y_word = np.array(y_word)
        valid_indices = (y_word != -1)
        X = X[valid_indices]
        y_word = y_word[valid_indices]
        new_X = []
        new_y_word = []
        for i in range(X.shape[0]):
            
            for j in range(3): 
                new_X.append(X[i, :, j * int(X.shape[2] / 3): (j + 1) * int(X.shape[2] / 3)])
                new_y_word.append(y_word[i]) 
        
        new_X = np.array(new_X)
        new_y_word = np.array(new_y_word)
        
        all_X.append(new_X)
        all_y_word.append(new_y_word)
    
    all_X = np.concatenate(all_X, axis=0)
    all_y_word = np.concatenate(all_y_word, axis=0)
    return all_X, all_y_word


def standardize_data(X):
    scaler = StandardScaler()
    num_samples, num_channels, num_timepoints = X.shape
    X = X.reshape(num_samples, -1)
    X = scaler.fit_transform(X)
    X = X.reshape(num_samples, num_channels, num_timepoints)
    return X


class EEGDataset(Dataset):
    def __init__(self, X, y_word):
        self.X = X.astype(np.float32)  
        self.y_word = y_word
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y_word[idx]

class EEGNetWithAttention(nn.Module):
    def __init__(self, num_classes_word, dropout=0.3):
        super(EEGNetWithAttention, self).__init__()
        self.firstconv = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=(1, 51), stride=(1, 1), padding=(0, 25), bias=False),
            nn.BatchNorm2d(16)
        )
        self.depthwiseConv = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=(64, 1), stride=(1, 1), groups=16, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1, 4), stride=(1, 4)),
            nn.Dropout(dropout)
        )
        self.separableConv = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=(1, 15), stride=(1, 1), padding=(0, 7), bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1, 8), stride=(1, 8)),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x.unsqueeze(1) 
        x = self.firstconv(x)
        x = self.depthwiseConv(x)
        x = self.separableConv(x)
        x = x.permute(0, 3, 1, 2)  
        x = x.reshape(x.size(0), x.size(1), -1) 
        return x
class RNNClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes_word, dropout=0.3):
        super(RNNClassifier, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=True)
        self.fc = nn.Linear(hidden_size * 2, num_classes_word)

    def forward(self, x):
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = self.fc(x)
        return x

def train_model(model, train_loader, criterion_word, optimizer, num_epochs):
    train_losses = []
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
    early_stop_patience = 10
    early_stop_counter = 0
    best_loss = float('inf')
    
    for epoch in range(num_epochs):
        model.train()
        running_loss_word = 0.0
        correct_word = 0
        total = 0
        
        for inputs, labels_word in train_loader:
            inputs = inputs.float().to(device) 
            labels_word = labels_word.long().to(device)

            optimizer.zero_grad()
            features = eegnet(inputs) 
            outputs_word = model(features)
            loss_word = criterion_word(outputs_word, labels_word)
            loss_word.backward()
            optimizer.step()
            
            running_loss_word += loss_word.item() * inputs.size(0)
            _, predicted_word = torch.max(outputs_word.data, 1)
            total += labels_word.size(0)
            correct_word += (predicted_word == labels_word).sum().item()

        epoch_loss_word = running_loss_word / total
        epoch_accuracy_word = 100 * correct_word / total
        scheduler.step(epoch_loss_word)
        
        train_losses.append(epoch_loss_word)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss Word: {epoch_loss_word:.4f}, Accuracy Word: {epoch_accuracy_word:.2f}%')


        if epoch_loss_word < best_loss:
            best_loss = epoch_loss_word
            early_stop_counter = 0
        else:
            early_stop_counter += 1
            if early_stop_counter >= early_stop_patience:
                print("Early stopping triggered.")
                break

    return train_losses

if __name__ == "__main__":
    words = [
        '农民种菜', '厨师做饭', '祖父喝茶', '病人咳嗽', 
        '叛军投降', '孔雀开屏', '老牛耕地', '母鸡下蛋', 
        '蜻蜓点水', '螳螂捕蝉'
    ]
    set_files = ['./001C.set','./006C1.set','./006C2.set','./007C1.set','./007C2.set','./008C1.set','./008C2.set','./009C1.set','./009C2.set','./010C1.set','./010C2.set','./011C1.set','./011C2.set',]
    X, y_word = process_set_files(set_files, words, 0, 3)
    X = standardize_data(X)
    

    kfold = KFold(n_splits=5, shuffle=True, random_state=42)
    for fold, (train_idx, test_idx) in enumerate(kfold.split(X)):
        print(f'Fold {fold+1}/5')
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y_word[train_idx], y_word[test_idx]

        train_dataset = EEGDataset(X_train, y_train)
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)


        eegnet = EEGNetWithAttention(num_classes_word=len(words)).to(device)
        rnn_classifier = RNNClassifier(input_size=64, hidden_size=128, num_layers=2, num_classes_word=len(words)).to(device)

 
        criterion_word = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(list(eegnet.parameters()) + list(rnn_classifier.parameters()), lr=0.001, weight_decay=5e-4)

    
        train_model(rnn_classifier, train_loader, criterion_word, optimizer, num_epochs=50)

   
        torch.save(eegnet.state_dict(), 'eegnet_pretrained.pth')
        torch.save(rnn_classifier.state_dict(), 'rnn_classifier_pretrained.pth')

        print("预训练模型已保存")
        break 


Fold 1/5


  from .autonotebook import tqdm as notebook_tqdm


Epoch [1/50], Loss Word: 2.3036, Accuracy Word: 10.68%
Epoch [2/50], Loss Word: 2.2907, Accuracy Word: 12.02%
Epoch [3/50], Loss Word: 2.2707, Accuracy Word: 14.21%
Epoch [4/50], Loss Word: 2.2566, Accuracy Word: 15.76%
Epoch [5/50], Loss Word: 2.2296, Accuracy Word: 16.59%
Epoch [6/50], Loss Word: 2.2111, Accuracy Word: 17.41%
Epoch [7/50], Loss Word: 2.1865, Accuracy Word: 19.39%
Epoch [8/50], Loss Word: 2.1733, Accuracy Word: 20.14%
Epoch [9/50], Loss Word: 2.1469, Accuracy Word: 21.05%
Epoch [10/50], Loss Word: 2.1125, Accuracy Word: 23.24%
Epoch [11/50], Loss Word: 2.1034, Accuracy Word: 23.05%
Epoch [12/50], Loss Word: 2.0774, Accuracy Word: 24.31%
Epoch [13/50], Loss Word: 2.0703, Accuracy Word: 24.71%
Epoch [14/50], Loss Word: 2.0173, Accuracy Word: 26.98%
Epoch [15/50], Loss Word: 2.0027, Accuracy Word: 26.55%
Epoch [16/50], Loss Word: 1.9566, Accuracy Word: 29.62%
Epoch [17/50], Loss Word: 1.9061, Accuracy Word: 31.76%
Epoch [18/50], Loss Word: 1.9093, Accuracy Word: 31.86%
E