In [9]:

from tqdm import tqdm 
import mne
import pandas as pd
import os, glob, random, warnings, numpy as np, matplotlib.pyplot as plt
import pytorch_lightning as pl
from pytorch_lightning.callbacks import (EarlyStopping, ModelCheckpoint,
                                         LearningRateMonitor)
from torchmetrics.classification import Accuracy, ConfusionMatrix, MulticlassF1Score
from torch.utils import data as torch_data
from sklearn.metrics import classification_report
from pytorch_lightning.utilities import rank_zero_only 


import torch
import torchmetrics
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as torch_data
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from pytorch_lightning.utilities import rank_zero_only


import lightning as L
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from torchmetrics.classification import Accuracy

%matplotlib inline
plt.rcParams['axes.facecolor'] = 'lightgray'

# 定义常量
EPOCH_DURATION = 6.0  # 每个epoch的秒数
SEGMENT_DURATION = 1.0  # 每个切片的秒数
N_SEGMENTS = int(EPOCH_DURATION / SEGMENT_DURATION)  # 每个epoch切成6片

# 标签映射
label_map = {
    769: "OVTK_GDF_Left", 
    770: "OVTK_GDF_Right",
    780: "OVTK_GDF_Up",
    774: "OVTK_GDF_Down",
    32769: "OVTK_StimulationId_ExperimentStart",
    32775: "OVTK_StimulationId_BaselineStart",
    33026: "OVTK_GDF_Feedback_Continuous"
}

# 数值标签映射（用于分类）
numeric_label_map = {
    769: 0,  # 左
    770: 1,  # 右
    780: 2,  # 上
    774: 3,  # 下
    999: 4   # 静息
}

# EEG通道列
channel_cols = [
    'Channel 1', 'Channel 2', 'Channel 3', 'Channel 4',
    'Channel 5', 'Channel 6', 'Channel 7', 'Channel 8'
]   


In [10]:
# 设置常量
EPOCH_DURATION = 6.0  # 每个epoch的秒数
SEGMENT_DURATION = 1.0  # 每个切片的秒数
N_SEGMENTS = int(EPOCH_DURATION / SEGMENT_DURATION)  # 每个epoch切成6片
SAMPLING_FREQ = 250.0  # 采样频率

# 标签映射
label_map = {
    769: "OVTK_GDF_Left", 
    770: "OVTK_GDF_Right",
    780: "OVTK_GDF_Up",
    774: "OVTK_GDF_Down",
    32769: "OVTK_StimulationId_ExperimentStart",
    32775: "OVTK_StimulationId_BaselineStart",
    33026: "OVTK_GDF_Feedback_Continuous"
}

# 数值标签映射（用于分类）
numeric_label_map = {
    769: 0,  # 左
    770: 1,  # 右
    780: 2,  # 上
    774: 3,  # 下
    999: 4   # 静息
}

# EEG通道列
channel_cols = [
    'Channel 1', 'Channel 2', 'Channel 3', 'Channel 4',
    'Channel 5', 'Channel 6', 'Channel 7', 'Channel 8'
]

def preprocess_csv(file_path):
    """预处理单个CSV文件"""
    print(f"正在处理文件: {os.path.basename(file_path)}")
    
    # 读取CSV文件
    df = pd.read_csv(file_path)
    
    # 删除不需要的通道（如果存在）
    columns_to_drop = [col for col in ["Channel 9", "Channel 10", "Channel 11"] if col in df.columns]
    if columns_to_drop:
        df.drop(columns=columns_to_drop, inplace=True)
    
    # 处理事件ID、日期和持续时间列，保留冒号前的部分
    def keep_first_part(x):
        if not isinstance(x, str):
            x = "" if pd.isna(x) else str(x)
        return x.split(":")[0] if x else ""
    
    for col in ["Event Id", "Event Date", "Event Duration"]:
        if col in df.columns:
            df[col] = df[col].fillna("").apply(keep_first_part)
    
    # 清空重复事件行
    if "Event Id" in df.columns and "Event Date" in df.columns and "Event Duration" in df.columns:
        same_mask = df["Event Id"] == df["Event Id"].shift(-1)
        rows_to_blank = same_mask.index[same_mask]
        rows_to_blank = rows_to_blank + 1
        rows_to_blank = rows_to_blank[rows_to_blank < len(df)]
        df.loc[rows_to_blank, ["Event Id", "Event Date", "Event Duration"]] = ""
        
        # 转换为数值类型
        df["Event Id"] = pd.to_numeric(df["Event Id"], errors="coerce")
        df["Event Date"] = pd.to_numeric(df["Event Date"], errors="coerce")
    
    return df

