In [None]:
import os
import numpy as np
import mne
from tqdm import tqdm
import xml.etree.ElementTree as ET
import re

ANNOTATION_MAP = {
    "Wake|0": 0, 
    "Stage 1 sleep|1": 1, 
    "Stage 2 sleep|2": 2, 
    "Stage 3 sleep|3": 3,
    "Stage 4 sleep|4": 3,
    "REM sleep|5": 4,
    "Unscored|9": 5,
    "Movement|6": 6
}

OSA_EVENTS = ['Central apnea', 'Hypopnea', 'Obstructive apnea']
CHANNEL_PRIORITY = {
    'EOG': ['EOG(L)', 'EOG', 'EOGL'],
    'EMG': ['EMG'],
    'AIRFLOW': ['NEW AIR', 'AIRFLOW']
}

def _safe_tmax_for_crop(raw, tmax, file_id="unknown"):
    """安全裁剪时间范围避免越界"""
    last = float(raw.times[-1])
    t0 = float(tmax)
    tmax = min(t0, last)
    if not (tmax < last):
        tmax = np.nextafter(last, -np.inf)
    if tmax < 0.0:
        tmax = 0.0
    return float(tmax)

def normalize_per_channel(signal):
    """对每个通道进行标准化处理"""
    epsilon = 1e-8
    means = np.mean(signal, axis=1, keepdims=True)
    stds = np.std(signal, axis=1, keepdims=True) + epsilon
    return (signal - means) / stds

def read_and_preprocess_edf(file_path):
    """读取并预处理EDF文件"""
    try:
        raw = mne.io.read_raw_edf(file_path, preload=True, verbose='ERROR')
        if 'EEG' not in raw.ch_names:
            return None, None
        if not any(ch in raw.ch_names for ch in ['NEW AIR', 'AIRFLOW']):
            return None, None
        ch_name_mapping = {}
        used_target_names = set()
        for target_name, variants in CHANNEL_PRIORITY.items():
            for ch in variants:
                if ch in raw.ch_names:
                    if target_name not in raw.ch_names and target_name not in used_target_names:
                        ch_name_mapping[ch] = target_name
                        used_target_names.add(target_name)
                        break
        if ch_name_mapping:
            raw.rename_channels(ch_name_mapping)
        ch_type_mapping = {
            'EOG': 'eog',
            'EMG': 'emg',
            'AIRFLOW': 'misc'
        }
        raw.set_channel_types(ch_type_mapping)
        duration = raw.n_times / raw.info['sfreq']
        return raw, duration
    except Exception as e:
        return None, None

def select_channels(raw):
    """选择需要的信号通道"""
    if 'EEG' not in raw.info['ch_names']:
        return None
    eeg_channel = 'EEG'
    eeg2_variants = ['EEG2', 'EEG 2', 'EEG(SEC)', 'EEG(sec)']
    eeg2_channel = next((ch for ch in eeg2_variants if ch in raw.info['ch_names']), None)
    if eeg2_channel is None:
        return None
    eog_channel = next((ch for ch in raw.info['ch_names'] if 'EOG' in ch), None)
    emg_channel = next((ch for ch in raw.info['ch_names'] if 'EMG' in ch), None)
    if not eog_channel or not emg_channel:
        return None
    if 'NEW AIR' in raw.info['ch_names']:
        airflow_channel = 'NEW AIR'
    elif 'AIRFLOW' in raw.info['ch_names']:
        airflow_channel = 'AIRFLOW'
    else:
        return None
    return [eeg2_channel, eeg_channel, eog_channel, emg_channel, airflow_channel]

def parse_sleep_annotations(annotation_path):
    """解析睡眠分期注释"""
    try:
        tree = ET.parse(annotation_path)
        root = tree.getroot()
        events = []
        for scored_event in root.findall('.//ScoredEvent'):
            event_type = scored_event.find('EventType').text
            if event_type != "Stages|Stages":
                continue
            description = scored_event.find('EventConcept').text
            start = float(scored_event.find('Start').text)
            duration = float(scored_event.find('Duration').text)
            if description not in ANNOTATION_MAP:
                continue
            events.append({
                'onset': start,
                'duration': duration,
                'description': description,
                'stage': ANNOTATION_MAP[description]
            })
        return events
    except Exception as e:
        return None

def extract_osa_events(annotation_path):
    """提取睡眠呼吸暂停事件"""
    try:
        tree = ET.parse(annotation_path)
        root = tree.getroot()
        events = []
        for scored_event in root.findall('.//ScoredEvent'):
            event_concept = scored_event.find('EventConcept').text
            event_type = event_concept.split('|')[0].strip()
            if event_type in OSA_EVENTS:
                start = float(scored_event.find('Start').text)
                duration = float(scored_event.find('Duration').text)
                if duration >= 10:
                    events.append({
                        'start': start,
                        'duration': duration,
                        'end': start + duration,
                        'type': event_type
                    })
        return events
    except Exception as e:
        return []

