<a href="https://colab.research.google.com/github/sameekshya1999/Sleep-Stage-Classification-Using-Deep-Learning-CNN-vs.-EEGNet-Attention-/blob/main/eegnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install numpy tensorflow keras mne urllib3 scikit-learn tqdm matplotlib seaborn


Collecting mne
  Downloading mne-1.9.0-py3-none-any.whl.metadata (20 kB)
Downloading mne-1.9.0-py3-none-any.whl (7.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m93.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mne
Successfully installed mne-1.9.0


In [4]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.utils import Sequence
import mne
import urllib.request
import os
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn
import gc

warnings.filterwarnings("ignore", category=DeprecationWarning)
mne.set_log_level('ERROR')

NUM_SUBJECTS = 20
NUM_NIGHTS = 2
BASE_URL = "https://physionet.org/files/sleep-edfx/1.0.0/"
TARGET_CHANNELS = ['EEG Fpz-Cz', 'EEG Pz-Oz']
EPOCH_DURATION = 30
BATCH_SIZE = 128
EPOCHS = 50
SAMPLING_RATE = 50
F1 = 8  # Number of temporal filters
D = 2   # Depth multiplier for depthwise convolution
F2 = 16 # Number of pointwise filters

TELEMETRY_SUBJECTS = [2, 4, 5, 6, 7, 12, 13]

print(f"scikit-learn version: {sklearn.__version__}")

def fetch_data(subject_id, night, record_type='PSG'):
    try:
        dataset_id = subject_id + 1
        folder = "sleep-cassette" if night == 1 else "sleep-telemetry"

        if night == 1:
            prefix = f"SC4{dataset_id:02d}"
        else:
            if subject_id not in TELEMETRY_SUBJECTS:
                return None
            telemetry_map = {2: 702, 4: 704, 5: 705, 6: 706, 7: 707, 12: 712, 13: 713}
            prefix = f"ST{telemetry_map.get(subject_id, 700 + dataset_id)}"

        file_name = f"{prefix}{night if night == 1 else 2}E0-PSG.edf" if record_type == 'PSG' else \
                    f"{prefix}{night if night == 1 else 2}EC-Hypnogram.edf"
        url = f"{BASE_URL}{folder}/{file_name}"
        local_file = os.path.join("sleep_edf", file_name)
        os.makedirs("sleep_edf", exist_ok=True)

        if not os.path.exists(local_file):
            urllib.request.urlretrieve(url, local_file)
            print(f"Downloaded {file_name}")
        return local_file
    except urllib.error.HTTPError as e:
        print(f"HTTP Error {e.code} fetching {file_name}: {e.reason}")
        return None
    except Exception as e:
        print(f"Error fetching {file_name}: {e}")
        return None

def get_available_subjects():
    available = []
    with ThreadPoolExecutor(max_workers=10) as executor:
        futures = []
        for subject_id in range(NUM_SUBJECTS):
            for night in range(1, NUM_NIGHTS + 1):
                futures.append((
                    subject_id,
                    night,
                    executor.submit(
                        lambda s, n: (
                            fetch_data(s, n, 'PSG') is not None and
                            fetch_data(s, n, 'Hypnogram') is not None
                        ),
                        subject_id, night
                    )
                ))

        for subject_id, night, future in tqdm(futures, desc="Checking availability"):
            if future.result():
                available.append((subject_id, night))
    print(f"Available subject-night pairs: {available}")
    return available

def augment_data(X):
    noise = np.random.normal(0, 0.01, X.shape)
    shift = np.random.randint(-50, 50)
    X_aug = np.roll(X + noise, shift, axis=1)
    return X_aug

def process_subject_night(subject_id, night):
    try:
        psg_file = fetch_data(subject_id, night, 'PSG')
        hypno_file = fetch_data(subject_id, night, 'Hypnogram')
        if psg_file is None or hypno_file is None:
            print(f"Skipping subject {subject_id}, night {night}: Missing files")
            return None, None

        raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
        available_channels = [ch for ch in TARGET_CHANNELS if ch in raw.ch_names]
        if len(available_channels) != len(TARGET_CHANNELS):
            print(f"Not all target channels found for subject {subject_id}, night {night}")
            return None, None
        raw.pick_channels(available_channels)

        raw.load_data()
        raw.filter(0.5, 40.0, l_trans_bandwidth=0.5, h_trans_bandwidth=10.0, verbose=False)
        raw.resample(SAMPLING_RATE, npad="auto")

        events = mne.make_fixed_length_events(raw, id=1, duration=EPOCH_DURATION)
        epochs_mne = mne.Epochs(raw, events, tmin=0, tmax=EPOCH_DURATION-1/raw.info['sfreq'],
                                picks=available_channels, baseline=None, preload=True)
        data = epochs_mne.get_data(units='uV')

        annotations = mne.read_annotations(hypno_file)
        labels = np.zeros(len(epochs_mne), dtype=int)
        stage_map = {
            'Sleep stage W': 0,
            'Sleep stage 1': 1,
            'Sleep stage 2': 2,
            'Sleep stage 3': 3,
            'Sleep stage 4': 3,
            'Sleep stage R': 4
        }

        for annot in annotations:
            onset = int(annot['onset'] / EPOCH_DURATION)
            duration = int(annot['duration'] / EPOCH_DURATION)
            stage = annot['description']
            if stage in stage_map:
                for i in range(max(0, onset), min(len(epochs_mne), onset + duration)):
                    labels[i] = stage_map[stage]

        data = (
            (data - np.mean(data, axis=(1, 2), keepdims=True)) /
            np.std(data, axis=(1, 2), keepdims=True)
        )
        X = data.transpose(0, 2, 1)
        X_aug = augment_data(X)
        X = np.concatenate([X, X_aug])
        labels = np.concatenate([labels, labels])

        del raw, epochs_mne, data
        gc.collect()

        print(f"Processed subject {subject_id}, night {night}: {X.shape[0]} epochs")
        return X, labels
    except Exception as e:
        print(f"Error processing subject {subject_id}, night {night}: {e}")
        return None, None

class EEGDataGenerator(Sequence):
    def __init__(self, X, y, batch_size, augment=True, class_weights=None):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int32)
        self.batch_size = batch_size
        self.augment = augment
        self.class_weights = class_weights

    def __len__(self):
        return int(np.ceil(len(self.X) / self.batch_size))

    def __getitem__(self, idx):
        start = idx * self.batch_size
        end = min(start + self.batch_size, len(self.X))
        X_batch = self.X[start:end]
        y_batch = self.y[start:end]

        if self.augment:
            X_batch = augment_data(X_batch).astype(np.float32)

        sample_weights = np.ones_like(y_batch, dtype=np.float32)
        if self.class_weights:
            sample_weights = np.array([self.class_weights[label] for label in y_batch], dtype=np.float32)

        return X_batch, y_batch, sample_weights

