In [None]:
import mne
import numpy as np
import xml.etree.ElementTree as ET
import math
import os
import re
import pywt
import logging
import glob
import warnings

edf_dir = "/data02/mesa/edfs"
xml_dir = "/data02/mesa/annotations-events-nsrr"
output_base_dir = "/data02/mesa/apnea2"
channels = ['EEG1', 'EEG3', 'EOG-L', 'EMG', 'Pres']
n_files = 3000
resample_rate = 100
epoch_length = 30
log_file = 'log.txt'

logging.basicConfig(
    filename=log_file, 
    level=logging.ERROR, 
    format='%(asctime)s - %(levelname)s - %(message)s',
    filemode='w'
)

def _safe_tmax_for_crop(raw, tmax, file_id="unknown"):
    """安全裁剪时间范围避免越界"""
    import numpy as _np
    last = float(raw.times[-1])
    t0 = float(tmax)
    tmax = min(t0, last)
    if not (tmax < last):
        tmax = _np.nextafter(last, -_np.inf)
    if tmax < 0.0:
        tmax = 0.0
    return float(tmax)

def _get_max_end_time(xml_or_root, event_type):
    """获取事件的最大结束时间"""
    if isinstance(xml_or_root, str):
        tree = ET.parse(xml_or_root)
        root = tree.getroot()
    else:
        root = xml_or_root
    max_time = 0.0
    for ev in root.iter('ScoredEvent'):
        if ev.findtext('EventType', '').strip() != event_type:
            continue
        try:
            start = float(ev.findtext('Start', '0') or 0)
            dur = float(ev.findtext('Duration', '0') or 0)
        except ValueError:
            start, dur = 0.0, 0.0
        max_time = max(max_time, start + dur)
    return max_time

def pick_channels_in_order(raw, desired_names):
    """按指定顺序选择通道"""
    exists = set(raw.ch_names)
    to_keep = [ch for ch in desired_names if ch in exists]
    if not to_keep:
        return
    raw.pick(to_keep)
    if raw.ch_names != to_keep:
        raw.reorder_channels(to_keep)

def wavelet_denoise(signal, wavelet='sym6', level=5, mode='soft', boundary_mode='symmetric'):
    """小波去噪处理"""
    if signal.ndim != 2:
        raise ValueError("输入数据必须为二维")
    denoised = np.zeros_like(signal)
    for i in range(signal.shape[0]):
        ch = signal[i]
        max_lvl = pywt.dwt_max_level(len(ch), pywt.Wavelet(wavelet))
        actual_lvl = min(level, max_lvl)
        coeffs = pywt.wavedec(ch, wavelet, level=actual_lvl, mode=boundary_mode)
        sigma = np.median(np.abs(coeffs[-1])) / 0.6745
        threshold = sigma * np.sqrt(2 * np.log(len(ch)))
        coeffs_thresh = [coeffs[0]] + [pywt.threshold(c, threshold, mode=mode) for c in coeffs[1:]]
        ch_denoised = pywt.waverec(coeffs_thresh, wavelet, mode=boundary_mode)
        denoised[i, :len(ch_denoised)] = ch_denoised[:len(ch)]
    return denoised

def standardize_signal(signal):
    """信号标准化"""
    mean = np.mean(signal)
    std = np.std(signal)
    return (signal - mean) / (std if std > 0 else 1)

def preprocess_channel(signal, ch_name, fs):
    """预处理单个通道"""
    if ch_name in ['EEG1', 'EEG3']:
        return standardize_signal(wavelet_denoise(signal[np.newaxis, :])[0])
    elif ch_name == 'EOG-L':
        return standardize_signal(wavelet_denoise(signal[np.newaxis, :], level=3)[0])
    elif ch_name in ['EMG', 'Pres']:
        return signal
    else:
        raise ValueError(f"未知通道: {ch_name}")