def create_raw_from_dataframe(df):
    """从DataFrame创建MNE Raw对象"""
    # 提取EEG数据
    data = df[channel_cols].to_numpy().T
    
    # 创建MNE信息对象
    ch_names = ['ch1', 'ch2', 'ch3', 'ch4', 'ch5', 'ch6', 'ch7', 'ch8']
    ch_types = ['eeg'] * 8
    info = mne.create_info(ch_names=ch_names, sfreq=SAMPLING_FREQ, ch_types=ch_types)
    
    # 创建Raw对象
    raw = mne.io.RawArray(data, info)
    
    # 应用50Hz陷波滤波器（去除电源线噪声）
    raw.notch_filter(freqs=[50], picks='eeg')
    
    return raw

def extract_events_from_dataframe(df, raw):
    """从DataFrame提取事件数据"""
    events = []
    
    if "Event Id" in df.columns and "Event Date" in df.columns:
        event_indices = df.index[~pd.isna(df["Event Id"]) & (df["Event Id"] != "")]
        
        for idx in event_indices:
            event_id = df.loc[idx, "Event Id"]
            event_date = df.loc[idx, "Event Date"]
            
            # 只保留任务相关事件
            if event_id in [769, 770, 780, 774]:
                # 事件样本点位置
                sample = int(event_date)
                
                # 创建事件元组 [sample, 0, event_id]
                events.append([sample, 0, int(event_id)])
    
    return np.array(events, dtype=int) if events else np.empty((0, 3), dtype=int)

def process_eeg_data(raw, events):
    """
    处理EEG数据
    提取所有类别的epochs并进行一致的切片
    """
    
    # 提取任务相关epochs（左、右、上、下）
    task_event_id = {
        'left': 769,
        'right': 770,
        'up': 780,
        'down': 774
    }
    
    task_epochs = mne.Epochs(
        raw,
        events=events,
        event_id=task_event_id,
        tmin=0.0,
        tmax=EPOCH_DURATION,
        baseline=None,
        picks='eeg',
        preload=True
    )
    
    X_task = (task_epochs.get_data() * 1e-6).astype(np.float32)
    y_task = np.array([numeric_label_map[val] for val in task_epochs.events[:, 2]], dtype=np.int64)
    
    print(f"任务相关epochs数量: {len(task_epochs)}")
    
    # 标记已被任务epochs占用的样本点
    used_samples = np.zeros(len(raw.times), dtype=bool)
    
    for idx in range(len(task_epochs)):
        event_sample = task_epochs.events[idx, 0]
        start_sample = event_sample
        end_sample = start_sample + int(EPOCH_DURATION * raw.info['sfreq'])
        if end_sample <= len(used_samples):
            used_samples[start_sample:end_sample] = True
    
    # 寻找未被占用的样本点作为静息epochs
    rest_events = []
    rest_length = int(EPOCH_DURATION * raw.info['sfreq'])
    
    i = 0
    while i < len(used_samples):
        if not used_samples[i]:
            start_sample = i
            end_sample = min(start_sample + rest_length, len(used_samples))
            
            # 确保整个区间没有被占用
            if end_sample - start_sample >= rest_length and not any(used_samples[start_sample:end_sample]):
                rest_events.append([start_sample, 0, 999])  # 999为静息状态代码
                i = end_sample
            else:
                i += 1
        else:
            i += 1
    
    rest_events = np.array(rest_events, dtype=int) if rest_events else np.empty((0, 3), dtype=int)
    
    # 创建静息epochs
    if len(rest_events) > 0:
        rest_event_id = {'rest': 999}
        rest_epochs = mne.Epochs(
            raw,
            events=rest_events,
            event_id=rest_event_id,
            tmin=0.0,
            tmax=EPOCH_DURATION,
            baseline=None,
            picks='eeg',
            preload=True
        )
        
        X_rest = (rest_epochs.get_data() * 1e-6).astype(np.float32)
        y_rest = np.ones(len(rest_epochs), dtype=np.int64) * numeric_label_map[999]
        
        print(f"静息epochs数量: {len(rest_epochs)}")
        
        # 合并任务和静息数据
        X_combined = np.vstack([X_task, X_rest])
        y_combined = np.concatenate([y_task, y_rest])
    else:
        print("未找到静息epochs")
        X_combined = X_task
        y_combined = y_task
    
    # 对所有epochs进行一致的切片处理
    all_segments = []
    all_segment_labels = []
    
    samples_per_segment = int(raw.info['sfreq'] * SEGMENT_DURATION)
    
    for i in range(len(X_combined)):
        epoch_data = X_combined[i]
        label = y_combined[i]
        
        for j in range(N_SEGMENTS):
            start_idx = j * samples_per_segment
            end_idx = (j + 1) * samples_per_segment
            
            if end_idx <= epoch_data.shape[1]:  # 确保不越界
                segment_data = epoch_data[:, start_idx:end_idx]
                all_segments.append(segment_data)
                all_segment_labels.append(label)
    
    X_segments = np.array(all_segments)
    y_segments = np.array(all_segment_labels)
    
    print(f"原始epochs总数: {len(X_combined)}")
    print(f"切片后segments总数: {len(all_segments)}")
    print(f"切片后数据形状: {X_segments.shape}")

    return X_segments, y_segments

