In [None]:
import mne
import numpy as np
import xml.etree.ElementTree as ET
import math
import pywt
import os
import re
import glob
from mne import Annotations

def wavelet_denoise(data, wavelet='sym6', level=5, mode='soft', boundary_mode='symmetric'):
    if data.ndim != 2:
        raise ValueError(f"输入数据需为二维数组，当前维度：{data.ndim}")
    if mode not in ('soft', 'hard'):
        raise ValueError(f"无效阈值模式：{mode}，可选 'soft' 或 'hard'")
    
    denoised_data = np.zeros_like(data)
    
    for i in range(data.shape[0]):
        channel_data = data[i, :]
        try:
            max_level = pywt.dwt_max_level(len(channel_data), pywt.Wavelet(wavelet))
            actual_level = min(level, max_level) if max_level else level
            coeffs = pywt.wavedec(channel_data, wavelet, level=actual_level, mode=boundary_mode)
            detail_coeffs = coeffs[-1]
            sigma = np.median(np.abs(detail_coeffs)) / 0.6745 if len(detail_coeffs) else 0
            n = len(channel_data)
            threshold = sigma * np.sqrt(2 * np.log(n)) if sigma > 0 else 0
            coeffs_thresholded = [coeffs[0]]
            for j in range(1, len(coeffs)):
                c = coeffs[j]
                c_thresh = pywt.threshold(c, threshold, mode=mode)
                coeffs_thresholded.append(c_thresh)
            denoised_channel = pywt.waverec(coeffs_thresholded, wavelet, mode=boundary_mode)
            if len(denoised_channel) > len(channel_data):
                denoised_data[i, :] = denoised_channel[:len(channel_data)]
            else:
                denoised_data[i, :len(denoised_channel)] = denoised_channel
        except pywt.Error as e:
            raise RuntimeError(f"小波处理异常，通道{i}错误：{str(e)}") from e
        except Exception as e:
            raise RuntimeError(f"通道{i}处理失败：{str(e)}") from e
    return denoised_data

def generate_stage_annotations(xml_path, epoch_length=30, total_duration=None):
    tree = ET.parse(xml_path)
    root = tree.getroot()
    stage_events = []
    for event in root.iter('ScoredEvent'):
        event_type = event.findtext('EventType', '').strip()
        if event_type != 'Stages|Stages':
            continue
        concept = event.findtext('EventConcept', '')
        parts = [p.strip() for p in concept.split('|')]
        if len(parts) < 2 or not parts[1].isdigit():
            print(f"[WARNING] 未能正确解析 EventConcept: {concept}，跳过此事件")
            continue
        raw_label = int(parts[1])
        label = 4 if raw_label == 5 else raw_label
        if not (0 <= label <= 4):
            continue
        try:
            start = float(event.findtext('Start', '0'))
            duration = float(event.findtext('Duration', '0'))
        except ValueError as e:
            print(f"[ERROR] 解析 Start 或 Duration 时出错: {e}, EventConcept: {concept}")
            continue
        if duration <= 0 or start < 0:
            continue
        stage_events.append({'start': start, 'end': start + duration, 'label': label})
    max_time = total_duration if total_duration is not None else max([e['end'] for e in stage_events] + [0])
    total_epochs = math.ceil(max_time / epoch_length)
    annotations = [{'epoch': i, 'label': 0} for i in range(total_epochs)]
    for e in stage_events:
        start_epoch = max(0, int(e['start'] // epoch_length))
        end_epoch = min(total_epochs-1, int((e['end']-1e-9) // epoch_length))
        for i in range(start_epoch, end_epoch + 1):
            annotations[i]['label'] = e['label']
    return annotations

def process_psg_with_segmentation(edf_path, xml_path, output_dir, channels=None, epoch_length=30, resample_rate=None):
    try:
        raw = mne.io.read_raw(edf_path, preload=True, verbose=False)
        
        if resample_rate:
            print(f"[INFO] 下采样为 {resample_rate} Hz")
            raw.resample(sfreq=resample_rate)
        
        if channels:
            existing = [ch for ch in channels if ch in raw.ch_names]
            missing = set(channels) - set(existing)
            if missing:
                print(f"[WARNING] 缺失通道: {missing}")
            raw.pick_types(eeg=True, include=existing)

        stage_anns = generate_stage_annotations(xml_path)
        print(f"[INFO] 正在处理文件: {edf_path}")
        segment_idx = 0
        sfreq = raw.info['sfreq']
        samples_per_epoch = int(epoch_length * sfreq)
        data = raw.get_data()

        for epoch_idx in range(len(stage_anns)):
            stage_label = stage_anns[epoch_idx]['label']
            segment_data = data[:, epoch_idx * samples_per_epoch: (epoch_idx + 1) * samples_per_epoch]
            filename = f"ex{segment_idx:04d}_{stage_label}.npy"
            np.save(os.path.join(output_dir, filename), segment_data)
            segment_idx += 1
        print(f"[INFO] 完成文件: {edf_path} 的处理，共保存 {segment_idx} 个片段")
    except Exception as e:
        print(f"[ERROR] 处理文件 {edf_path} 时发生错误: {e}")
        return

# ==== 配置参数 ====
edf_dir = "E:/mesa/polysomnography/edfs"
xml_dir = "E:/mesa/polysomnography/annotations-events-nsrr"
output_base_dir = "G:/mesa_staging"

n_files = 3000
channels = ['EEG2', 'Flow', 'SpO2', 'Thor', 'Abdo']
denoise_enabled = True
normalize_enabled = True
resample_rate = 100  # 设置目标采样率，如不下采样则设为 None

edf_files = glob.glob(os.path.join(edf_dir, "mesa-sleep-*.edf"))
edf_files.sort(key=lambda x: int(re.search(r'mesa-sleep-(\d+)\.edf', os.path.basename(x)).group(1)))
if isinstance(n_files, int):
    edf_files = edf_files[:n_files]
    print(f"准备处理前 {n_files} 个文件")

total_files = len(edf_files)
processed_count = 0
failed_count = 0

for edf_path in edf_files:
    try:
        base_name = os.path.basename(edf_path)
        file_id = re.search(r'mesa-sleep-(\d+)\.edf', base_name).group(1)
        xml_filename = f"mesa-sleep-{file_id}-nsrr.xml"
        xml_path = os.path.join(xml_dir, xml_filename)
        if not os.path.exists(xml_path):
            print(f"[SKIP] 缺失标注文件: {xml_filename}")
            continue
        output_dir = os.path.join(output_base_dir, file_id)
        os.makedirs(output_dir, exist_ok=True)
        processed_count += 1
        print(f"\n{'='*50}")
        print(f"正在处理文件 ({processed_count}/{total_files}): {base_name}")
        process_psg_with_segmentation(
            edf_path=edf_path,
            xml_path=xml_path,
            output_dir=output_dir,
            channels=channels,
            epoch_length=30,
            resample_rate=resample_rate
        )
    except Exception as e:
        print(f"[ERROR] 处理文件 {base_name} 失败: {str(e)}")
        failed_count += 1
        continue

print("\n批量处理完成！处理情况：")
print(f"总文件数: {total_files}")
print(f"成功处理: {processed_count}")
print(f"失败文件: {failed_count}")
