# Preprocessing Pipeline for Sleep-EDF Dataset

---


This pipeline performs the following tasks on the Sleep-EDF dataset:

1.   Load PSG EEG channels (Fpz-Cz & Pz-Oz) and hypnogram
2.   Apply a bandpass filter of 0.3-35 Hz
3.   Epoch EEG signals into 30s segments
4.   Map stages to labels
5.   Align labels to epochs
6.   Drop movement/unwanted data
7.   Save preprocessed data into .npy files



---



The preprocessed output consists of 2 numpy arrays:


*   X - EEG inputs
*   Y - labels


# Import + define stages-label map

In [2]:
import os
import glob
import numpy as np
import mne

# Sleep stage to label mapping
stage_map = {
    'Sleep stage W': 0,
    'Sleep stage 1': 1,
    'Sleep stage 2': 2,
    'Sleep stage 3': 3,
    'Sleep stage 4': 3,  # combine 3 + 4 into N3
    'Sleep stage R': 4,
    'Sleep stage M': -1,  # drop
    'Sleep stage ?': -1   # drop
}


# Fetch PSG + hypnogram files

In [None]:
data_dir = "/raw_data/"

psg_files = sorted(glob.glob(data_dir + "/physionet.org/files/sleep-edfx/1.0.0/sleep-cassette/*PSG.edf"))
hyp_files = sorted(glob.glob(data_dir + "/physionet.org/files/sleep-edfx/1.0.0/sleep-cassette/*Hypnogram.edf"))

print("PSG:", len(psg_files), "Hypnograms:", len(hyp_files))


Mounted at /content/drive
PSG: 153 Hypnograms: 153


# Match PSG-hypnogram pairs

In [4]:
pairs = []

for psg in psg_files:
    base = os.path.basename(psg).split("-")[0][:6]  # e.g. SC4031
    match = [h for h in hyp_files if base in h]
    if match:
        pairs.append((psg, match[0]))

print("Matched pairs:", len(pairs))


Matched pairs: 153


# Main preprocessing function

In [7]:
def preprocess_recording(psg_path, hyp_path, epoch_length=30):
    """Preprocess one Sleep Cassette PSG + Hypnogram into CNN-ready X, y."""

    # --------------------------
    # 1. Read PSG (EEGs)
    # --------------------------
    raw = mne.io.read_raw_edf(psg_path, preload=True)

    # Extract EEG channels recommended in documentation
    available = raw.ch_names
    eeg_channels = [ch for ch in ["EEG Fpz-Cz", "EEG Pz-Oz"] if ch in available]

    if len(eeg_channels) == 0:
        raise ValueError("EEG channels missing from file:", psg_path)

    raw.pick_channels(eeg_channels)

    # Bandpass filter 0.3â€“35 Hz (sleep scoring standard)
    raw.filter(0.3, 35., fir_design="firwin")

    # --------------------------
    # 2. Read HYPNOGRAM
    # --------------------------
    annots = mne.read_annotations(hyp_path)
    raw.set_annotations(annots)

    # Map sleep stages to label indices
    events, event_ids = mne.events_from_annotations(raw, event_id=stage_map)

    # --------------------------
    # 3. Split eeg into 30s epochs
    # --------------------------
    epochs = mne.make_fixed_length_epochs(raw,
                                          duration=epoch_length,
                                          preload=True)

    X = epochs.get_data()  # shape: (n_epochs, channels, samples)

    # --------------------------
    # 4. ASSIGN STAGE LABEL TO EACH EPOCH
    # --------------------------
    y = []

    for start in epochs.events[:, 0] / raw.info['sfreq']:
        # Find annotation whose time window covers this epoch start
        match = None
        for annot in annots:
            if annot['onset'] <= start < annot['onset'] + annot['duration']:
                match = annot['description']
                break

        if match is None:
            y.append(-1)
        else:
            y.append(stage_map.get(match, -1))

    y = np.array(y)

    # --------------------------
    # 5. Drop invalid epochs (Movement, ?)
    # --------------------------
    valid = y != -1
    X = X[valid]
    y = y[valid]

    return X, y


