In [None]:
import mne
import numpy as np
import xml.etree.ElementTree as ET
import math
import os
import re
import glob
import pywt

# ================== 参数 ====================
edf_dir = "E:/mesa/polysomnography/edfs"
xml_dir = "E:/mesa/polysomnography/annotations-events-nsrr"
output_base_dir = "G:/mesa_apnea"

channels = ['EEG2', 'Flow', 'SpO2', 'Thor', 'Abdo']
n_files = 100  # 设置为 None 处理所有
resample_rate = 100  # Hz，若不想下采样则设为 None
denoise_enabled = True
normalize_enabled = True
epoch_length = 30  # 秒
# ===========================================


def wavelet_denoise(data, wavelet='sym6', level=5, mode='soft', boundary_mode='symmetric'):
    if data.ndim != 2:
        raise ValueError("输入数据必须为二维")
    denoised_data = np.zeros_like(data)
    for i in range(data.shape[0]):
        signal = data[i, :]
        try:
            max_level = pywt.dwt_max_level(len(signal), pywt.Wavelet(wavelet))
            level = min(level, max_level)
            coeffs = pywt.wavedec(signal, wavelet, level=level, mode=boundary_mode)
            sigma = np.median(np.abs(coeffs[-1])) / 0.6745
            threshold = sigma * np.sqrt(2 * np.log(len(signal)))
            coeffs_thresh = [coeffs[0]] + [pywt.threshold(c, threshold, mode=mode) for c in coeffs[1:]]
            rec = pywt.waverec(coeffs_thresh, wavelet, mode=boundary_mode)
            denoised_data[i, :len(rec)] = rec[:len(signal)]
        except Exception as e:
            raise RuntimeError(f"通道 {i} 去噪失败: {e}")
    return denoised_data

def generate_stage_annotations(xml_path, epoch_length=30, total_duration=None):
    tree = ET.parse(xml_path)
    root = tree.getroot()
    stage_events = []
    for event in root.iter('ScoredEvent'):
        if event.findtext('EventType', '').strip() != 'Stages|Stages':
            continue
        concept = event.findtext('EventConcept', '')
        parts = concept.split('|')
        if len(parts) < 2 or not parts[1].isdigit():
            continue
        raw_label = int(parts[1])
        label = 4 if raw_label == 5 else raw_label  # REM=4
        if not (0 <= label <= 4):
            continue
        start = float(event.findtext('Start', '0'))
        duration = float(event.findtext('Duration', '0'))
        stage_events.append({'start': start, 'end': start + duration, 'label': label})

    max_time = total_duration if total_duration is not None else max([e['end'] for e in stage_events] + [0])
    total_epochs = math.ceil(max_time / epoch_length)
    annotations = [{'epoch': i, 'label': 0} for i in range(total_epochs)]
    for e in stage_events:
        start_ep = int(e['start'] // epoch_length)
        end_ep = int((e['end'] - 1e-9) // epoch_length)
        for i in range(start_ep, end_ep + 1):
            if 0 <= i < total_epochs:
                annotations[i]['label'] = e['label']
    return annotations

def generate_apnea_annotations_binary(xml_path, epoch_length=30):
    tree = ET.parse(xml_path)
    root = tree.getroot()
    event_map = {
        'Central apnea': 1,
        'Hypopnea': 1,
        'Obstructive apnea': 1
    }
    apnea_events = []
    max_time = 0
    for event in root.iter('ScoredEvent'):
        concept = event.findtext('EventConcept', '').split('|')[0].strip()
        if concept not in event_map:
            continue
        start = float(event.findtext('Start', '0'))
        duration = float(event.findtext('Duration', '0'))
        end = start + duration
        apnea_events.append({'start': start, 'end': end, 'label': 1})
        max_time = max(max_time, end)

    total_epochs = math.ceil(max_time / epoch_length)
    annotations = [{'epoch': i, 'label': 0} for i in range(total_epochs)]
    for e in apnea_events:
        start_ep = int(e['start'] // epoch_length)
        end_ep = int((e['end'] - 1e-9) // epoch_length)
        for i in range(start_ep, end_ep + 1):
            if 0 <= i < total_epochs:
                annotations[i]['label'] = 1
    return annotations, max_time

def process_apnea_epoch_segments(edf_path, xml_path, output_dir):
    apnea_anns, apnea_max_time = generate_apnea_annotations_binary(xml_path, epoch_length)
    stage_anns = generate_stage_annotations(xml_path, epoch_length, apnea_max_time)

    raw = mne.io.read_raw(edf_path, preload=True, verbose=False)

    if resample_rate:
        raw.resample(resample_rate)

    if channels:
        exist = [ch for ch in channels if ch in raw.ch_names]
        raw.pick_channels(exist)

    sfreq = raw.info['sfreq']
    duration = raw.n_times / sfreq
    max_time = min(duration, apnea_max_time)
    aligned_time = (max_time // epoch_length) * epoch_length
    raw.crop(tmax=aligned_time)

    data = raw.get_data()
    total_epochs = int(aligned_time // epoch_length)
    apnea_anns = apnea_anns[:total_epochs]
    stage_anns = stage_anns[:total_epochs]
    samples_per_epoch = int(epoch_length * sfreq)

    if denoise_enabled:
        eeg_chs = [i for i, name in enumerate(raw.ch_names) if 'EEG' in name]
        if eeg_chs:
            data[eeg_chs, :] = wavelet_denoise(data[eeg_chs, :])

    if normalize_enabled:
        mask = np.array([ch != 'SpO2' for ch in raw.ch_names])
        if np.any(mask):
            mean = data[mask].mean(axis=1, keepdims=True)
            std = data[mask].std(axis=1, keepdims=True)
            std[std == 0] = 1
            data[mask] = (data[mask] - mean) / std

    os.makedirs(output_dir, exist_ok=True)
    saved = 0
    for i in range(total_epochs):
        if stage_anns[i]['label'] == 0:
            continue  # 跳过醒期
        label = apnea_anns[i]['label']
        seg = data[:, i * samples_per_epoch: (i + 1) * samples_per_epoch]
        filename = f"ep{i:04d}_label{label}.npy"
        np.save(os.path.join(output_dir, filename), seg)
        saved += 1

    print(f"[INFO] {os.path.basename(edf_path)}: 共保存 {saved} 个非清醒期片段")

# ========== 主循环 ==========
if __name__ == "__main__":
    edf_files = glob.glob(os.path.join(edf_dir, "mesa-sleep-*.edf"))
    edf_files.sort(key=lambda x: int(re.search(r"mesa-sleep-(\d+)\.edf", x).group(1)))

    if isinstance(n_files, int):
        edf_files = edf_files[:n_files]

    print(f"[INFO] 处理 {len(edf_files)} 个文件")
    for idx, edf_path in enumerate(edf_files, 1):
        try:
            file_id = re.search(r"mesa-sleep-(\d+)\.edf", os.path.basename(edf_path)).group(1)
            xml_path = os.path.join(xml_dir, f"mesa-sleep-{file_id}-nsrr.xml")
            output_dir = os.path.join(output_base_dir, file_id)
            if not os.path.exists(xml_path):
                print(f"[SKIP] 缺失标注文件: {xml_path}")
                continue
            print(f"\n[{idx}/{len(edf_files)}] 处理文件: {file_id}")
            process_apnea_epoch_segments(edf_path, xml_path, output_dir)
        except Exception as e:
            print(f"[ERROR] 文件 {edf_path} 处理失败: {e}")