def preprocess_all_channels(data, ch_names, fs):
    """预处理所有通道"""
    processed = np.zeros_like(data)
    for i, ch in enumerate(ch_names):
        processed[i] = preprocess_channel(data[i], ch, fs)
    return processed

def generate_apnea_annotations_with_coverage(xml_path, epoch_length=30, total_duration=None):
    """生成呼吸暂停标注"""
    tree = ET.parse(xml_path)
    root = tree.getroot()
    event_map = {
        'Central apnea': 1,
        'Hypopnea': 1,
        'Obstructive apnea': 1,
        'Mixed apnea': 1
    }
    apnea_events = []
    max_time = 0
    for event in root.iter('ScoredEvent'):
        concept = event.findtext('EventConcept', '').split('|')[0].strip()
        event_type = event.findtext('EventType', '').strip()
        if concept in event_map or 'apnea' in concept.lower() or 'hypopnea' in concept.lower():
            try:
                start = float(event.findtext('Start', '0'))
                duration = float(event.findtext('Duration', '0'))
                if duration <= 0:
                    continue
                end = start + duration
                if duration >= 10:
                    apnea_events.append({
                        'start': start, 
                        'end': end, 
                        'label': 1, 
                        'type': concept,
                        'duration': duration
                    })
                    max_time = max(max_time, end)
            except (ValueError, TypeError):
                continue
    if total_duration is not None:
        max_time = total_duration
    total_epochs = math.ceil(max_time / epoch_length)
    annotations = []
    for i in range(total_epochs):
        epoch_start = i * epoch_length
        epoch_end = (i + 1) * epoch_length
        annotations.append({
            'epoch': i,
            'start': epoch_start,
            'end': epoch_end,
            'label': 0,
            'original_label': 0,
            'events': [],
            'has_significant_overlap': False,
            'corrected': False
        })
    apnea_events.sort(key=lambda x: x['start'])
    for event_idx, event in enumerate(apnea_events):
        s_ep = int(event['start'] // epoch_length)
        e_ep = int((event['end'] - 1e-9) // epoch_length)
        for i in range(s_ep, e_ep + 1):
            if 0 <= i < total_epochs:
                epoch_start = i * epoch_length
                epoch_end = (i + 1) * epoch_length
                overlap_start = max(event['start'], epoch_start)
                overlap_end = min(event['end'], epoch_end)
                overlap_duration = overlap_end - overlap_start
                annotations[i]['events'].append({
                    'event_start': event['start'],
                    'event_end': event['end'],
                    'overlap_duration': overlap_duration,
                    'event_type': event['type'],
                    'event_idx': event_idx,
                    'total_duration': event['duration']
                })
                if overlap_duration >= 6:
                    annotations[i]['has_significant_overlap'] = True
    for i in range(total_epochs):
        original_label = annotations[i]['label']
        if annotations[i]['has_significant_overlap']:
            new_label = 1
            if original_label == 0:
                annotations[i]['corrected'] = True
            annotations[i]['label'] = new_label
    return annotations, max_time

def generate_stage_annotations(xml_path, epoch_length=30, total_duration=None):
    """生成睡眠分期标注"""
    tree = ET.parse(xml_path)
    root = tree.getroot()
    events = []
    for ev in root.iter('ScoredEvent'):
        if ev.findtext('EventType') != 'Stages|Stages':
            continue
        concept = ev.findtext('EventConcept', '')
        parts = concept.split('|')
        if len(parts) < 2 or not parts[1].isdigit():
            continue
        label = int(parts[1])
        label = 4 if label == 5 else label
        try:
            start = float(ev.findtext('Start', '0'))
            duration = float(ev.findtext('Duration', '0'))
            end = start + duration
            if duration <= 0:
                raise ValueError("Duration must be greater than zero.")
            events.append({'start': start, 'end': end, 'label': label})
        except (ValueError, TypeError):
            continue
    max_time = total_duration or max([e['end'] for e in events] + [0])
    total_epochs = math.ceil(max_time / epoch_length)
    annotations = [{'epoch': i, 'label': 0} for i in range(total_epochs)]
    for e in events:
        s = int(e['start'] // epoch_length)
        e_ = int((e['end'] - 1e-9) // epoch_length)
        for i in range(s, min(e_ + 1, total_epochs)):
            annotations[i]['label'] = e['label']
    return annotations, max_time

def process_apnea_epoch_segments(edf_path, xml_path, output_dir):
    """处理呼吸暂停数据段"""
    base_name = os.path.basename(edf_path)
    file_id_match = re.search(r'mesa-sleep-(\d+)', base_name)
    file_id = file_id_match.group(1) if file_id_match else "unknown"
    raw = mne.io.read_raw(edf_path, preload=True, verbose=False)
    if channels:
        existing = [ch for ch in channels if ch in raw.ch_names]
        if existing:
            pick_channels_in_order(raw, existing)
    edf_last_time = float(raw.times[-1])
    resp_max_time = _get_max_end_time(xml_path, 'Respiratory|Respiratory')
    stage_max_time = _get_max_end_time(xml_path, 'Stages|Stages')
    max_bound = min(edf_last_time, resp_max_time, stage_max_time)
    adjusted_max_time = (math.floor(max_bound / epoch_length)) * epoch_length
    raw.crop(tmax=_safe_tmax_for_crop(raw, adjusted_max_time, file_id))
    if resample_rate:
        raw.resample(resample_rate)
    try:
        apnea_anns, apnea_max = generate_apnea_annotations_with_coverage(xml_path, epoch_length, adjusted_max_time)
    except Exception:
        total_epochs_fallback = int(adjusted_max_time // epoch_length)
        apnea_anns = [{'epoch': i, 'label': 0} for i in range(total_epochs_fallback)]
        apnea_max = adjusted_max_time
    stage_anns, stage_max = generate_stage_annotations(xml_path, epoch_length, adjusted_max_time)
    fs = raw.info['sfreq']
    data = raw.get_data()
    ch_names = raw.ch_names
    data = preprocess_all_channels(data, ch_names, fs)
    total_epochs = int(adjusted_max_time // epoch_length)
    samples_per_epoch = int(epoch_length * fs)
    min_epochs = min(total_epochs, len(apnea_anns), len(stage_anns))
    apnea_anns = apnea_anns[:min_epochs]
    stage_anns = stage_anns[:min_epochs]
    actual_epochs = min_epochs
    os.makedirs(output_dir, exist_ok=True)
    for i in range(actual_epochs):
        if i >= len(stage_anns) or i >= len(apnea_anns):
            break
        if stage_anns[i]['label'] == 0:
            continue
        start_idx = i * samples_per_epoch
        end_idx = (i + 1) * samples_per_epoch
        if end_idx > data.shape[1]:
            break
        seg = data[:, start_idx:end_idx]
        final_label = apnea_anns[i]['label']
        fname = os.path.join(output_dir, f"ep{i:04d}_label{final_label}.npy")
        np.save(fname, seg)

    print(f"处理文件 {file_id}")

if __name__ == "__main__":
    edf_files = glob.glob(os.path.join(edf_dir, "mesa-sleep-*.edf"))
    edf_files.sort(key=lambda x: int(re.search(r"mesa-sleep-(\d+)\.edf", os.path.basename(x)).group(1)))
    if isinstance(n_files, int):
        edf_files = edf_files[:n_files]
    print(f"开始处理 {len(edf_files)} 个文件")
    for idx, edf_path in enumerate(edf_files, 1):
        try:
            file_id = re.search(r"mesa-sleep-(\d+)\.edf", os.path.basename(edf_path)).group(1)
            xml_path = os.path.join(xml_dir, f"mesa-sleep-{file_id}-nsrr.xml")
            output_dir = os.path.join(output_base_dir, file_id)
            if not os.path.exists(xml_path):
                continue
            process_apnea_epoch_segments(edf_path, xml_path, output_dir)
        except Exception:
            continue
    print("\n所有文件处理完成")