def get_eeg_channels(raw):
    """获取EEG通道数"""
    eeg_channel_inds = mne.pick_types(
        raw.info,
        meg=False,
        eeg=True,
        stim=False,
        eog=False,
        exclude='bads',
    )
    return len(eeg_channel_inds)

def load_all_csv_files(folder_path, pattern="*.csv"):
    """加载文件夹中所有符合模式的CSV文件，并处理它们"""
    all_X_segments = []
    all_y_segments = []
    
    # 获取所有匹配的CSV文件
    csv_files = glob.glob(os.path.join(folder_path, pattern))
    
    if not csv_files:
        print(f"在 {folder_path} 中未找到符合 {pattern} 的CSV文件")
        return None, None
    
    print(f"找到 {len(csv_files)} 个CSV文件")
    
    # 处理每个CSV文件
    for file_path in tqdm(csv_files, desc="处理文件"):
        try:
            # 预处理CSV
            df = preprocess_csv(file_path)
            
            # 创建Raw对象
            raw = create_raw_from_dataframe(df)
            
            # 提取事件
            events = extract_events_from_dataframe(df, raw)
            
            if len(events) > 0:
                # 处理EEG数据并获取切片
                X_segments, y_segments = process_eeg_data(raw, events)
                
                # 添加到总集合
                all_X_segments.append(X_segments)
                all_y_segments.append(y_segments)
            else:
                print(f"文件 {os.path.basename(file_path)} 中未找到有效事件")
        
        except Exception as e:
            print(f"处理文件 {os.path.basename(file_path)} 时出错: {str(e)}")
    
    # 合并所有数据
    if all_X_segments and all_y_segments:
        X_all = np.vstack(all_X_segments)
        y_all = np.concatenate(all_y_segments)
        
        print(f"所有文件处理完成!")
        print(f"总数据形状: {X_all.shape}")
        print(f"总标签形状: {y_all.shape}")
        
        # 显示每个类别的样本计数
        for i in range(5):
            count = np.sum(y_all == i)
            percent = count / len(y_all) * 100
            print(f"类别 {i} ({list(label_map.values())[i] if i < 4 else '静息'}): {count} 样本 ({percent:.2f}%)")
        
        return X_all, y_all
    else:
        print("未能从任何文件中提取到有效数据")
        return None, None

