In [None]:
import os
import numpy as np
import mne
from tqdm import tqdm
import xml.etree.ElementTree as ET
import re

ANNOTATION_MAP = {
    "Wake|0": 0, 
    "Stage 1 sleep|1": 1, 
    "Stage 2 sleep|2": 2, 
    "Stage 3 sleep|3": 3,
    "Stage 4 sleep|4": 3,  # Stage 4 映射到 N3
    "REM sleep|5": 4,
    "Unscored|9": 5,
    "Movement|6": 6
}

CHANNEL_PRIORITY = {
    'EOG': ['EOG(L)', 'EOG', 'EOGL'],
    'EMG': ['EMG'],
    'ECG': ['ECG'],
    'AIRFLOW': ['NEW AIR', 'AIRFLOW']
}

STAGE_NAMES = {0: "W", 1: "N1", 2: "N2", 3: "N3", 4: "R"}

def normalize_per_channel(signal):
    epsilon = 1e-8
    means = np.mean(signal, axis=1, keepdims=True)
    stds = np.std(signal, axis=1, keepdims=True) + epsilon
    return (signal - means) / stds

def read_and_preprocess_edf(file_path):
    """
    读取EDF文件并进行预处理（避免重命名冲突，只设置必要通道类型）
    返回 Raw 对象和持续时间
    """
    try:
        raw = mne.io.read_raw_edf(
            file_path, 
            infer_types=True, 
            preload=True, 
            verbose='ERROR'
        )

        # 必须有 EEG 和 AIRFLOW（NEW AIR 或 AIRFLOW）
        if 'EEG' not in raw.ch_names:
            print(f"文件 {os.path.basename(file_path)} 缺少主 EEG 通道")
            return None, None
        if not any(ch in raw.ch_names for ch in ['NEW AIR', 'AIRFLOW']):
            print(f"文件 {os.path.basename(file_path)} 缺少 AIRFLOW 通道")
            return None, None

        # 安全构建通道重命名映射（排除冲突）
        ch_name_mapping = {}
        used_target_names = set()
        for target_name, variants in CHANNEL_PRIORITY.items():
            for ch in variants:
                if ch in raw.ch_names:
                    # 避免重命名为已经存在的通道名，例如 EEG → EEG
                    if target_name not in raw.ch_names and target_name not in used_target_names:
                        ch_name_mapping[ch] = target_name
                        used_target_names.add(target_name)
                        break

        if ch_name_mapping:
            raw.rename_channels(ch_name_mapping)

        # 设置通道类型（不包括 EEG）
        ch_type_mapping = {
            'EOG': 'eog',
            'EMG': 'emg',
            'ECG': 'ecg',
            'AIRFLOW': 'misc'
        }
        raw.set_channel_types(ch_type_mapping)

        duration = raw.n_times / raw.info['sfreq']
        return raw, duration

    except Exception as e:
        print(f"读取EDF文件失败: {file_path}, 错误: {str(e)}")
        return None, None
    
def select_channels(raw):
    if 'EEG' not in raw.info['ch_names']:
        print("主 EEG 通道 'EEG' 缺失，跳过此文件。")
        return None
    eeg_channel = 'EEG'

    eeg2_variants = ['EEG2', 'EEG 2', 'EEG(SEC)', 'EEG(sec)']
    eeg2_channel = next((ch for ch in eeg2_variants if ch in raw.info['ch_names']), None)
    if eeg2_channel is None:
        print("第二 EEG 通道缺失，跳过此文件。")
        return None

    eog_channel = next((ch for ch in raw.info['ch_names'] if 'EOG' in ch), None)
    emg_channel = next((ch for ch in raw.info['ch_names'] if 'EMG' in ch), None)
    if not eog_channel or not emg_channel:
        print("EOG 或 EMG 通道缺失，跳过此文件。")
        return None

    if 'NEW AIR' in raw.info['ch_names']:
        airflow_channel = 'NEW AIR'
    elif 'AIRFLOW' in raw.info['ch_names']:
        airflow_channel = 'AIRFLOW'
    else:
        print("AIRFLOW 通道缺失，跳过此文件。")
        return None

    return [eeg2_channel, eeg_channel, eog_channel, emg_channel, airflow_channel]

def parse_sleep_annotations(annotations_path):
    try:
        tree = ET.parse(annotations_path)
        root = tree.getroot()

        events = []
        for scored_event in root.findall('.//ScoredEvent'):
            event_type = scored_event.find('EventType').text
            if event_type != "Stages|Stages":
                continue

            description = scored_event.find('EventConcept').text
            start = float(scored_event.find('Start').text)
            duration = float(scored_event.find('Duration').text)

            if description not in ANNOTATION_MAP:
                continue

            events.append({
                'onset': start,
                'duration': duration,
                'description': description,
                'stage': ANNOTATION_MAP[description]
            })

        return events

    except Exception as e:
        print(f"解析注释文件失败: {annotations_path}, 错误: {str(e)}")
        return None