class ExpandDimsLayer(layers.Layer):
    def call(self, inputs):
        return tf.expand_dims(inputs, axis=-1)

def build_eegnet(input_shape, nb_classes=5, F1=8, D=2, F2=16, dropout_rate=0.5):
    """
    Implementation of EEGNet as per Lawhern et al. (2018).
    input_shape: (time_steps, channels)
    """
    input_layer = layers.Input(shape=input_shape)

    # Block 1: Temporal convolution
    expanded_input = ExpandDimsLayer()(input_layer)
    block1 = layers.Conv2D(F1, (1, 64), padding='same', use_bias=False)(expanded_input)
    block1 = layers.BatchNormalization()(block1)
    block1 = layers.DepthwiseConv2D((input_shape[1], 1), depth_multiplier=D, padding='valid',
                                    use_bias=False, depthwise_constraint=tf.keras.constraints.max_norm(1.))(block1)
    block1 = layers.BatchNormalization()(block1)
    block1 = layers.Activation('elu')(block1)
    # Corrected pooling to pool across time dimension
    block1 = layers.AveragePooling2D((4, 1))(block1)
    block1 = layers.Dropout(dropout_rate)(block1)

    # Block 2: Separable convolution
    block2 = layers.SeparableConv2D(F2, (1, 16), padding='same', use_bias=False)(block1)
    block2 = layers.BatchNormalization()(block2)
    block2 = layers.Activation('elu')(block2)
    # Corrected pooling to pool across time dimension
    block2 = layers.AveragePooling2D((8, 1))(block2)
    block2 = layers.Dropout(dropout_rate)(block2)

    # Output layer
    flatten = layers.Flatten()(block2)
    output = layers.Dense(nb_classes, activation='softmax',
                         kernel_constraint=tf.keras.constraints.max_norm(0.25))(flatten)

    model = models.Model(inputs=input_layer, outputs=output)
    model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

def plot_training_curves(history):
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.tight_layout()
    plt.savefig('training_curves.png')
    plt.close()

def evaluate_model(model, X_test, y_test):
    test_loss, test_acc = model.evaluate(X_test, y_test, verbose=0)
    print(f"\nTest Accuracy: {test_acc:.4f}")
    print(f"Test Loss: {test_loss:.4f}")

    y_pred = model.predict(X_test, verbose=0)
    y_pred_classes = np.argmax(y_pred, axis=1)
    precision, recall, f1, _ = precision_recall_fscore_support(y_test, y_pred_classes, average=None)
    stage_names = ['Wake', 'N1', 'N2', 'N3', 'REM']
    print("\nPer-class Metrics:")
    for i, stage in enumerate(stage_names):
        print(f"{stage}: Precision={precision[i]:.4f}, Recall={recall[i]:.4f}, F1={f1[i]:.4f}")

    cm = confusion_matrix(y_test, y_pred_classes)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=stage_names, yticklabels=stage_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.png')
    plt.close()

