In [1]:
# Data Preparation & DataLoaders for AHI Regression

# This notebook prepares a dataset of per-window EEG spectrogram images for sleep apnea severity (AHI) regression.
# It loads subject-level metadata, performs stratified subject sampling, generates and quality-controls spectrograms,
# assigns subject-stratified cross-validation folds, and builds PyTorch DataLoaders for model training and evaluation.
# The resulting dataset enables robust, leakage-free training and validation of deep learning models for AHI prediction from EEG.

In [2]:
# 1. Imports & Configuration
import pandas as pd
import numpy as np
from pathlib import Path
from scipy.signal import stft
import matplotlib
matplotlib.use('Agg')  # headless backend for image export
import matplotlib.pyplot as plt
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import StratifiedShuffleSplit, StratifiedGroupKFold
import pyedflib

# Spectrogram & QC Parameters (aligned with clinical standards)
WINDOW_SECONDS = 30    # epoch length in seconds (AASM standard)
OVERLAP_FRAC   = 0.5   # 50% overlap captures transitions between breathing events
NPERSEG        = 256   # STFT window length (~1s at 256Hz sampling)
NOVERLAP       = 128   # overlap between STFT segments for smooth spectrograms
MIN_STD        = 5.0   # drop flat epochs (low variance indicates sensor issues)
MAX_STD        = 200.0 # drop artifact epochs (high variance indicates noise)
N_SUBJECTS     = 100   # number of subjects for prototyping
RANDOM_STATE   = 26    # seed for reproducibility

# Paths
DATA_DIR     = Path("/Volumes/T9/projects/sleep-apnea-classification-using-eeg-spectrograms/data/sleep-heart-health-study")
OUT_DIR      = Path("/Volumes/T9/projects/sleep-apnea-classification-using-eeg-spectrograms/data/spectrograms")
RESNET_DIR   = OUT_DIR / "resnet224_color"
METADATA_CSV = OUT_DIR / 'window_metadata.csv'
SPLITS_CSV   = OUT_DIR / 'spectrogram_splits.csv'

for d in (OUT_DIR, RESNET_DIR):
    d.mkdir(parents=True, exist_ok=True)

# Helper function to map AHI to clinical severity bins
def map_ahi_to_label(ahi: float) -> str:
    if ahi < 5:
        return "healthy"
    elif ahi < 15:
        return "mild"
    elif ahi < 30:
        return "moderate"
    else:
        return "severe"

In [3]:
# 2. Load metadata and subsample subjects
meta = (
    pd.read_csv(DATA_DIR / "shhs1-dataset-0.21.0.csv", usecols=["nsrrid","ahi_a0h4"] )
      .rename(columns={"nsrrid":"subject_id","ahi_a0h4":"ahi"})
)
meta['subject_id'] = meta['subject_id'].astype(str)
# filter subjects with available EDF files
available = {f.stem.split('-')[1] for f in DATA_DIR.rglob('*.edf')}
meta = meta[meta['subject_id'].isin(available)].reset_index(drop=True)
# add severity label for stratified sampling
meta['severity'] = meta['ahi'].apply(map_ahi_to_label)
# stratified subsampling by severity category
sss = StratifiedShuffleSplit(n_splits=1, train_size=N_SUBJECTS, random_state=RANDOM_STATE)
train_idx, _ = next(sss.split(meta['subject_id'], meta['severity']))
meta = meta.iloc[train_idx].reset_index(drop=True)
print(f"[INFO] Selected {len(meta)} subjects for analysis")

[INFO] Selected 100 subjects for analysis


In [4]:
# Log severity distribution after subsampling
severity_counts = meta['severity'].value_counts().sort_index()
print(f"[INFO] Severity distribution after subsampling: {severity_counts.to_dict()}")

[INFO] Severity distribution after subsampling: {'healthy': 49, 'mild': 30, 'moderate': 14, 'severe': 7}


In [5]:
# 3. Compute global contrast limits for spectrogram color mapping
limits = []
def find_edf(sid: str) -> Path:
    for ext in ("edf","EDF"):
        for f in DATA_DIR.rglob(f"*.{ext}"):
            if sid in f.stem:
                return f
    return None