def process_raw_data(raw, sleep_events, resample_freq=100):
    annotations = mne.Annotations(
        onset=[e['onset'] for e in sleep_events],
        duration=[e['duration'] for e in sleep_events],
        description=[e['description'] for e in sleep_events]
    )
    raw.set_annotations(annotations)

    if 'EEG' in raw.ch_names:
        raw.filter(l_freq=0.1, h_freq=40, picks=['EEG'], method='fir', fir_window='hamming', phase='zero')

    data_array = raw.get_data()
    normalized_data = normalize_per_channel(data_array)
    raw._data = normalized_data

    if raw.info['sfreq'] > resample_freq:
        raw.resample(resample_freq, npad='auto')

    events_from_annot, event_id = mne.events_from_annotations(
        raw, 
        event_id=ANNOTATION_MAP,
        chunk_duration=30.0
    )

    tmax = 30.0 - 1.0 / raw.info['sfreq']
    try:
        epochs = mne.Epochs(
            raw,
            events=events_from_annot,
            event_id=event_id,
            tmin=0.0,
            tmax=tmax,
            baseline=None,
            preload=True,
            verbose=False
        )
        return epochs
    except ValueError as e:
        print(f"创建Epochs失败: {str(e)}")
        return None

def save_epochs_to_npy(epochs, output_dir, file_id, file_index):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    stage_counts = {stage: 0 for stage in STAGE_NAMES.keys()}
    subfolder_name = f"SHHS1_{file_index:04d}"
    subfolder_path = os.path.join(output_dir, subfolder_name)
    os.makedirs(subfolder_path, exist_ok=True)

    for i in range(len(epochs)):
        epoch_data = epochs.get_data(item=i)
        epoch_label = epochs.events[i, 2]

        if epoch_label in [5, 6] or epoch_label not in STAGE_NAMES:
            continue

        stage_counts[epoch_label] += 1
        npy_path = os.path.join(subfolder_path, f"e{i:04d}_{epoch_label}.npy")
        np.save(npy_path, epoch_data[0])

    return stage_counts

def match_psg_annotation_files(psg_dir, annotation_dir):
    psg_files = [f for f in os.listdir(psg_dir) if f.endswith('.edf')]
    file_map = {}
    for psg_file in psg_files:
        base_id = re.sub(r'\.edf$', '', psg_file)
        base_id = re.sub(r'[\-_].*$', '', base_id)
        annotation_candidates = [
            f for f in os.listdir(annotation_dir)
            if f.startswith(base_id) and f.endswith('.xml')
        ]
        if annotation_candidates:
            annotation_file = sorted(annotation_candidates, key=len, reverse=True)[0]
            file_map[psg_file] = annotation_file
    return file_map

def main(max_files=6000):
    ANNOTATION_DIR = "D:/shhs/polysomnography/annotations-events-nsrr/shhs1"
    PSG_DIR = "D:/shhs/polysomnography/edfs/shhs1"
    OUTPUT_DIR = "E:/数据/shhs_staging"

    os.makedirs(OUTPUT_DIR, exist_ok=True)
    file_map = match_psg_annotation_files(PSG_DIR, ANNOTATION_DIR)
    print(f"找到 {len(file_map)} 对匹配的文件")
    global_stage_counts = {stage: 0 for stage in STAGE_NAMES.keys()}
    processed_files = 0
    

    file_items = list(file_map.items())[:]
    file_map = dict(file_items)
    
    for psg_file, annotation_file in tqdm(file_map.items(), desc="Processing files"):
        print(psg_file)
        if processed_files >= max_files:
            print(f"已处理 {max_files} 个文件，停止处理.")
            break

        psg_path = os.path.join(PSG_DIR, psg_file)
        annotation_path = os.path.join(ANNOTATION_DIR, annotation_file)

        raw, duration = read_and_preprocess_edf(psg_path)
        if raw is None:
            continue
        
        picked_channels = select_channels(raw)
        if picked_channels is None:
            continue

        sleep_events = parse_sleep_annotations(annotation_path)
        if sleep_events is None or not sleep_events:
            continue

        raw.pick_channels(picked_channels)
        epochs = process_raw_data(raw, sleep_events, resample_freq=100)
        if epochs is None or len(epochs) == 0:
            continue

        stage_counts = save_epochs_to_npy(epochs, OUTPUT_DIR, psg_file, processed_files + 1)
        for stage, count in stage_counts.items():
            global_stage_counts[stage] += count

        processed_files += 1
        print(f"已处理 {processed_files} 个文件")

    print("\n处理完成，最终统计数据:")
    for stage, count in global_stage_counts.items():
        print(f"阶段 {stage}: {count} 个epochs")
    print(f"总计: {sum(global_stage_counts.values())} 个epochs")

if __name__ == "__main__":
    main(max_files=6000)
