In [None]:
import os
import numpy as np
import pyedflib
from scipy.signal import resample_poly
from tqdm import tqdm

# ==== 配置路径 ====
base_path = r"F:\PSG_audio\V3\APNEA_EDF"        # 原始EDF目录
save_base = r"F:\PSG_audio\V3\processed"        # 输出根目录

target_channels = [
    'EEG C3-A2',
    'EEG C4-A1',
    'EOG LOC-A2',
    'EMG Chin',
    'Flow Patient (Pressure Cannula)'
]

target_sampling_rate = 100
segment_duration = 30
segment_length = target_sampling_rate * segment_duration

# ==== 工具函数 ====

def find_channel_indices(targets, actual_labels):
    indices = []
    missing = []
    for ch in targets:
        match = [i for i, label in enumerate(actual_labels) if ch.lower() in label.lower()]
        if match:
            indices.append(match[0])
        else:
            missing.append(ch)
    if missing:
        raise ValueError(f"通道未找到: {missing}")
    return indices

def stable_resample(signal, orig_rate, target_rate):
    gcd = np.gcd(int(orig_rate), int(target_rate))
    up = int(target_rate // gcd)
    down = int(orig_rate // gcd)
    return resample_poly(signal, up, down)

def normalize(signal):
    mean = np.mean(signal)
    std = np.std(signal)
    return (signal - mean) / std if std > 0 else signal

def process_edf_file(edf_path, save_dir, epoch_counter):
    """处理一个EDF文件，并将切片存入 save_dir，epoch_counter 为该病人的累计段数"""
    try:
        with pyedflib.EdfReader(edf_path) as f:
            labels = f.getSignalLabels()
            freqs = f.getSampleFrequencies()

            # 修复标签
            fixed_labels = labels.copy()
            flow_indices = [i for i, label in enumerate(labels) if label.lower() == 'flow patient']
            if len(flow_indices) == 2:
                fixed_labels[flow_indices[0]] = 'Flow Patient (Thermistor)'
                fixed_labels[flow_indices[1]] = 'Flow Patient (Pressure Cannula)'

            selected_indices = find_channel_indices(target_channels, fixed_labels)

            data_list = []
            for idx in selected_indices:
                raw = f.readSignal(idx)
                orig_fs = freqs[idx]
                resampled = stable_resample(raw, orig_fs, target_sampling_rate) if orig_fs != target_sampling_rate else raw
                data_list.append(normalize(resampled))

            min_len = min(len(sig) for sig in data_list)
            aligned = np.stack([sig[:min_len] for sig in data_list], axis=0)
            total_segments = min_len // segment_length

            os.makedirs(save_dir, exist_ok=True)

            for j in range(total_segments):
                segment = aligned[:, j * segment_length:(j + 1) * segment_length]
                epoch_filename = f"e{epoch_counter:04}.npy"
                np.save(os.path.join(save_dir, epoch_filename), segment)
                epoch_counter += 1

            return total_segments, epoch_counter
    except Exception as e:
        print(f"处理失败：{edf_path} | 错误：{str(e)}")
        return 0, epoch_counter

# ==== 主处理逻辑 ====

# 收集所有病人目录及EDF路径
patient_dirs = sorted([os.path.join(base_path, d) for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))])

print(f"共发现 {len(patient_dirs)} 个病人目录，开始处理...")
total_epochs = 0

for patient_idx, patient_path in enumerate(tqdm(patient_dirs, desc="病人处理")):
    patient_id = f"sub{patient_idx+1:03}"  # 构造输出子目录名
    save_dir = os.path.join(save_base, patient_id)
    epoch_counter = 0  # 该病人的epoch编号

    # 处理该病人所有EDF文件
    edf_files = [f for f in os.listdir(patient_path) if f.lower().endswith(".edf")]
    edf_files.sort()

    for edf_file in edf_files:
        edf_path = os.path.join(patient_path, edf_file)
        segments, epoch_counter = process_edf_file(edf_path, save_dir, epoch_counter)
        total_epochs += segments

print(f"\n所有EDF文件处理完成！共生成 {total_epochs} 个数据段。")