def read_eeg_channel(
    edf_path: Path,
    prefs: tuple[str, ...] = ("EEG(sec)", "EEG2", "EEG 2", "EEG sec")
) -> tuple[np.ndarray, float]:
    """
    Read only the intended EEG channels matching specified preferences; raise if missing.
    """
    with pyedflib.EdfReader(str(edf_path)) as f:
        labels = f.getSignalLabels()
        idx = None
        # Search in order of given preference list
        for pref in prefs:
            idx = next(
                (i for i, ch in enumerate(labels) if pref.lower() in ch.lower()),
                None
            )
            if idx is not None:
                break
        if idx is None:
            raise RuntimeError(
                f"No matching EEG channels in {edf_path.name}. Available channels: {labels}"
            )
        # Read selected channel
        signal = f.readSignal(idx)
        fs = f.getSampleFrequency(idx)
    return signal, fs

for sid in meta['subject_id']:
    edf_file = find_edf(sid)
    if not edf_file:
        continue
    sig, fs = read_eeg_channel(edf_file)
    win, hop = int(WINDOW_SECONDS*fs), int(WINDOW_SECONDS*fs*(1-OVERLAP_FRAC))
    n_epochs = (len(sig) - win) // hop + 1
    for i in range(n_epochs):
        segment = sig[i*hop : i*hop+win]
        _,_,Zxx = stft(segment, fs=fs, nperseg=NPERSEG, noverlap=NOVERLAP)
        db = 20*np.log10(np.abs(Zxx) + 1e-6)
        limits.append(np.percentile(db, [1,99]))
vmin, vmax = np.percentile(np.vstack(limits), [1,99])
print(f"[INFO] Global spectrogram limits: vmin={vmin:.2f} dB, vmax={vmax:.2f} dB")

[INFO] Global spectrogram limits: vmin=-53.68 dB, vmax=32.78 dB


In [6]:
# 4. Generate and quality-control spectrogram windows
entries = []
for sid, ahi in zip(meta['subject_id'], meta['ahi']):
    edf_file = find_edf(sid)
    if not edf_file:
        print(f"[WARN] Missing EDF for subject {sid}")
        continue
    sig, fs = read_eeg_channel(edf_file)
    win, hop = int(WINDOW_SECONDS*fs), int(WINDOW_SECONDS*fs*(1-OVERLAP_FRAC))
    n_epochs = (len(sig) - win) // hop + 1
    dropped_low = dropped_high = saved = 0
    print(f"[INFO] Processing {sid}: {n_epochs} windows")
    for i in range(n_epochs):
        segment = sig[i*hop : i*hop+win]
        _,_,Zxx = stft(segment, fs=fs, nperseg=NPERSEG, noverlap=NOVERLAP)
        db = 20*np.log10(np.abs(Zxx) + 1e-6)
        std = db.std()
        if std <= MIN_STD:
            dropped_low += 1
            continue
        if std >= MAX_STD:
            dropped_high += 1
            continue
        fig, ax = plt.subplots(figsize=(2.24,2.24), dpi=100)
        ax.pcolormesh(db, cmap='viridis', vmin=vmin, vmax=vmax)
        ax.axis('off')
        out_dir = RESNET_DIR / sid
        out_dir.mkdir(parents=True, exist_ok=True)
        filepath = out_dir / f"{sid}_{i:04d}.png"
        fig.savefig(filepath, bbox_inches='tight', pad_inches=0)
        plt.close(fig)
        entries.append({'subject_id': sid, 'ahi': ahi, 'spectrogram_path': str(filepath)})
        saved += 1
    print(f"[INFO] {sid}: saved={saved}, dropped_low={dropped_low}, dropped_high={dropped_high}")
window_df = pd.DataFrame(entries)
window_df.to_csv(METADATA_CSV, index=False)
print(f"[INFO] Window metadata saved to {METADATA_CSV}")