class EEGDataset(torch_data.Dataset):
    """
    处理EEG数据的Dataset类，支持五类分类（左、右、上、下、静息）
    并提供训练、验证和测试集的拆分功能
    """
    def __init__(self, x, y=None, inference=False, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, random_state=42):
        """
        初始化EEG数据集
        
        参数:
            x: 形状为(n_samples, n_channels, n_times)的numpy数组
            y: 形状为(n_samples,)的numpy数组，类别标签(0-4对应左、右、上、下、静息)
            inference: 是否为推理模式（不需要标签）
            train_ratio: 训练集比例
            val_ratio: 验证集比例
            test_ratio: 测试集比例
            random_state: 随机种子，用于数据集拆分
        """
        super().__init__()
        self.__split = None
        
        # 确保比例和为1
        assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-10, "数据集比例必须和为1"
        
        if not inference:
            # 使用sklearn的train_test_split进行更可靠的拆分
            # 首先分离出测试集
            X_temp, X_test, y_temp, y_test = train_test_split(
                x, y, test_size=test_ratio, random_state=random_state, stratify=y
            )
            
            # 再从剩余数据中分离出验证集
            val_ratio_adjusted = val_ratio / (train_ratio + val_ratio)  # 调整验证集比例
            X_train, X_val, y_train, y_val = train_test_split(
                X_temp, y_temp, test_size=val_ratio_adjusted, 
                random_state=random_state, stratify=y_temp
            )
            
            self.train_ds = {
                'x': X_train,
                'y': y_train,
            }
            self.val_ds = {
                'x': X_val,
                'y': y_val,
            }
            self.test_ds = {
                'x': X_test,
                'y': y_test,
            }
            
            # 打印每个集合的样本数和类别分布
            print(f"训练集大小: {len(X_train)} 样本")
            print(f"验证集大小: {len(X_val)} 样本")
            print(f"测试集大小: {len(X_test)} 样本")
            
            for i in range(5):  # 五类: 0-左, 1-右, 2-上, 3-下, 4-静息
                print(f"类别 {i} 分布: 训练集 {sum(y_train == i)}, "
                      f"验证集 {sum(y_val == i)}, 测试集 {sum(y_test == i)}")
        else:
            self.inference_ds = {'x': x}
            print(f"推理数据集大小: {len(x)} 样本")

    def __len__(self):
        """返回当前拆分的数据集长度"""
        return len(self.dataset['x'])

    def __getitem__(self, idx):
        """获取指定索引的样本"""
        x_ = torch.tensor(self.dataset['x'][idx], dtype=torch.float32)  # shape=(n_channels, n_times)
        
        if self.__split != "inference":
            y_ = torch.tensor(self.dataset['y'][idx], dtype=torch.long)  # 使用long类型用于分类
            return x_, y_
        else:
            return x_

    def split(self, __split):
        """设置当前使用的数据集拆分"""
        self.__split = __split
        return self

    @property
    def dataset(self):
        """根据当前拆分返回相应的数据集"""
        assert self.__split is not None, "必须先指定数据集拆分(train/val/test/inference)!"
        
        if self.__split == "train":
            return self.train_ds
        elif self.__split == "val":
            return self.val_ds
        elif self.__split == "test":
            return self.test_ds
        elif self.__split == "inference":
            return self.inference_ds
        else:
            raise ValueError(f"未知的数据集拆分: {self.__split}")

    def get_loaders(self, batch_size=32, num_workers=4):
        """获取所有数据加载器"""
        train_loader = DataLoader(
            self.split("train"), batch_size=batch_size, shuffle=True, 
            num_workers=num_workers, pin_memory=True
        )
        
        val_loader = DataLoader(
            self.split("val"), batch_size=batch_size, shuffle=False, 
            num_workers=num_workers, pin_memory=True
        )
        
        test_loader = DataLoader(
            self.split("test"), batch_size=batch_size, shuffle=False, 
            num_workers=num_workers, pin_memory=True
        )
        
        return train_loader, val_loader, test_loader


data_folder = "D:/data/code/eeg/OpenViBE/data/TEST"
X_all, y_all = load_all_csv_files(data_folder)
    
if X_all is not None and y_all is not None:
    eeg_dataset = EEGDataset(X_all, y_all)
    train_loader, val_loader, test_loader = eeg_dataset.get_loaders(batch_size=64)

找到 3 个CSV文件


处理文件:   0%|          | 0/3 [00:00<?, ?it/s]

正在处理文件: motor-imagery-1-[2025.04.20-12.14.27].csv


  df = pd.read_csv(file_path)


Creating RawArray with float64 data, n_channels=8, n_times=101728
    Range : 0 ... 101727 =      0.000 ...   406.908 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 49 - 51 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 49.38
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 49.12 Hz)
- Upper passband edge: 50.62 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 50.88 Hz)
- Filter length: 1651 samples (6.604 s)

