In [2]:
import os
import mne
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import KFold
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def process_set_files(set_files, words, tmin, tmax):
    all_X = []
    all_y_word = []
    for set_file in set_files:
        print(f"读取EEGLAB .set 文件: {set_file}...")
        try:
            raw = mne.io.read_raw_eeglab(set_file, preload=True)
            print("文件读取成功！")
        except Exception as e:
            print(f"读取文件时出错: {e}")
            exit()

        if 'VEOG' in raw.ch_names: raw.drop_channels(['VEOG'])
        if 'HEOG' in raw.ch_names: raw.drop_channels(['HEOG'])
        if 'Trigger' in raw.ch_names: raw.drop_channels(['Trigger'])

        raw.filter(0.5, 80., method='iir', iir_params={'order': 6, 'ftype': 'butter'})

        events, current_event_id_map = mne.events_from_annotations(raw)
        event_id = {}
        event_code_to_word = {}
        
        for i, word in enumerate(words):
            think_code = str(40 + i)  
            think_event_id = current_event_id_map.get(think_code)
            
            if think_event_id is not None:
                event_id[f'想_{word}'] = think_event_id
                event_code_to_word[think_event_id] = i  
        epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax, preload=True, verbose='ERROR', baseline=(0, 0))
        X = epochs.get_data() 
        y_event_ids = epochs.events[:, 2]
        
        y_word = []
        for label in y_event_ids:
            word = event_code_to_word.get(label, -1)
            if word != -1:
                y_word.append(word)
        
        y_word = np.array(y_word)
        valid_indices = (y_word != -1)
        X = X[valid_indices]
        y_word = y_word[valid_indices]
        new_X = []
        new_y_word = []
        for i in range(X.shape[0]):
            
            for j in range(3): 
                new_X.append(X[i, :, j * int(X.shape[2] / 3): (j + 1) * int(X.shape[2] / 3)])
                new_y_word.append(y_word[i]) 
        
        new_X = np.array(new_X)
        new_y_word = np.array(new_y_word)
        
        all_X.append(new_X)
        all_y_word.append(new_y_word)
    
    all_X = np.concatenate(all_X, axis=0)
    all_y_word = np.concatenate(all_y_word, axis=0)
    return all_X, all_y_word

def standardize_data(X):
    scaler = StandardScaler()
    num_samples, num_channels, num_timepoints = X.shape
    X = X.reshape(num_samples, -1)
    X = scaler.fit_transform(X)
    X = X.reshape(num_samples, num_channels, num_timepoints)
    return X

class EEGDataset(Dataset):
    def __init__(self, X, y_word):
        self.X = X.astype(np.float32)  
        self.y_word = y_word
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y_word[idx]

class EEGNet(nn.Module):
    def __init__(self, num_classes_word, dropout=0.3):
        super(EEGNet, self).__init__()
        self.firstconv = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=(1, 51), stride=(1, 1), padding=(0, 25), bias=False),
            nn.BatchNorm2d(16)
        )
        self.depthwiseConv = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=(64, 1), stride=(1, 1), groups=16, bias=False), 
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1, 4), stride=(1, 4)),
            nn.Dropout(dropout)
        )
        self.separableConv = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=(1, 15), stride=(1, 1), padding=(0, 7), bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1, 8), stride=(1, 8)),
            nn.Dropout(dropout)
        )
        self.flatten = nn.Flatten()
        self.classify = nn.Linear(992, num_classes_word)
    
    def forward(self, x):
        x = x.unsqueeze(1)  
        x = self.firstconv(x)
        x = self.depthwiseConv(x)
        x = self.separableConv(x)
        x = self.flatten(x)
        word_out = self.classify(x)
        return word_out

def train_model(model, train_loader, criterion_word, optimizer, num_epochs):
    train_losses = []
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
    early_stop_patience = 10
    early_stop_counter = 0
    best_loss = float('inf')
    
    for epoch in range(num_epochs):
        model.train()
        running_loss_word = 0.0
        correct_word = 0
        total = 0
        
        for inputs, labels_word in train_loader:
            inputs = inputs.float().to(device) 
            labels_word = labels_word.long().to(device)

            optimizer.zero_grad()
            outputs_word = model(inputs)
            loss_word = criterion_word(outputs_word, labels_word)
            loss_word.backward()
            optimizer.step()
            
            running_loss_word += loss_word.item() * inputs.size(0)
            _, predicted_word = torch.max(outputs_word.data, 1)
            total += labels_word.size(0)
            correct_word += (predicted_word == labels_word).sum().item()

        epoch_loss_word = running_loss_word / total
        epoch_accuracy_word = 100 * correct_word / total
        scheduler.step(epoch_loss_word)
        
        train_losses.append(epoch_loss_word)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss Word: {epoch_loss_word:.4f}, Accuracy Word: {epoch_accuracy_word:.2f}%')

        if epoch_loss_word < best_loss:
            best_loss = epoch_loss_word
            early_stop_counter = 0
        else:
            early_stop_counter += 1
            if early_stop_counter >= early_stop_patience:
                print("Early stopping triggered.")
                break

    return train_losses

