In [None]:
# Imports
import numpy as np
from scipy.signal import correlate, correlation_lags
import librosa
import mne
from matplotlib import pyplot as plt
%matplotlib qt

#import 关于ica的一切
from mne.preprocessing import ICA
from mne_icalabel import label_components
from matplotlib.cm import get_cmap

import numpy as np
import librosa as lb
from scipy import signal # <<< 加上这一行来导入 signal 模块
from typing import Union

import numpy as np
import h5py
from scipy.signal import resample


In [None]:
root_dir = r'D:\A-BCI声音脑机接口\4'

eeg_files = sorted(glob.glob(os.path.join(root_dir, "Data", "Session*", "*.vhdr")))
audio_files = sorted(glob.glob(os.path.join(root_dir, 'audio', 'M*.wav')))
out_file = os.path.join(root_dir, 'my_backward_dataset1.hdf5')

print(f"找到 {len(eeg_files)} 个 EEG Session 文件")
print(f"找到 {len(audio_files)} 个 Audio 文件")

In [None]:
# 32-channel mapping (to rename the channels)
channel_map = {
    '1': 'Fp1',
    '2': 'Fp2',
    '3': 'F7',
    '4': 'F3',
    '5': 'Fz',
    '6': 'F4',
    '7': 'F8',
    '8': 'FC5',
    '9': 'FC1',
    '10': 'FC2',
    '11': 'FC6',
    '12': 'T7',
    '13': 'C3',
    '14': 'Cz',
    '15': 'C4',
    '16': 'T8',
    '17': 'CP5',
    '18': 'CP1',
    '19': 'CP2',
    '20': 'CP6',
    '21': 'P7',
    '22': 'P3',
    '23': 'Pz',
    '24': 'P4',
    '25': 'P8',
    '26': 'POz',
    '27': 'O1',
    '28': 'Oz',
    '29': 'O2',
    '30': 'FT9',
    '31': 'FT10',
    '32': 'TP9'
}

# Set montage
montage = mne.channels.make_standard_montage('standard_1020')

In [None]:
#Alignment EEG (StimTrak) and Audio

## Pad shorter signal with zeroes
def pad_zeros_right(s, padding_length):
    return np.pad(s, (0, padding_length), mode='constant', constant_values=0)

def padding(a, b, pad_function=None):
    if len(a) != len(b) and pad_function is None:
        raise ValueError(f"len(a)={len(a)} != len(b)={len(b)} and no pad_function provided")
    elif len(a) != len(b):
        if len(a) < len(b):
            a = pad_function(a, len(b) - len(a))
        else:
            b = pad_function(b, len(a) - len(b))
    return a, b

def crosscorrelation(ref, sig):
    # ref = StimTrak/EEG-reference, sig = audio
    c = correlate(ref, sig, mode='full')
    lags = correlation_lags(len(ref), len(sig), mode='full')
    return c, lags

In [None]:
   #提取audio文件的音频 用于后续的对齐
    fs = 1000  # Hz

    # Load audio wav-file
    audios = []

    for audionr in range(1, 16):

        # Lade und resample direkt auf 1000 Hz
        data, sr = librosa.load(audio_files, sr=fs)  # resample the audio data from 48kHz to 1kHz to align it with EEG

        audios.append(data)

In [None]:
###get envelope_list of audios (滤delta波，降采样至1000hz，用来加工hdf5文件)
import numpy as np
import librosa as lb
from scipy import signal # <<< 加上这一行来导入 signal 模块
from typing import Union

# helper functions
def get_envelope_from_hilbert(data: np.ndarray):
    analytic_signal = signal.hilbert(data)
    return np.abs(analytic_signal)

def apply_butterworth_bandpass(data, order=4, cutoff_low=1, cutoff_high=20, fs=1000, axis=-1):
    sos = signal.butter(N=order, Wn=[cutoff_low, cutoff_high], 
                        btype="bandpass", fs=fs, output="sos")
    return signal.sosfiltfilt(sos, data, axis=axis)

# ========== 音频预处理只做一次 ==========
target_fs = 1000
envelopes_list = []

for i in range(1, 16):

    # 1. 以原始采样率加载
    stimuli, sr_orig = lb.load(audio_files, sr=None, mono=True)

    # 2. 提取包络
    envelope = get_envelope_from_hilbert(stimuli)

    # 3. 滤波
    delta_band_high = apply_butterworth_bandpass(
        envelope, order=4, cutoff_low=1, cutoff_high=4, fs=sr_orig
    )

    # 4. 降采样到 EEG 采样率 1000 Hz
    num_samples_1khz = int(len(delta_band_high) * target_fs / sr_orig)
    delta_band_1khz = signal.resample(delta_band_high, num_samples_1khz)

    # 5. 放到 list 里
    envelopes_list.append(delta_band_1khz)


In [None]:
# --- 1. 初始化设置 ---

# 如果你想每次运行代码都从头开始，先删除旧文件
if os.path.exists(out_file):
    os.remove(out_file) 
    print(f"旧文件 {out_file} 已删除，准备创建新文件。")

# 定义起始 Subject ID (例如从 301 开始)
current_subject_id = 301