Not setting metadata
25 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 25 events and 1501 original time points ...
0 bad epochs dropped
任务相关epochs数量: 25
Not setting metadata
66 matching events found
No baseline correction applied
0 proje

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.0s finished
  X_task = (task_epochs.get_data() * 1e-6).astype(np.float32)
  X_rest = (rest_epochs.get_data() * 1e-6).astype(np.float32)
处理文件:  33%|███▎      | 1/3 [00:00<00:01,  1.77it/s]

静息epochs数量: 66
原始epochs总数: 91
切片后segments总数: 546
切片后数据形状: (546, 8, 250)
正在处理文件: motor-imagery-1-[2025.04.20-12.22.41].csv
Creating RawArray with float64 data, n_channels=8, n_times=103744
    Range : 0 ... 103743 =      0.000 ...   414.972 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 49 - 51 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 49.38
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 49.12 Hz)
- Upper passband edge: 50.62 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 50.88 Hz)
- Filter length: 1651 samples (6.604 s)

Not setting metadata
25 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 25 events and 1501 original time points ...


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.0s finished
  X_task = (task_epochs.get_data() * 1e-6).astype(np.float32)
  X_rest = (rest_epochs.get_data() * 1e-6).astype(np.float32)
处理文件:  67%|██████▋   | 2/3 [00:00<00:00,  2.21it/s]

静息epochs数量: 67
原始epochs总数: 92
切片后segments总数: 552
切片后数据形状: (552, 8, 250)
正在处理文件: motor-imagery-1-[2025.04.20-12.31.29].csv
Creating RawArray with float64 data, n_channels=8, n_times=101248
    Range : 0 ... 101247 =      0.000 ...   404.988 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 49 - 51 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 49.38
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 49.12 Hz)
- Upper passband edge: 50.62 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 50.88 Hz)
- Filter length: 1651 samples (6.604 s)

Not setting metadata
25 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 25 events and 1501 original time points ...


  df = pd.read_csv(file_path)
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.0s finished
  X_task = (task_epochs.get_data() * 1e-6).astype(np.float32)
  X_rest = (rest_epochs.get_data() * 1e-6).astype(np.float32)
处理文件: 100%|██████████| 3/3 [00:01<00:00,  2.32it/s]

静息epochs数量: 66
原始epochs总数: 91
切片后segments总数: 546
切片后数据形状: (546, 8, 250)
所有文件处理完成!
总数据形状: (1644, 8, 250)
总标签形状: (1644,)
类别 0 (OVTK_GDF_Left): 108 样本 (6.57%)
类别 1 (OVTK_GDF_Right): 114 样本 (6.93%)
类别 2 (OVTK_GDF_Up): 120 样本 (7.30%)
类别 3 (OVTK_GDF_Down): 108 样本 (6.57%)
类别 4 (静息): 1194 样本 (72.63%)
训练集大小: 1314 样本
验证集大小: 165 样本
测试集大小: 165 样本
类别 0 分布: 训练集 86, 验证集 11, 测试集 11
类别 1 分布: 训练集 92, 验证集 11, 测试集 11
类别 2 分布: 训练集 96, 验证集 12, 测试集 12
类别 3 分布: 训练集 86, 验证集 11, 测试集 11
类别 4 分布: 训练集 954, 验证集 120, 测试集 120





In [11]:

 

class AvgMeter:
    """滑动窗口平均；窗口大小 = num"""
    def __init__(self, num: int = 40):
        self.num = num
        self.reset()

    def reset(self):
        self.losses: list[torch.Tensor] = []

    def update(self, val: torch.Tensor):
        self.losses.append(val.detach().clone())

    def show(self) -> torch.Tensor:
        if len(self.losses) == 0:
            return torch.tensor(0.0, device=self.losses[0].device)
        recent = self.losses[-self.num :]
        return torch.stack(recent).mean()