def process_raw_data(raw, sleep_events, resample_freq=100):
    """处理原始数据并创建epochs"""
    annotations = mne.Annotations(
        onset=[e['onset'] for e in sleep_events],
        duration=[e['duration'] for e in sleep_events],
        description=[e['description'] for e in sleep_events]
    )
    raw.set_annotations(annotations)
    if 'EEG' in raw.ch_names:
        raw.filter(l_freq=0.1, h_freq=40, picks=['EEG'], method='fir', fir_window='hamming', phase='zero')
    data_array = raw.get_data()
    normalized_data = normalize_per_channel(data_array)
    raw._data = normalized_data
    if raw.info['sfreq'] > resample_freq:
        raw.resample(resample_freq, npad='auto')
    events_from_annot, event_id = mne.events_from_annotations(
        raw, 
        event_id=ANNOTATION_MAP,
        chunk_duration=30.0
    )
    tmax = 30.0 - 1.0 / raw.info['sfreq']
    safe_tmax = _safe_tmax_for_crop(raw, tmax)
    try:
        epochs = mne.Epochs(
            raw,
            events=events_from_annot,
            event_id=event_id,
            tmin=0.0,
            tmax=safe_tmax,
            baseline=None,
            preload=True,
            verbose=False
        )
        return epochs
    except ValueError as e:
        return None

def process_subject(raw, epochs, osa_events, file_index):
    """处理单个受试者数据并保存结果"""
    sfreq = raw.info['sfreq']
    saved_count = 0
    output_dir = f"F:/SHHS2_apnea/SHHS2_{file_index:04d}"
    os.makedirs(output_dir, exist_ok=True)
    for i in range(len(epochs)):
        stage = epochs.events[i, 2]
        if stage in [0, 5, 6]:
            continue
        start_sample = epochs.events[i, 0]
        epoch_start = start_sample / sfreq
        epoch_end = epoch_start + 30.0
        osa_label = 0
        has_significant_overlap = False
        for event in osa_events:
            event_end = event['start'] + event['duration']
            overlap_start = max(event['start'], epoch_start)
            overlap_end = min(event_end, epoch_end)
            overlap_duration = overlap_end - overlap_start
            if overlap_duration >= 6:
                has_significant_overlap = True
                break
        if has_significant_overlap:
            osa_label = 1
        epoch_data = epochs.get_data(item=i)[0]
        file_name = f"e_{saved_count:04d}_{osa_label}.npy"
        np.save(os.path.join(output_dir, file_name), epoch_data)
        saved_count += 1
    return saved_count

def match_psg_annotation_files(psg_dir, annotation_dir):
    """匹配PSG和标注文件"""
    psg_files = [f for f in os.listdir(psg_dir) if f.endswith('.edf')]
    file_map = {}
    for psg_file in psg_files:
        base_id = re.sub(r'\.edf$', '', psg_file)
        base_id = re.sub(r'[\-_].*$', '', base_id)
        annotation_candidates = [
            f for f in os.listdir(annotation_dir)
            if f.startswith(base_id) and f.endswith('.xml')
        ]
        if annotation_candidates:
            annotation_file = sorted(annotation_candidates, key=len, reverse=True)[0]
            file_map[psg_file] = annotation_file
    return file_map

def main():
    """主处理函数"""
    ANNOTATION_DIR = "D:/shhs/polysomnography/annotations-events-nsrr/shhs2"
    PSG_DIR = "D:/shhs/polysomnography/edfs/shhs2"
    OUTPUT_DIR = "F:/SHHS2_apnea"
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    file_map = match_psg_annotation_files(PSG_DIR, ANNOTATION_DIR)
    processed_files = 0
    total_epochs = 0
    for psg_file, annotation_file in tqdm(file_map.items(), desc="Processing files"):
        psg_path = os.path.join(PSG_DIR, psg_file)
        annotation_path = os.path.join(ANNOTATION_DIR, annotation_file)
        raw, duration = read_and_preprocess_edf(psg_path)
        if raw is None:
            continue
        picked_channels = select_channels(raw)
        if picked_channels is None:
            continue
        sleep_events = parse_sleep_annotations(annotation_path)
        if sleep_events is None or not sleep_events:
            continue
        osa_events = extract_osa_events(annotation_path)
        raw.pick_channels(picked_channels)
        epochs = process_raw_data(raw, sleep_events, resample_freq=100)
        if epochs is None or len(epochs) == 0:
            continue
        file_index = processed_files + 1
        epoch_count = process_subject(raw, epochs, osa_events, file_index)
        total_epochs += epoch_count
        processed_files += 1

if __name__ == "__main__":
    main()