# --- 2. 开始遍历 Session (Session 循环) ---
for sess_id, eeg_path in enumerate(eeg_files, start=1):
    print(f"\n=== 处理 Session {sess_id} : {os.path.basename(eeg_path)} ===")

    # 读取数据
    raw = mne.io.read_raw_brainvision(eeg_path, preload=True)
    
    # 提取 StimTrak (假设每个session都需要重新提取和对齐)
    raw_data = raw.get_data()
    stimtrak = raw_data[-1, :]
    speech_stimulus_eeg = stimtrak

    # --- Timeshifts 计算 (针对当前 Session) ---
    # 注意：如果每个Session的音频对齐不同，这里需要重新计算
    lag_samples_list = []
    # timeshift_s_list = [] # 如果后面没用到可以注释掉

    print("   正在计算 Audio Lags...")
    for i, audio in enumerate(audios, start=1):
        audio_p, stim_p = padding(audio, speech_stimulus_eeg, pad_function=pad_zeros_right)
        corr, lags = crosscorrelation(stim_p, audio_p)
        peak_idx = np.argmax(np.abs(corr)) 
        lag_samples = lags[peak_idx]
        # lag_seconds = lag_samples / fs
        lag_samples_list.append(lag_samples)
        # timeshift_s_list.append(lag_seconds)
    
    # 提取三个人的 Raw 对象
    subj1_ch = [str(i) for i in range(1, 33)]
    subj2_ch = [str(i) for i in range(33, 65)]
    subj3_ch = [str(i) for i in range(65, 97)]

    raw_list = [
        raw.copy().pick_channels(subj1_ch),
        raw.copy().pick_channels(subj2_ch),
        raw.copy().pick_channels(subj3_ch)
    ]

    # --- 3. 遍历 Session 内的 3 个 Subject ---
    for sub_idx, raw_sub in enumerate(raw_list):
        
        # 使用当前累计的 ID
        print(f"   正在处理 Subject {current_subject_id} (Session {sess_id} 的第 {sub_idx+1} 个人)...")

        # Set Montage & Rename (每个 Subject 独立进行)
        raw_sub.rename_channels(channel_map)
        raw_sub.set_montage(montage)

        # --- 预处理 & ICA (你的原有逻辑) ---
        # Highpass-filter for ICA
        raw_ica = raw_sub.copy().filter(1., 100., fir_design='firwin', verbose=False) # 建议关掉 verbose 减少刷屏
        raw_ica.set_eeg_reference('average', projection=False, verbose=False)

        # ICA
        ica = mne.preprocessing.ICA(n_components=15, method='infomax', fit_params=dict(extended=True), random_state=97, max_iter='auto')
        ica.fit(raw_ica, verbose=False)

        # ICLabel
        labels = label_components(raw_ica, ica, method='iclabel')
        # print(labels['labels']) # 可以注释掉减少刷屏
        
        artifact_comps = [i for i, label in enumerate(labels['labels']) if label in ['eye blink', 'muscle artifact']]
        ica.exclude = artifact_comps

        # Apply ICA
        raw_clean = ica.apply(raw_sub.copy(), verbose=False)

        # Notch & Filter
        raw_clean.notch_filter(freqs=50, verbose=False)
        delta = raw_clean.copy().filter(1., 4., fir_design='firwin', verbose=False)

        # --- 准备切分数据 ---
        eeg_mat = delta.get_data()
        eeg_trials = []
        env_trials = []

        # 切分
        for audio, lag in zip(envelopes_list, lag_samples_list):
            L = len(audio)
            start = int(lag)
            stop = start + L
            
            # 简单的边界检查，防止 crashing
            if stop > eeg_mat.shape[1]:
                print(f"Warning: Trial beyond data range for Sub {current_subject_id}")
                break 
                
            eeg_trial = eeg_mat[:, start:stop]
            env_trial = audio
            
            eeg_trials.append(eeg_trial)
            env_trials.append(env_trial)

        # --- 降采样 ---
        current_fs = 1000
        target_fs = 128
        eeg_trials_low_fs = []
        env_trials_low_fs = []
        
        # (这里使用你的降采样逻辑，为了简洁省略了中间 print)
        for eeg_t in eeg_trials:
            n_samples_new = int(eeg_t.shape[1] * target_fs / current_fs)
            eeg_low = np.zeros((eeg_t.shape[0], n_samples_new))
            for ch in range(eeg_t.shape[0]):
                eeg_low[ch, :] = resample(eeg_t[ch, :], n_samples_new)
            eeg_trials_low_fs.append(eeg_low)

        for env_t in env_trials:
            n_samples_new = int(env_t.shape[0] * target_fs / current_fs)
            env_low = resample(env_t, n_samples_new)
            env_trials_low_fs.append(env_low)

        # --- 4. 写入 HDF5 (关键修改) ---
        # 模式必须是 "a" (append)，否则会覆盖之前的数据
        with h5py.File(out_file, "a") as f:
            n_trials = len(eeg_trials_low_fs)
            
            print(f"      -> 写入 Subject {current_subject_id} 到 HDF5...")
            
            for t in range(n_trials):
                trial_id = t + 1
                
                eeg_data = np.asarray(eeg_trials_low_fs[t])
                env_data = np.asarray(env_trials_low_fs[t])

                # 路径中使用 动态的 current_subject_id
                eeg_path = f"eeg/{current_subject_id}/{trial_id}"
                
                # 防止重复写入报错（如果跑断了重跑）
                if eeg_path in f:
                    del f[eeg_path]

                eeg_ds = f.create_dataset(eeg_path, data=eeg_data)

                stim_code = f"s{current_subject_id}_t{trial_id}"
                eeg_ds.attrs["stimulus"] = stim_code

                # Envelope (注意：如果不同人的 Envelope 是一样的，其实存一份就够了，但为了结构统一分别存也可以)
                env_group_path = f"stimulus_files/{stim_code}"
                if env_group_path in f: # 如果存在则先删除，防止报错
                     del f[env_group_path]
                
                stim_group = f.require_group(env_group_path)
                stim_group.create_dataset("attended_env", data=env_data)

        # --- 5. 循环末尾 ID 自增 ---
        # 处理完一个人，ID + 1，为下一个人（无论是同Session还是下个Session）做准备
        current_subject_id += 1

print(f"\n所有 Session 处理完成。文件保存为: {out_file}")