class ModelWrapper(pl.LightningModule):
    def __init__(
        self,
        arch: nn.Module,
        dataset,                   # EEGDataset 实例
        batch_size: int = 64,
        lr: float = 1e-3,
        max_epoch: int = 100,
    ):
        super().__init__()
        self.save_hyperparameters(ignore=["arch", "dataset"])

        self.arch        = arch
        self.dataset_obj = dataset        # 原始数据集对象
        self.batch_size  = batch_size
        self.lr          = lr
        self.max_epoch   = max_epoch
        self.num_classes = 5              # ←← 改为 5

        # --- metrics ---
        self.train_acc = Accuracy(task="multiclass", num_classes=self.num_classes)
        self.val_acc   = Accuracy(task="multiclass", num_classes=self.num_classes)
        self.test_acc  = Accuracy(task="multiclass", num_classes=self.num_classes)

        self.confmat   = ConfusionMatrix(
            task="multiclass", num_classes=self.num_classes
        )
        self.f1score   = MulticlassF1Score(num_classes=self.num_classes)

        # --- recorders ---
        self.tr_loss_meter, self.val_loss_meter = AvgMeter(), AvgMeter()
        self.tr_acc_meter , self.val_acc_meter  = AvgMeter(), AvgMeter()

        # 关闭手动优化，Lightning 自动反向传播/step
        self.automatic_optimization = True

    # ---------------- core ----------------
    def forward(self, x):
        return self.arch(x)

    def _shared_step(self, batch):
        x, y = batch
        y    = y.long().squeeze()       # shape -> (B,)
        logits = self(x)
        loss   = F.cross_entropy(logits, y)
        preds  = torch.argmax(logits, dim=1)
        return loss, preds, y

    def training_step(self, batch, _):
        loss, preds, y = self._shared_step(batch)
        self.tr_loss_meter.update(loss)
        self.tr_acc_meter.update(self.train_acc(preds, y))
        self.log_dict(
            {"train/loss": loss, "train/acc": self.train_acc(preds, y)},
            on_step=False, on_epoch=True, prog_bar=True, sync_dist=True
        )
        return loss

    def on_train_epoch_end(self):
        rank_zero_only(
            lambda: print(
                f"[Epoch {self.current_epoch}] "
                f"train_loss={self.tr_loss_meter.show():.4f}, "
                f"train_acc={self.tr_acc_meter.show():.4f}"
            )
        )
        self.tr_loss_meter.reset()
        self.tr_acc_meter.reset()

    def validation_step(self, batch, _):
        loss, preds, y = self._shared_step(batch)
        self.val_loss_meter.update(loss)
        self.val_acc_meter .update(self.val_acc(preds, y))
        self.log_dict(
            {"val/loss": loss, "val/acc": self.val_acc(preds, y)},
            on_step=False, on_epoch=True, prog_bar=True, sync_dist=True
        )

    def on_validation_epoch_end(self):
        rank_zero_only(
            lambda: print(
                f"[Epoch {self.current_epoch}] "
                f"val_loss={self.val_loss_meter.show():.4f}, "
                f"val_acc={self.val_acc_meter.show():.4f}"
            )
        )
        self.val_loss_meter.reset()
        self.val_acc_meter.reset()

    def test_step(self, batch, _):
        loss, preds, y = self._shared_step(batch)
        self.test_acc.update(preds, y)
        self.confmat .update(preds, y)
        self.log_dict(
            {"test/loss": loss},
            on_step=False, on_epoch=True, prog_bar=True, sync_dist=True
        )

    def on_test_epoch_end(self):
        cm = self.confmat.compute().cpu().numpy()
        acc = self.test_acc.compute().item()
        f1  = self.f1score.compute().item()
        rank_zero_only(lambda: print(f"\nTest ACC={acc:.4f}, F1={f1:.4f}\nConfusion‑Matrix:\n{cm}"))
        self.confmat.reset()

        # ------ 详细分类报告 ------
        if rank_zero_only.rank == 0:
            report = classification_report(
                y_true = np.concatenate([t.cpu().numpy() for t in self.confmat.target]),
                y_pred = np.concatenate([p.cpu().numpy() for p in self.confmat.preds]),
                target_names=["Left","Right","Up","Down","Rest"],
                digits=4
            )
            print(report)

    # ---------------- data loaders ---------------
    def train_dataloader(self):
        return self.dataset_obj.get_loaders(
            batch_size=self.batch_size, num_workers=os.cpu_count()
        )[0]

    def val_dataloader(self):
        return self.dataset_obj.get_loaders(
            batch_size=self.batch_size, num_workers=os.cpu_count()
        )[1]

    def test_dataloader(self):
        return self.dataset_obj.get_loaders(
            batch_size=self.batch_size, num_workers=os.cpu_count()
        )[2]

    # --------------- optim & sched ---------------
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=self.lr,
            total_steps=self.trainer.estimated_stepping_batches,
            pct_start=0.3,
            anneal_strategy="cos",
            final_div_factor=1e2,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": scheduler, "interval": "step"},
        }