[INFO] Processing 201701: 2105 windows
[INFO] 201701: saved=2105, dropped_low=0, dropped_high=0
[INFO] Processing 201824: 2111 windows
[INFO] 201824: saved=2111, dropped_low=0, dropped_high=0
[INFO] Processing 204499: 2135 windows
[INFO] 204499: saved=2135, dropped_low=0, dropped_high=0
[INFO] Processing 204685: 2117 windows
[INFO] 204685: saved=2117, dropped_low=0, dropped_high=0
[INFO] Processing 201408: 2101 windows
[INFO] 201408: saved=2101, dropped_low=0, dropped_high=0
[INFO] Processing 200509: 2137 windows
[INFO] 200509: saved=2137, dropped_low=0, dropped_high=0
[INFO] Processing 205321: 1797 windows
[INFO] 205321: saved=1797, dropped_low=0, dropped_high=0
[INFO] Processing 200320: 1877 windows
[INFO] 200320: saved=1877, dropped_low=0, dropped_high=0
[INFO] Processing 200957: 2137 windows
[INFO] 200957: saved=2137, dropped_low=0, dropped_high=0
[INFO] Processing 203025: 2101 windows
[INFO] 203025: saved=2101, dropped_low=0, dropped_high=0
[INFO] Processing 201264: 1797 windows
[

In [7]:
window_df = pd.read_csv(METADATA_CSV)
window_df = window_df[window_df['spectrogram_path'].apply(lambda p: Path(p).exists())].reset_index(drop=True)
window_df.to_csv(METADATA_CSV, index=False)
print(f"[INFO] Updated metadata saved to {METADATA_CSV} (after removing missing spectrograms)")

[INFO] Updated metadata saved to /Volumes/T9/projects/sleep-apnea-classification-using-eeg-spectrograms/data/spectrograms/window_metadata.csv (after removing missing spectrograms)


In [8]:
# 5. Assign windows to stratified folds
window_df['severity'] = window_df['ahi'].apply(map_ahi_to_label)
window_df['fold'] = -1
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=RANDOM_STATE)
for fold, (_, idx_test) in enumerate(sgkf.split(window_df, window_df['severity'], groups=window_df['subject_id'])):
    window_df.loc[idx_test, 'fold'] = fold
split_df = window_df[['subject_id','fold']].drop_duplicates()
split_df.to_csv(SPLITS_CSV, index=False)
print(f"[INFO] Fold assignments saved to {SPLITS_CSV}")
# Log fold distribution to verify balance
fold_counts = window_df['fold'].value_counts().sort_index()
print(f"[INFO] Window count per fold: {fold_counts.to_dict()}")

[INFO] Fold assignments saved to /Volumes/T9/projects/sleep-apnea-classification-using-eeg-spectrograms/data/spectrograms/spectrogram_splits.csv
[INFO] Window count per fold: {0: 33300, 1: 32616, 2: 31426, 3: 35844, 4: 33292}


In [9]:
# 6. Prepare DataLoaders for model training
def make_loader(df, batch_size=32, shuffle=False):
    class SpectrogramWindowDataset(Dataset):
        def __init__(self, df, transform=None):
            self.df = df.reset_index(drop=True)
            self.ahi_map = self.df.groupby('subject_id')['ahi'].first().to_dict()
            self.df['ahi_target'] = self.df['subject_id'].map(self.ahi_map)
            self.transform = transform or transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
            ])
        def __len__(self): return len(self.df)
        def __getitem__(self, idx):
            row = self.df.iloc[idx]
            img = Image.open(row['spectrogram_path']).convert('RGB')
            x = self.transform(img)
            y = torch.tensor(row['ahi_target'], dtype=torch.float32)
            return x, y
    return DataLoader(SpectrogramWindowDataset(df), batch_size=batch_size,
                      shuffle=shuffle, num_workers=4, pin_memory=True)

# Create loaders for training and validation
loaders = {}
for fold in range(5):
    df_fold = window_df[window_df['fold'] == fold]
    loaders[f'train_fold_{fold}'] = make_loader(window_df[window_df['fold'] != fold], shuffle=True)
    loaders[f'val_fold_{fold}']   = make_loader(df_fold, shuffle=False)

print("[INFO] DataLoaders are ready.")

[INFO] DataLoaders are ready.
