<a href="https://colab.research.google.com/github/sameekshya1999/EEG-Sleep-Stage-Classification/blob/main/Sleep_Stage_Classification_using_EEG_Signals_with_CNN_and_EEGNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install mne
!pip install scikit-learn
!pip install seaborn
!pip install tqdm
!pip install matplotlib tensorflow



Collecting mne
  Downloading mne-1.9.0-py3-none-any.whl.metadata (20 kB)
Downloading mne-1.9.0-py3-none-any.whl (7.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m22.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mne
Successfully installed mne-1.9.0


# 🧠 Sleep Stage Classification using EEG Signals with CNN

This project implements a deep learning pipeline to classify sleep stages based on EEG data from the **Sleep-EDF dataset**. The pipeline includes:

- ✅ Automatic data download and preprocessing using MNE
- ✅ Epoching and normalization of EEG signals (Fpz-Cz, Pz-Oz channels)
- ✅ Label alignment using standard AASM sleep stages: **Wake, N1, N2, N3, REM**
- ✅ A **Convolutional Neural Network (CNN)** architecture built with TensorFlow/Keras
- ✅ Class distribution statistics, training history, and performance evaluation
- ✅ Output metrics: **Accuracy, Classification Report, Confusion Matrix**

### 🔬 Dataset:
- Source: [Sleep-EDF Database](https://physionet.org/content/sleep-edfx/1.0.0/)
- Type: Polysomnography (PSG) recordings from healthy adults
- Channels used: `EEG Fpz-Cz` and `EEG Pz-Oz`

### 📊 Evaluation:
- Stratified Train/Test Split (80/20)
- Final metrics include per-class precision, recall, and F1-score

---

**Author:** Samiksha BC  
**Project:** Master's Final Project – EEG-based Sleep Stage Classification  
**University:** Indiana University South Bend


In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import mne
import urllib.request
import os
from sklearn.model_selection import train_test_split
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix

# Suppress MNE warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
mne.set_log_level('ERROR')

# Constants
NUM_SUBJECTS = 20
NUM_NIGHTS = 2
BASE_URL = "https://physionet.org/files/sleep-edfx/1.0.0/"

def fetch_data(subject_id, night, record_type='PSG'):
    try:
        if record_type == 'PSG':
            file_name = f"SC4{subject_id:02d}{night}E0-PSG.edf"
        else:
            file_name = f"SC4{subject_id:02d}{night}EC-Hypnogram.edf"

        url = BASE_URL + ("sleep-cassette/" if night == 1 else "sleep-telemetry/") + file_name
        local_file = f"sleep_edf/{file_name}"
        os.makedirs("sleep_edf", exist_ok=True)

        if not os.path.exists(local_file):
            urllib.request.urlretrieve(url, local_file)
            print(f"Downloaded: {file_name}")

        return local_file
    except Exception:
        return None

def get_available_subjects():
    available = []
    with ThreadPoolExecutor(max_workers=10) as executor:
        futures = []
        for subject_id in range(NUM_SUBJECTS):
            for night in range(1, NUM_NIGHTS+1):
                futures.append((subject_id, night, executor.submit(
                    lambda s, n: fetch_data(s, n) is not None, subject_id, night)))
        for subject_id, night, future in tqdm(futures, desc="Checking availability"):
            if future.result():
                available.append((subject_id, night))
    return available

def process_subject_night(subject_id, night):
    try:
        psg_file = fetch_data(subject_id, night, 'PSG')
        hypno_file = fetch_data(subject_id, night, 'Hypnogram')
        if psg_file is None or hypno_file is None:
            return None, None

        raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
        required_channels = ['EEG Fpz-Cz', 'EEG Pz-Oz']
        available = [ch for ch in required_channels if ch in raw.ch_names]
        if len(available) < 2:
            print(f"Skipping subject {subject_id}, night {night}: missing required channels")
            return None, None
        raw.pick_channels(available)

        raw.filter(0.5, 40.0, l_trans_bandwidth=0.5, h_trans_bandwidth=10.0, verbose=False)
        data = raw.get_data(units='uV')
        sfreq = raw.info['sfreq']

        samples_per_epoch = int(30 * sfreq)
        n_epochs = data.shape[1] // samples_per_epoch
        epochs = np.array([data[:, i*samples_per_epoch:(i+1)*samples_per_epoch] for i in range(n_epochs)])

        annotations = mne.read_annotations(hypno_file)
        labels = np.zeros(n_epochs, dtype=int)
        stage_map = {
            'Sleep stage W': 0, 'Sleep stage 1': 1,
            'Sleep stage 2': 2, 'Sleep stage 3': 3,
            'Sleep stage 4': 3, 'Sleep stage R': 4
        }

        for annot in annotations:
            onset = int(annot['onset'] / 30)
            duration = int(annot['duration'] / 30)
            stage = annot['description']
            if stage in stage_map:
                for i in range(max(0, onset), min(n_epochs, onset + duration)):
                    labels[i] = stage_map[stage]

        epochs = (epochs - np.mean(epochs, axis=(1,2), keepdims=True)) / np.std(epochs, axis=(1,2), keepdims=True)
        X = epochs.transpose(0, 2, 1)
        y = labels

        return X, y
    except Exception as e:
        print(f"Error processing subject {subject_id} night {night}: {str(e)}")
        return None, None

def build_model(input_shape):
    model = models.Sequential([
        layers.Input(shape=input_shape),
        layers.Conv1D(64, 7, activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling1D(2),
        layers.Conv1D(128, 7, activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling1D(2),
        layers.Conv1D(256, 7, activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.GlobalAveragePooling1D(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.3),
        layers.Dense(5, activation='softmax')
    ])
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    return model

def main():
    available = get_available_subjects()
    print(f"\nFound {len(available)} available subject-night combinations")
    if not available:
        print("No data available - check your internet connection")
        return

    all_X, all_y = [], []
    with ThreadPoolExecutor(max_workers=4) as executor:
        futures = [executor.submit(process_subject_night, s, n) for s, n in available]
        for future in tqdm(futures, desc="Processing data"):
            X, y = future.result()
            if X is not None and y is not None:
                all_X.append(X)
                all_y.append(y)

    if not all_X:
        print("No valid data processed")
        return

    X = np.concatenate(all_X)
    y = np.concatenate(all_y)

    print(f"\nFinal dataset: {X.shape[0]} epochs")
    print("Class distribution:")
    for i, stage in enumerate(['Wake', 'N1', 'N2', 'N3', 'REM']):
        print(f"{stage}: {np.sum(y == i)} ({(np.sum(y == i)/len(y))*100:.1f}%)")

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, stratify=y, random_state=42
    )

    model = build_model((X.shape[1], X.shape[2]))
    print("\nTraining model...")
    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=15,
        batch_size=64,
        verbose=1
    )

    test_loss, test_acc = model.evaluate(X_test, y_test, verbose=0)
    print(f"\nFinal Test Accuracy: {test_acc:.4f}")
    print(f"Final Test Loss: {test_loss:.4f}")

    # Per-class metrics
    y_pred_probs = model.predict(X_test)
    y_pred = np.argmax(y_pred_probs, axis=1)

    print("\nClassification Report:")
    print(classification_report(y_test, y_pred, target_names=['Wake', 'N1', 'N2', 'N3', 'REM']))

    cm = confusion_matrix(y_test, y_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Wake', 'N1', 'N2', 'N3', 'REM'],
                yticklabels=['Wake', 'N1', 'N2', 'N3', 'REM'])
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.tight_layout()
    plt.savefig('2channel_confusion_matrix.png')
    plt.close()

if __name__ == "__main__":
    main()


Checking availability:   0%|          | 0/40 [00:00<?, ?it/s]

Downloaded: SC4081E0-PSG.edf
Downloaded: SC4041E0-PSG.edf


Checking availability:   2%|▎         | 1/40 [09:53<6:25:41, 593.36s/it]

Downloaded: SC4001E0-PSG.edf
Downloaded: SC4091E0-PSG.edf
Downloaded: SC4051E0-PSG.edf
Downloaded: SC4061E0-PSG.edf
Downloaded: SC4031E0-PSG.edf


Checking availability:   8%|▊         | 3/40 [10:10<1:38:43, 160.09s/it]

Downloaded: SC4011E0-PSG.edf


Checking availability:  12%|█▎        | 5/40 [10:10<45:27, 77.92s/it]   

Downloaded: SC4021E0-PSG.edf


Checking availability:  38%|███▊      | 15/40 [10:12<07:00, 16.81s/it]

Downloaded: SC4071E0-PSG.edf


Checking availability:  52%|█████▎    | 21/40 [10:57<04:08, 13.08s/it]

Downloaded: SC4101E0-PSG.edf


Checking availability:  57%|█████▊    | 23/40 [11:47<04:15, 15.01s/it]

Downloaded: SC4111E0-PSG.edf


Checking availability:  62%|██████▎   | 25/40 [11:56<03:15, 13.03s/it]

Downloaded: SC4121E0-PSG.edf


Checking availability:  68%|██████▊   | 27/40 [13:08<03:52, 17.88s/it]

Downloaded: SC4131E0-PSG.edf
Downloaded: SC4161E0-PSG.edf


Checking availability:  72%|███████▎  | 29/40 [14:37<04:24, 24.02s/it]

Downloaded: SC4141E0-PSG.edf
Downloaded: SC4171E0-PSG.edf


Checking availability:  78%|███████▊  | 31/40 [14:51<02:58, 19.83s/it]

Downloaded: SC4151E0-PSG.edf


Checking availability:  92%|█████████▎| 37/40 [15:15<00:34, 11.65s/it]

Downloaded: SC4181E0-PSG.edf


Checking availability: 100%|██████████| 40/40 [15:17<00:00, 22.94s/it]


Downloaded: SC4191E0-PSG.edf

Found 20 available subject-night combinations


Processing data:   0%|          | 0/20 [00:00<?, ?it/s]

Downloaded: SC4031EC-Hypnogram.edf
Downloaded: SC4001EC-Hypnogram.edf
Downloaded: SC4041EC-Hypnogram.edf
Downloaded: SC4051EC-Hypnogram.edf


  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=Fa

Downloaded: SC4061EC-Hypnogram.edf
Downloaded: SC4071EC-Hypnogram.edf


  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
Processing data:  25%|██▌       | 5/20 [00:20<00:52,  3.51s/it]

Downloaded: SC4081EC-Hypnogram.edf


  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)


Downloaded: SC4091EC-Hypnogram.edf


  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)


Downloaded: SC4101EC-Hypnogram.edf


  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)