(Optional) Make some logs from MNE silent

In [8]:
import warnings

mne.set_log_level('WARNING')
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning)

# Run preprocessing for each Cassette subject

In [None]:
import gc
import numpy as np

save_dir = "/content/processed_sc"
os.makedirs(save_dir, exist_ok=True)

for i, (psg, hyp) in enumerate(pairs):
    print(f"Processing file {i+1}/{len(pairs)}: {os.path.basename(psg)}")
    X, y = preprocess_recording(psg, hyp)

    print("  Shapes:", X.shape, y.shape)
    print("  Channels:", X.shape[1])
    print("  Samples per epoch:", X.shape[2])
    print("  Unique labels:", np.unique(y))

    # Show small data snippet
    print("  First label:", y[0])
    print("  First epoch (first 3 samples):", X[0, 0, :3])
    print("  ----------------------------------")

    # Save EACH SUBJECT separately
    np.save(f"{save_dir}/X_{i}.npy", X)
    np.save(f"{save_dir}/y_{i}.npy", y)
    print(f"  Saved X_{i}.npy and y_{i}.npy")

    # CLEAR RAM
    del X, y
    gc.collect()


Processing file 1/153: SC4001E0-PSG.edf
  Shapes: (2650, 2, 3000) (2650,)
  Channels: 2
  Samples per epoch: 3000
  Unique labels: [0 1 2 3 4]
  First label: 0
  First epoch (first 3 samples): [ 2.03287907e-20 -7.12998692e-06 -4.31699508e-06]
  ----------------------------------
  Saved X_0.npy and y_0.npy
Processing file 2/153: SC4002E0-PSG.edf
  Shapes: (2829, 2, 3000) (2829,)
  Channels: 2
  Samples per epoch: 3000
  Unique labels: [0 1 2 3 4]
  First label: 0
  First epoch (first 3 samples): [3.38813179e-21 3.68525224e-05 2.13092512e-05]
  ----------------------------------
  Saved X_1.npy and y_1.npy
Processing file 3/153: SC4011E0-PSG.edf
  Shapes: (2802, 2, 3000) (2802,)
  Channels: 2
  Samples per epoch: 3000
  Unique labels: [0 1 2 3 4]
  First label: 0
  First epoch (first 3 samples): [ 3.26107685e-20 -4.79873198e-06 -5.56032260e-06]
  ----------------------------------
  Saved X_2.npy and y_2.npy
Processing file 4/153: SC4012E0-PSG.edf
  Shapes: (2848, 2, 3000) (2848,)
  Cha

# Combine data for each subject

In [None]:
save_dir = "/content/processed_sc"
os.makedirs(save_dir, exist_ok=True)

files = sorted(os.listdir(save_dir))

X_list = []
y_list = []

for f in files:
    if f.startswith("X_"):
        X_list.append(np.load(os.path.join(save_dir, f)))
    if f.startswith("y_"):
        y_list.append(np.load(os.path.join(save_dir, f)))

X = np.vstack(X_list)
y = np.hstack(y_list)


# Print data summary + save data

In [None]:
print("===== FINAL DATASET SUMMARY (from memmap) =====")
print("X_mm shape:", X.shape)
print("y_mm shape:", y.shape)
print("Channels used:", X.shape[1])
print("Samples per epoch:", X.shape[2])

print("\nLabel distribution:")
unique, counts = np.unique(y, return_counts=True)
for u, c in zip(unique, counts):
    print(f"  Stage {u}: {c} epochs")



===== FINAL DATASET SUMMARY (from memmap) =====
X_mm shape: (414961, 2, 3000)
y_mm shape: (414961,)
Channels used: 2
Samples per epoch: 3000

Label distribution:
  Stage 0: 285433 epochs
  Stage 1: 21522 epochs
  Stage 2: 69132 epochs
  Stage 3: 13039 epochs
  Stage 4: 25835 epochs


In [None]:
data_dir = "/preprocessed_data/"

np.save(data_dir + "/X.npy", X)
np.save(data_dir + "/y.npy", y)