def evaluate_model(model, test_loader):
    model.eval()
    y_pred_word = []
    y_true_word = []
    total = 0
    correct_word = 0
    
    with torch.no_grad():
        for inputs, labels_word in test_loader:
            inputs = inputs.to(device)
            labels_word = labels_word.to(device)
            
            outputs_word = model(inputs)
            _, predicted_word = torch.max(outputs_word.data, 1)
            
            total += labels_word.size(0)
            correct_word += (predicted_word == labels_word).sum().item()
            
            y_pred_word.extend(predicted_word.cpu().numpy())
            y_true_word.extend(labels_word.cpu().numpy())

    accuracy_word = 100 * correct_word / total
    print(f'测试集准确率 Word: {accuracy_word:.2f}%')
    
    print("\nClassification Report (Word):")
    print(classification_report(y_true_word, y_pred_word))
    
    print("\nConfusion Matrix (Word):")
    print(confusion_matrix(y_true_word, y_pred_word))
    
    return y_pred_word

def k_fold_cross_validation(X, y, k=5, num_epochs=30):
    kfold = KFold(n_splits=k, shuffle=True, random_state=42)
    fold = 1
    for train_idx, test_idx in kfold.split(X):
        print(f'Fold {fold}/{k}')
        fold += 1
        
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        X_train = standardize_data(X_train)
        X_test = standardize_data(X_test)

        train_dataset = EEGDataset(X_train, y_train)
        test_dataset = EEGDataset(X_test, y_test)
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

        num_classes_word = len(np.unique(y))

        model = EEGNet(num_classes_word=num_classes_word, dropout=0.3).to(device)
        
        criterion_word = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=5e-4)
        train_losses = train_model(model, train_loader, criterion_word, optimizer, num_epochs)

        evaluate_model(model, test_loader)

if __name__ == "__main__":
    words = [
        '农民种菜', '厨师做饭', '祖父喝茶', '病人咳嗽', 
        '叛军投降', '孔雀开屏', '老牛耕地', '母鸡下蛋', 
        '蜻蜓点水', '螳螂捕蝉'
    ]
    set_files = ['c:/Users/clock/Desktop/hanzi/008C1.set', 'c:/Users/clock/Desktop/hanzi/008C2.set']
    X, y_word = process_set_files(set_files, words, 0, 3)
    k_fold_cross_validation(X, y_word, k=5, num_epochs=50)


读取EEGLAB .set 文件: c:/Users/clock/Desktop/hanzi/008C1.set...
Reading c:\Users\clock\Desktop\hanzi\008C1.fdt
Reading 0 ... 903022  =      0.000 ...   903.022 secs...
文件读取成功！
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 80 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 24 (effective, after forward-backward)
- Cutoffs at 0.50, 80.00 Hz: -6.02, -6.02 dB

Used Annotations descriptions: ['40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59']
读取EEGLAB .set 文件: c:/Users/clock/Desktop/hanzi/008C2.set...
Reading c:\Users\clock\Desktop\hanzi\008C2.fdt
Reading 0 ... 901822  =      0.000 ...   901.822 secs...


  X = epochs.get_data()


文件读取成功！
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 80 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 24 (effective, after forward-backward)
- Cutoffs at 0.50, 80.00 Hz: -6.02, -6.02 dB

Used Annotations descriptions: ['40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59']


  X = epochs.get_data()


Fold 1/5
Epoch [1/50], Loss Word: 2.3853, Accuracy Word: 9.38%
Epoch [2/50], Loss Word: 2.1808, Accuracy Word: 23.61%
Epoch [3/50], Loss Word: 1.9837, Accuracy Word: 31.25%
Epoch [4/50], Loss Word: 1.7413, Accuracy Word: 43.06%
Epoch [5/50], Loss Word: 1.5499, Accuracy Word: 47.92%
Epoch [6/50], Loss Word: 1.3472, Accuracy Word: 57.81%
Epoch [7/50], Loss Word: 1.1769, Accuracy Word: 66.84%
Epoch [8/50], Loss Word: 0.9997, Accuracy Word: 71.70%
Epoch [9/50], Loss Word: 0.9094, Accuracy Word: 74.31%
Epoch [10/50], Loss Word: 0.7719, Accuracy Word: 79.17%
Epoch [11/50], Loss Word: 0.6975, Accuracy Word: 81.25%
Epoch [12/50], Loss Word: 0.6025, Accuracy Word: 85.07%
Epoch [13/50], Loss Word: 0.5239, Accuracy Word: 85.94%
Epoch [14/50], Loss Word: 0.4370, Accuracy Word: 89.41%
Epoch [15/50], Loss Word: 0.3959, Accuracy Word: 90.45%
Epoch [16/50], Loss Word: 0.3677, Accuracy Word: 91.32%
Epoch [17/50], Loss Word: 0.3278, Accuracy Word: 92.01%
Epoch [18/50], Loss Word: 0.2997, Accuracy Word: 

KeyboardInterrupt: 