Downloaded: SC4111EC-Hypnogram.edf


  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)


Downloaded: SC4121EC-Hypnogram.edf


  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
Processing data:  30%|███       | 6/20 [00:59<03:09, 13.55s/it]

Downloaded: SC4131EC-Hypnogram.edf


  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
Processing data:  50%|█████     | 10/20 [01:16<01:17,  7.79s/it]

Downloaded: SC4151EC-Hypnogram.edf


  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)


Downloaded: SC4161EC-Hypnogram.edf


  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)


Downloaded: SC4181EC-Hypnogram.edf


  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=True, verbose=False)
Processing data: 100%|██████████| 20/20 [01:33<00:00,  4.68s/it]



Final dataset: 40834 epochs
Class distribution:
Wake: 28490 (69.8%)
N1: 869 (2.1%)
N2: 6528 (16.0%)
N3: 2305 (5.6%)
REM: 2642 (6.5%)

Training model...
Epoch 1/15
[1m511/511[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 60ms/step - accuracy: 0.8885 - loss: 0.3212 - val_accuracy: 0.9137 - val_loss: 0.2518
Epoch 2/15
[1m511/511[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 50ms/step - accuracy: 0.9362 - loss: 0.1786 - val_accuracy: 0.9344 - val_loss: 0.1722
Epoch 3/15
[1m511/511[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 50ms/step - accuracy: 0.9403 - loss: 0.1644 - val_accuracy: 0.9336 - val_loss: 0.1877
Epoch 4/15
[1m511/511[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 50ms/step - accuracy: 0.9454 - loss: 0.1534 - val_accuracy: 0.9491 - val_loss: 0.1379
Epoch 5/15
[1m511/511[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 50ms/step - accuracy: 0.9437 - loss: 0.1500 - val_accuracy: 0.9456 - val_loss: 0.1423
Epoch 6/15
[1m511/511[0m 

In [None]:
!pip install mne
!pip install scikit-learn
!pip install seaborn
!pip install tqdm
!pip install matplotlib
!pip install tensorflow




# 🧠 EEGNet with Attention for Sleep Stage Classification

This notebook implements a full deep learning pipeline for classifying sleep stages using EEG data from the **Sleep-EDF dataset**, leveraging a modified **EEGNet architecture with multi-head temporal attention**. The model is designed for efficient training and improved per-class performance, especially on hard-to-classify stages like **N1** and **REM**.

---

## ✅ **Key Features of This Notebook:**

- 📥 **Automated data download** from PhysioNet's Sleep-EDF (Cassette + Telemetry) datasets
- 🔄 **Data preprocessing**: filtering, resampling, epoching (30s), channel selection (`Fpz-Cz`, `Pz-Oz`)
- 📊 **Data augmentation** with on-the-fly noise and time-shifted copies
- 🧠 **EEGNet-based model** enhanced with **multi-head temporal attention**
- ⚖️ **Class balancing** using sample weights
- 🧪 **Evaluation**: accuracy, per-class F1, precision/recall, and confusion matrix
- 📉 **Visualization** of training loss/accuracy curves and confusion matrix

---

## 📁 **Dataset**
- **Source**: [PhysioNet Sleep-EDF](https://physionet.org/content/sleep-edfx/1.0.0/)
- **EEG Channels Used**: `EEG Fpz-Cz`, `EEG Pz-Oz`
- **Sleep Stages Classified**:
  - 0 → Wake
  - 1 → N1
  - 2 → N2
  - 3 → N3 (merged S3 + S4)
  - 4 → REM

---

## 🚀 Hardware & Performance
- Uses **mixed precision training** for GPU memory efficiency
- Optimized for systems with **12GB RAM and ~15GB GPU VRAM**
- Scalable across subjects and supports training on partial batches

---

## 📌 Author
- **Samiksha BC**, Indiana University South Bend  
- Final Year Independent Study Project | Master of Science in Computer Science

---




In [None]:



'''
EEGNet with Attention for Sleep Stage Classification

This script implements a deep learning pipeline for classifying sleep stages
(Wake, N1, N2, N3, REM) using EEG data from the Sleep-EDF dataset. It includes:
- Memory-efficient data fetching and preprocessing with MNE
- A custom EEGNet model with a temporal attention layer
- On-the-fly and static augmentation for ~75,000-90,000 epochs
- Batch processing and mixed precision for GPU efficiency
- Comprehensive evaluation with accuracy, F1-score, and confusion matrix
- Visualizations of training curves

Optimized for 12GB RAM and 15GB GPU RAM, using ~5-7GB GPU memory.
'''

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.utils import Sequence
import mne
import urllib.request
import os
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn
import gc

from tensorflow.keras.mixed_precision import set_global_policy
set_global_policy('mixed_float16')

warnings.filterwarnings("ignore", category=DeprecationWarning)
mne.set_log_level('ERROR')

NUM_SUBJECTS = 20
NUM_NIGHTS = 2
BASE_URL = "https://physionet.org/files/sleep-edfx/1.0.0/"
TARGET_CHANNELS = ['EEG Fpz-Cz', 'EEG Pz-Oz']
EPOCH_DURATION = 30
BATCH_SIZE = 128
EPOCHS = 50
SAMPLING_RATE = 50

TELEMETRY_SUBJECTS = [2, 4, 5, 6, 7, 12, 13]

print(f"scikit-learn version: {sklearn.__version__}")

def fetch_data(subject_id, night, record_type='PSG'):
    try:
        dataset_id = subject_id + 1
        folder = "sleep-cassette" if night == 1 else "sleep-telemetry"

        if night == 1:
            prefix = f"SC4{dataset_id:02d}"
        else:
            if subject_id not in TELEMETRY_SUBJECTS:
                return None
            telemetry_map = {2: 702, 4: 704, 5: 705, 6: 706, 7: 707, 12: 712, 13: 713}
            prefix = f"ST{telemetry_map.get(subject_id, 700 + dataset_id)}"

        file_name = f"{prefix}{night if night == 1 else 2}E0-PSG.edf" if record_type == 'PSG' else \
                    f"{prefix}{night if night == 1 else 2}EC-Hypnogram.edf"
        url = f"{BASE_URL}{folder}/{file_name}"
        local_file = os.path.join("sleep_edf", file_name)
        os.makedirs("sleep_edf", exist_ok=True)

        if not os.path.exists(local_file):
            urllib.request.urlretrieve(url, local_file)
            print(f"Downloaded {file_name}")
        return local_file
    except urllib.error.HTTPError as e:
        print(f"HTTP Error {e.code} fetching {file_name}: {e.reason}")
        return None
    except Exception as e:
        print(f"Error fetching {file_name}: {e}")
        return None

def get_available_subjects():
    available = []
    with ThreadPoolExecutor(max_workers=10) as executor:
        futures = []
        for subject_id in range(NUM_SUBJECTS):
            for night in range(1, NUM_NIGHTS + 1):
                futures.append((
                    subject_id,
                    night,
                    executor.submit(
                        lambda s, n: (
                            fetch_data(s, n, 'PSG') is not None and
                            fetch_data(s, n, 'Hypnogram') is not None
                        ),
                        subject_id, night
                    )
                ))

        for subject_id, night, future in tqdm(futures, desc="Checking availability"):
            if future.result():
                available.append((subject_id, night))
    print(f"Available subject-night pairs: {available}")
    return available

def augment_data(X):
    noise = np.random.normal(0, 0.01, X.shape)
    shift = np.random.randint(-50, 50)
    X_aug = np.roll(X + noise, shift, axis=1)
    return X_aug

def process_subject_night(subject_id, night):
    try:
        psg_file = fetch_data(subject_id, night, 'PSG')
        hypno_file = fetch_data(subject_id, night, 'Hypnogram')
        if psg_file is None or hypno_file is None:
            print(f"Skipping subject {subject_id}, night {night}: Missing files")
            return None, None

        raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
        available_channels = [ch for ch in TARGET_CHANNELS if ch in raw.ch_names]
        if not available_channels:
            print(f"No target channels for subject {subject_id}, night {night}")
            return None, None
        raw.pick_channels(available_channels)

        raw.load_data()
        raw.filter(0.5, 40.0, l_trans_bandwidth=0.5, h_trans_bandwidth=10.0, verbose=False)
        raw.resample(SAMPLING_RATE, npad="auto")

        events = mne.make_fixed_length_events(raw, id=1, duration=EPOCH_DURATION)
        epochs_mne = mne.Epochs(raw, events, tmin=0, tmax=EPOCH_DURATION-1/raw.info['sfreq'],
                                picks=available_channels, baseline=None, preload=True)
        data = epochs_mne.get_data(units='uV')

        annotations = mne.read_annotations(hypno_file)
        labels = np.zeros(len(epochs_mne), dtype=int)
        stage_map = {
            'Sleep stage W': 0,
            'Sleep stage 1': 1,
            'Sleep stage 2': 2,
            'Sleep stage 3': 3,
            'Sleep stage 4': 3,
            'Sleep stage R': 4
        }

        for annot in annotations:
            onset = int(annot['onset'] / EPOCH_DURATION)
            duration = int(annot['duration'] / EPOCH_DURATION)
            stage = annot['description']
            if stage in stage_map:
                for i in range(max(0, onset), min(len(epochs_mne), onset + duration)):
                    labels[i] = stage_map[stage]

        data = (
            (data - np.mean(data, axis=(1, 2), keepdims=True)) /
            np.std(data, axis=(1, 2), keepdims=True)
        )
        X = data.transpose(0, 2, 1)
        X_aug = augment_data(X)
        X = np.concatenate([X, X_aug])
        labels = np.concatenate([labels, labels])

        del raw, epochs_mne, data
        gc.collect()

        print(f"Processed subject {subject_id}, night {night}: {X.shape[0]} epochs")
        return X, labels
    except Exception as e:
        print(f"Error processing subject {subject_id}, night {night}: {e}")
        return None, None

class EEGDataGenerator(Sequence):
    def __init__(self, X, y, batch_size, augment=True, class_weights=None):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int32)
        self.batch_size = batch_size
        self.augment = augment
        self.class_weights = class_weights

    def __len__(self):
        return int(np.ceil(len(self.X) / self.batch_size))

    def __getitem__(self, idx):
        start = idx * self.batch_size
        end = min(start + self.batch_size, len(self.X))
        X_batch = self.X[start:end]
        y_batch = self.y[start:end]

        if self.augment:
            X_batch = augment_data(X_batch).astype(np.float32)

        sample_weights = np.ones_like(y_batch, dtype=np.float32)
        if self.class_weights:
            sample_weights = np.array([self.class_weights[label] for label in y_batch], dtype=np.float32)

        return X_batch, y_batch, sample_weights

class TemporalAttention(layers.Layer):
    def __init__(self, heads=4, key_dim=24):
        super().__init__()
        self.multi_head = layers.MultiHeadAttention(num_heads=heads, key_dim=key_dim)
        self.norm = layers.LayerNormalization()
        self.add = layers.Add()

    def call(self, inputs):
        attn_output = self.multi_head(inputs, inputs)
        out = self.add([inputs, attn_output])
        return self.norm(out)

def build_eegnet_attention(input_shape):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv1D(64, 7, padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(pool_size=2)(x)
    x = layers.Conv1D(128, 7, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(pool_size=2)(x)
    x = TemporalAttention(heads=4, key_dim=24)(x)
    x = layers.BatchNormalization()(x)
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(5, activation='softmax', dtype='float32')(x)

    model = models.Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer=tf.keras.optimizers.Adam(0.0005),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

def plot_training_curves(history):
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.tight_layout()
    plt.savefig('training_curves.png')
    plt.close()

def evaluate_model(model, X_test, y_test):
    test_loss, test_acc = model.evaluate(X_test, y_test, verbose=0)
    print(f"\nTest Accuracy: {test_acc:.4f}")
    print(f"Test Loss: {test_loss:.4f}")

    y_pred = model.predict(X_test, verbose=0)
    y_pred_classes = np.argmax(y_pred, axis=1)
    precision, recall, f1, _ = precision_recall_fscore_support(y_test, y_pred_classes, average=None)
    stage_names = ['Wake', 'N1', 'N2', 'N3', 'REM']
    print("\nPer-class Metrics:")
    for i, stage in enumerate(stage_names):
        print(f"{stage}: Precision={precision[i]:.4f}, Recall={recall[i]:.4f}, F1={f1[i]:.4f}")

    cm = confusion_matrix(y_test, y_pred_classes)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=stage_names, yticklabels=stage_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.png')
    plt.close()

def data_generator(available, batch_size=2000):
    for subject_id, night in available:
        X, y = process_subject_night(subject_id, night)
        if X is None or y is None:
            continue
        for i in range(0, len(X), batch_size):
            yield X[i:i+batch_size], y[i:i+batch_size]
        del X, y
        gc.collect()

def run_pipeline():
    available = get_available_subjects()
    if not available:
        return

    X_train, y_train, X_test, y_test = [], [], [], []
    for X_batch, y_batch in tqdm(data_generator(available), desc="Processing data"):
        if X_batch is None or y_batch is None:
            continue
        class_counts = np.bincount(y_batch)
        stratify = y_batch if min(class_counts[class_counts > 0]) >= 2 else None
        X_tr, X_te, y_tr, y_te = train_test_split(X_batch, y_batch, test_size=0.2, stratify=stratify, random_state=42)
        X_train.append(X_tr); y_train.append(y_tr)
        X_test.append(X_te); y_test.append(y_te)
        del X_batch, y_batch
        gc.collect()

    if not X_train:
        return

    X_train = np.concatenate(X_train)
    y_train = np.concatenate(y_train)
    X_test = np.concatenate(X_test)
    y_test = np.concatenate(y_test)

    class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
    class_weight_dict = dict(enumerate(class_weights))

    model = build_eegnet_attention(input_shape=(X_train.shape[1], X_train.shape[2]))
    train_generator = EEGDataGenerator(X_train, y_train, BATCH_SIZE, augment=True, class_weights=class_weight_dict)
    val_generator = EEGDataGenerator(X_test, y_test, BATCH_SIZE, augment=False, class_weights=class_weight_dict)

    history = model.fit(train_generator, validation_data=val_generator, epochs=EPOCHS, verbose=1)

    plot_training_curves(history)
    evaluate_model(model, X_test, y_test)

if __name__ == "__main__":
    run_pipeline()







scikit-learn version: 1.6.1


Checking availability:  15%|█▌        | 6/40 [00:00<00:02, 14.60it/s]

HTTP Error 404 fetching SC4021EC-Hypnogram.edf: Not Found
HTTP Error 404 fetching ST7132E0-PSG.edf: Not Found
HTTP Error 404 fetching ST7052E0-PSG.edf: Not Found
HTTP Error 404 fetching SC4011EC-Hypnogram.edf: Not Found
HTTP Error 404 fetching SC4141EC-Hypnogram.edf: Not Found
HTTP Error 404 fetching ST7062E0-PSG.edf: Not Found
HTTP Error 404 fetching ST7122E0-PSG.edf: Not Found
HTTP Error 404 fetching ST7072E0-PSG.edf: Not Found
HTTP Error 404 fetching ST7042E0-PSG.edf: Not Found
HTTP Error 404 fetching ST7022E0-PSG.edf: Not Found


Checking availability:  82%|████████▎ | 33/40 [00:00<00:00, 79.19it/s]

HTTP Error 404 fetching SC4171EC-Hypnogram.edf: Not Found
HTTP Error 404 fetching SC4191EC-Hypnogram.edf: Not Found


Checking availability:  82%|████████▎ | 33/40 [00:17<00:00, 79.19it/s]

Downloaded SC4201E0-PSG.edf


Checking availability: 100%|██████████| 40/40 [01:32<00:00,  2.30s/it]


Downloaded SC4201EC-Hypnogram.edf
Available subject-night pairs: [(2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1), (9, 1), (10, 1), (11, 1), (12, 1), (14, 1), (15, 1), (17, 1), (19, 1)]


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 2, night 1: 5640 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 3, night 1: 5140 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 4, night 1: 5444 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 5, night 1: 5540 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 6, night 1: 5620 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 7, night 1: 5592 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 8, night 1: 5464 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 9, night 1: 5440 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 10, night 1: 5284 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 11, night 1: 5572 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 12, night 1: 5628 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 14, night 1: 5240 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 15, night 1: 5252 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 17, night 1: 5512 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 19, night 1: 5608 epochs


Processing data: 45it [01:10,  1.57s/it]


Epoch 1/50


  self._warn_if_super_not_called()


[1m513/513[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 50ms/step - accuracy: 0.7763 - loss: 0.7913 - val_accuracy: 0.8523 - val_loss: 0.5651
Epoch 2/50
[1m513/513[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 31ms/step - accuracy: 0.8893 - loss: 0.4789 - val_accuracy: 0.8712 - val_loss: 0.4684
Epoch 3/50
[1m513/513[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 32ms/step - accuracy: 0.8993 - loss: 0.4367 - val_accuracy: 0.8910 - val_loss: 0.4174
Epoch 4/50
[1m513/513[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 31ms/step - accuracy: 0.9051 - loss: 0.4154 - val_accuracy: 0.8766 - val_loss: 0.4941
Epoch 5/50
[1m513/513[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 32ms/step - accuracy: 0.9143 - loss: 0.3788 - val_accuracy: 0.9207 - val_loss: 0.5186
Epoch 6/50
[1m513/513[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 32ms/step - accuracy: 0.9212 - loss: 0.3528 - val_accuracy: 0.9116 - val_loss: 0.4058
Epoch 7/50
[1m513/513[0m 