In [None]:
import os
import pandas as pd
import numpy as np
import numpy as np

# mne imports
import mne
from mne import io
from mne.datasets import sample
from mne.preprocessing import ICA


# tools for plotting confusion matrices
from matplotlib import pyplot as plt

# 设置通道数和样本数
channels = 14
samples = 384

# 存放数据的文件夹路径
data_folder = './EEGData/'

# 存储所有数据的列表和标签列表
data_list = []
label_list = []

# 遍历文件夹中的所有CSV文件
for file_name in os.listdir(data_folder):
    if file_name.endswith(".csv"):
        file_path = os.path.join(data_folder, file_name)
        
        # 使用pandas读取CSV文件，只读取2-15列的数据
        df = pd.read_csv(file_path, usecols=list(range(1, 15)), header=0)

        sfreq = 128  # 采样率为128Hz
        ch_names = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4']
        ch_types = ['eeg'] * channels

        # 创建MNE的Raw对象
        info = mne.create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq)
        raw = mne.io.RawArray(df.T, info=info)

        # 滤波
        raw.filter(l_freq=1.0, h_freq=50.0)

        # 创建ICA对象并拟合数据
        ica = ICA(n_components=channels, random_state=0, max_iter=1000)  # 调整参数
        ica.fit(raw)

        # 应用ICA滤波
        ica.exclude = []  # 将排除的独立成分列表设置为空
        ica.apply(raw)

        # 获取滤波后的数据
        data = raw.get_data().T

        # 将数据调整为每个trial的形状
        # 假设数据的样本数为samples
        num_trials = data.shape[0] // samples
        data = data[:num_trials * samples, :]
        
        # 重新将数据分成每个trial的形状
        data = data.reshape(num_trials, samples, channels)

        # 添加数据到列表中
        data_list.append(data)
        
        # 添加标签到列表中
        labels = np.zeros(num_trials, dtype=int)
        labels[3:18] = 1  # Trials 4-18 labeled as 1
        labels[18:] = 2   # Trials 19-21 labeled as 2
        label_list.append(labels)

# 将数据列表转换为numpy数组并按顺序连接
data_array = np.concatenate(data_list, axis=0)

# 将标签列表转换为numpy数组并按顺序连接
label_array = np.concatenate(label_list, axis=0)

# (399, 384, 14) 14个通道，每个通道384个样本，总共399个试验数量
print("Data shape:", data_array.shape)  
# (399,) 399个试验数量
print("Label shape:", label_array.shape)