In [None]:
# ============= 0. 依赖 =============
import os, random, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
import pytorch_lightning as L
from torchmetrics.classification import Accuracy, ConfusionMatrix, MulticlassF1Score
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import CSVLogger

# ============= 1. 随机种子 =============
SEED = 42
L.seed_everything(SEED, workers=True)

# =========================================================
# ----------------- 模型组件定义 ---------------------------
# =========================================================
class PositionalEncoding(nn.Module):
    """Transformer 常用位置编码 (batch, seq_len, embed_dim)"""
    def __init__(self, embed_dim: int, dropout: float, max_len: int = 1000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, embed_dim, 2).float() * (-np.log(10000.0) / embed_dim)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)   # (1, max_len, embed_dim)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)]
        return self.dropout(x)


class TransformerBlock(nn.Module):
    """简化版 Encoder Block"""
    def __init__(self, embed_dim: int, num_heads: int,
                 dim_feedforward: int, dropout: float = 0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads,
                                          dropout=dropout, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, dim_feedforward),
            nn.ReLU(True),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, embed_dim),
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        # Self‑Attention
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        # Feed‑Forward
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.drop(ffn_out))
        return x


class EEGClassificationModel(nn.Module):
    """
    整体架构：
    Conv1d → BatchNorm → Conv1d → BatchNorm →
    2×TransformerBlock → MLP → 5‑way logits
    """
    def __init__(self, eeg_channel: int, dropout: float = 0.1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(eeg_channel, eeg_channel, kernel_size=11,
                      stride=1, padding=5, bias=False),
            nn.BatchNorm1d(eeg_channel),
            nn.ReLU(True),
            nn.Dropout(dropout),
            nn.Conv1d(eeg_channel, eeg_channel * 2, kernel_size=11,
                      stride=1, padding=5, bias=False),
            nn.BatchNorm1d(eeg_channel * 2),
        )

        self.transformer = nn.Sequential(
            PositionalEncoding(eeg_channel * 2, dropout),
            TransformerBlock(eeg_channel * 2, num_heads=4,
                             dim_feedforward=eeg_channel // 4, dropout=dropout),
            TransformerBlock(eeg_channel * 2, num_heads=4,
                             dim_feedforward=eeg_channel // 4, dropout=dropout),
        )

        self.mlp = nn.Sequential(
            nn.Linear(eeg_channel * 2, eeg_channel // 2),
            nn.ReLU(True),
            nn.Dropout(dropout),
            nn.Linear(eeg_channel // 2, 5),   # ← 5 类
        )

    def forward(self, x):          # x: (B, C, T)
        x = self.conv(x)           # (B, 2C, T)
        x = x.permute(0, 2, 1)     # (B, T, 2C)
        x = self.transformer(x)    # (B, T, 2C)
        x = x.permute(0, 2, 1)     # (B, 2C, T)
        x = x.mean(dim=-1)         # (B, 2C)
        return self.mlp(x)         # (B, 5)

# =========================================================
# ---------------- Lightning 包装器 ------------------------
# =========================================================
class ModelWrapper(L.LightningModule):
    def __init__(self, arch: nn.Module, dataset,
                 batch_size: int = 64, lr: float = 2e-3, max_epoch: int = 100):
        super().__init__()
        self.save_hyperparameters(ignore=["arch", "dataset"])
        self.arch        = arch
        self.dataset_obj = dataset
        self.num_classes = 5

        # metrics
        self.tr_acc = Accuracy(task="multiclass", num_classes=5)
        self.val_acc = Accuracy(task="multiclass", num_classes=5)
        self.test_acc = Accuracy(task="multiclass", num_classes=5)
        self.confmat = ConfusionMatrix(task="multiclass", num_classes=5)
        self.f1 = MulticlassF1Score(num_classes=5)

    def forward(self, x):
        return self.arch(x)

    # ------ shared step ------
    def _common_step(self, batch):
        x, y = batch
        y = y.long().squeeze()
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = logits.argmax(dim=1)
        return loss, preds, y

    # ------ train / val / test ------
    def training_step(self, batch, _):
        loss, preds, y = self._common_step(batch)
        self.log_dict({"train/loss": loss,
                       "train/acc": self.tr_acc(preds, y)},
                      prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, _):
        loss, preds, y = self._common_step(batch)
        self.log_dict({"val/loss": loss,
                       "val/acc": self.val_acc(preds, y)},
                      prog_bar=True, on_step=False, on_epoch=True)

    def test_step(self, batch, _):
        loss, preds, y = self._common_step(batch)
        self.confmat.update(preds, y)
        self.log("test/loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log("test/acc", self.test_acc(preds, y), prog_bar=True, on_step=False, on_epoch=True)

    def on_test_epoch_end(self):
        cm = self.confmat.compute().cpu().numpy()
        acc = self.test_acc.compute().item()
        f1  = self.f1.compute().item()
        print(f"\nTest ACC={acc:.4f}  F1={f1:.4f}\nConfusion‑Matrix:\n{cm}")

    # ------ loaders ------
    def train_dataloader(self):
        return self.dataset_obj.get_loaders(batch_size=self.hparams.batch_size)[0]

    def val_dataloader(self):
        return self.dataset_obj.get_loaders(batch_size=self.hparams.batch_size)[1]

    def test_dataloader(self):
        return self.dataset_obj.get_loaders(batch_size=self.hparams.batch_size)[2]

    # ------ optim ------
    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=1e-4)
        sched = torch.optim.lr_scheduler.OneCycleLR(
            opt, max_lr=self.hparams.lr,
            total_steps=self.trainer.estimated_stepping_batches,
            pct_start=0.3, div_factor=25.0, final_div_factor=1e2
        )
        return {"optimizer": opt, "lr_scheduler": {"scheduler": sched, "interval": "step"}}

# =========================================================
# -------------------- 一键训练脚本 ------------------------
# =========================================================
# ==== 假设你已经执行过预处理，得到了 X_all, y_all, eeg_dataset ====
print("X_all shape:", X_all.shape)      # (n_segments, 8, 250)

EEG_CHANNEL   = X_all.shape[1]          # 通道数 8
MAX_EPOCH     = 80
BATCH_SIZE    = 64
LR            = 2e-3
CHECKPOINT_DIR = os.getcwd()

# -- 构建基础模型 --
base_model = EEGClassificationModel(eeg_channel=EEG_CHANNEL, dropout=0.20)

if torch.cuda.is_available() and int(torch.__version__.split('.')[0]) >= 2:
    torch.set_float32_matmul_precision("high")
    base_model = torch.compile(base_model)

# -- LightningModule --
wrapper = ModelWrapper(
    arch       = base_model,
    dataset    = eeg_dataset,
    batch_size = BATCH_SIZE,
    lr         = LR,
    max_epoch  = MAX_EPOCH,
)

# ---- 回调 ----
ckpt_cb = ModelCheckpoint(
    dirpath   = CHECKPOINT_DIR,
    filename  = "best-{epoch:02d}-{val_acc:.4f}",
    monitor   = "val/acc",
    mode      = "max",
    save_top_k= 1,
)
early_cb = EarlyStopping(monitor="val/acc", mode="max", patience=12, verbose=True)
lr_cb    = LearningRateMonitor(logging_interval="epoch")
logger   = CSVLogger(save_dir=CHECKPOINT_DIR, name="eeg_log")

# ---- Trainer ----
trainer = L.Trainer(
    max_epochs        = MAX_EPOCH,
    accelerator       = "auto",
    precision         = "bf16-mixed" if torch.cuda.is_available() else 32,
    gradient_clip_val = 1.0,
    callbacks         = [ckpt_cb, early_cb, lr_cb],
    logger            = logger,
    log_every_n_steps = 10,
    deterministic     = True,
)

# ---- 训练 & 测试 ----
trainer.fit(wrapper)
trainer.test(wrapper, ckpt_path=ckpt_cb.best_model_path)
print("Best checkpoint:", ckpt_cb.best_model_path)


Seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


X_all shape: (1644, 8, 250)


c:\ProgramData\anaconda3\Lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:654: Checkpoint directory D:\data\code\eeg\src exists and is not empty.
Loading `train_dataloader` to estimate number of stepping batches.
c:\ProgramData\anaconda3\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:420: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.