def data_generator(available, batch_size=2000):
    for subject_id, night in available:
        X, y = process_subject_night(subject_id, night)
        if X is None or y is None:
            continue
        for i in range(0, len(X), batch_size):
            yield X[i:i+batch_size], y[i:i+batch_size]
        del X, y
        gc.collect()

def run_pipeline():
    available = get_available_subjects()
    if not available:
        return

    X_train, y_train, X_test, y_test = [], [], [], []
    for X_batch, y_batch in tqdm(data_generator(available), desc="Processing data"):
        if X_batch is None or y_batch is None:
            continue
        class_counts = np.bincount(y_batch)
        stratify = y_batch if min(class_counts[class_counts > 0]) >= 2 else None
        X_tr, X_te, y_tr, y_te = train_test_split(X_batch, y_batch, test_size=0.2, stratify=stratify, random_state=42)
        X_train.append(X_tr); y_train.append(y_tr)
        X_test.append(X_te); y_test.append(y_te)
        del X_batch, y_batch
        gc.collect()

    if not X_train:
        return

    X_train = np.concatenate(X_train)
    y_train = np.concatenate(y_train)
    X_test = np.concatenate(X_test)
    y_test = np.concatenate(y_test)

    class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
    class_weight_dict = dict(enumerate(class_weights))

    model = build_eegnet(input_shape=(X_train.shape[1], X_train.shape[2]), nb_classes=5, F1=F1, D=D, F2=F2)
    train_generator = EEGDataGenerator(X_train, y_train, BATCH_SIZE, augment=True, class_weights=class_weight_dict)
    val_generator = EEGDataGenerator(X_test, y_test, BATCH_SIZE, augment=False, class_weights=class_weight_dict)

    history = model.fit(train_generator, validation_data=val_generator, epochs=EPOCHS, verbose=1)

    plot_training_curves(history)
    evaluate_model(model, X_test, y_test)

if __name__ == "__main__":
    run_pipeline()

scikit-learn version: 1.6.1


Checking availability:   2%|▎         | 1/40 [00:00<00:15,  2.46it/s]

HTTP Error 404 fetching ST7022E0-PSG.edf: Not Found
HTTP Error 404 fetching ST7062E0-PSG.edf: Not Found
HTTP Error 404 fetching ST7122E0-PSG.edf: Not Found
HTTP Error 404 fetching SC4021EC-Hypnogram.edf: Not Found
HTTP Error 404 fetching ST7052E0-PSG.edf: Not Found
HTTP Error 404 fetching SC4141EC-Hypnogram.edf: Not Found
HTTP Error 404 fetching ST7132E0-PSG.edf: Not Found
HTTP Error 404 fetching SC4011EC-Hypnogram.edf: Not Found
HTTP Error 404 fetching ST7072E0-PSG.edf: Not Found
HTTP Error 404 fetching ST7042E0-PSG.edf: Not Found


Checking availability: 100%|██████████| 40/40 [00:00<00:00, 53.83it/s]


HTTP Error 404 fetching SC4171EC-Hypnogram.edf: Not Found
HTTP Error 404 fetching SC4191EC-Hypnogram.edf: Not Found
Available subject-night pairs: [(2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1), (9, 1), (10, 1), (11, 1), (12, 1), (14, 1), (15, 1), (17, 1), (19, 1)]


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 2, night 1: 5640 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 3, night 1: 5140 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 4, night 1: 5444 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 5, night 1: 5540 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 6, night 1: 5620 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 7, night 1: 5592 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 8, night 1: 5464 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 9, night 1: 5440 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 10, night 1: 5284 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 11, night 1: 5572 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 12, night 1: 5628 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 14, night 1: 5240 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 15, night 1: 5252 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 17, night 1: 5512 epochs


  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)
  raw = mne.io.read_raw_edf(psg_file, preload=False, verbose=False)


Processed subject 19, night 1: 5608 epochs


Processing data: 45it [01:02,  1.38s/it]


Epoch 1/50


  self._warn_if_super_not_called()


[1m512/513[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 54ms/step - accuracy: 0.6218 - loss: 1.3474

  self._warn_if_super_not_called()


[1m513/513[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 57ms/step - accuracy: 0.6221 - loss: 1.3469 - val_accuracy: 0.7075 - val_loss: 1.2804
Epoch 2/50
[1m513/513[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 48ms/step - accuracy: 0.7397 - loss: 1.0913 - val_accuracy: 0.6852 - val_loss: 1.0064
Epoch 3/50
[1m513/513[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 48ms/step - accuracy: 0.7501 - loss: 1.0373 - val_accuracy: 0.7610 - val_loss: 0.9624
Epoch 4/50
[1m513/513[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 48ms/step - accuracy: 0.7655 - loss: 1.0001 - val_accuracy: 0.7675 - val_loss: 0.9628
Epoch 5/50
[1m513/513[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 48ms/step - accuracy: 0.7727 - loss: 0.9608 - val_accuracy: 0.7878 - val_loss: 0.9292
Epoch 6/50
[1m513/513[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 48ms/step - accuracy: 0.7696 - loss: 