In [26]:
import mne
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.utils.class_weight import compute_class_weight
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset



# Load raw signal
edf_path = "/Users/veeralpatel/ECE284FinalProject/sleep-cassette/SC4001E0-PSG.edf"
ann_path = "/Users/veeralpatel/ECE284FinalProject/sleep-cassette/SC4001EC-Hypnogram.edf"
raw = mne.io.read_raw_edf(edf_path, preload=True)
annotations = mne.read_annotations(ann_path)
raw.set_annotations(annotations)
raw.pick_channels(["EEG Fpz-Cz", "EOG horizontal", "Temp rectal", "EMG submental"])
raw.resample(100)
print("Raw annotation labels (before mapping):")
print(set(raw.annotations.description))



Extracting EDF parameters from /Users/veeralpatel/ECE284FinalProject/sleep-cassette/SC4001E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 7949999  =      0.000 ... 79499.990 secs...


  raw = mne.io.read_raw_edf(edf_path, preload=True)
  raw = mne.io.read_raw_edf(edf_path, preload=True)
  raw = mne.io.read_raw_edf(edf_path, preload=True)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Sampling frequency of the instance is already 100.0, returning unmodified.
Raw annotation labels (before mapping):
{'Sleep stage ?', 'Sleep stage 2', 'Sleep stage 1', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage 3', 'Sleep stage W'}


  raw.set_annotations(annotations)


In [27]:
# Step 1: Remap annotations to sleep stages, skip "?"
stage_map = {
    "Sleep stage W": "W",
    "Sleep stage 1": "N1",
    "Sleep stage 2": "N2",
    "Sleep stage 3": "N3",
    "Sleep stage 4": "N3",
    "Sleep stage R": "R",
    "Sleep stage ?": None
}

cleaned_onset, cleaned_duration, cleaned_desc = [], [], []

for onset, duration, desc in zip(raw.annotations.onset, raw.annotations.duration, raw.annotations.description):
    new_desc = stage_map.get(desc)
    if new_desc is not None:
        cleaned_onset.append(onset)
        cleaned_duration.append(duration)
        cleaned_desc.append(new_desc)

raw.set_annotations(mne.Annotations(
    onset=cleaned_onset,
    duration=cleaned_duration,
    description=cleaned_desc
))


Unnamed: 0,General,General.1
,Filename(s),SC4001E0-PSG.edf
,MNE object type,RawEDF
,Measurement date,1989-04-24 at 16:13:00 UTC
,Participant,X
,Experimenter,Unknown
,Acquisition,Acquisition
,Duration,22:05:00 (HH:MM:SS)
,Sampling frequency,100.00 Hz
,Time points,7950000
,Channels,Channels


In [28]:
events, event_id = mne.events_from_annotations(raw)
valid_event_id = {k: v for k, v in event_id.items() if k in ["W", "N1", "N2", "N3", "R"]}
epochs = mne.Epochs(raw, events, event_id=valid_event_id,
                    tmin=0, tmax=30.0, baseline=None, preload=True)

stage_to_class = {"W": 0, "N1": 1, "N2": 2, "N3": 3, "R": 4}
inverse_id = {v: stage_to_class[k] for k, v in valid_event_id.items()}
y_labels = [inverse_id[e[-1]] for e in epochs.events]

X = epochs.get_data()
X_tensor = torch.tensor(X, dtype=torch.float32).permute(0, 2, 1)
y_tensor = torch.tensor(y_labels, dtype=torch.long)

print("✅ Final Shapes:", X_tensor.shape, y_tensor.shape)
print("✅ Unique Labels:", set(y_tensor.numpy()))


Used Annotations descriptions: ['N1', 'N2', 'N3', 'R', 'W']
Not setting metadata
153 matching events found
No baseline correction applied
0 projection items activated


Using data from preloaded Raw for 153 events and 3001 original time points ...
0 bad epochs dropped
✅ Final Shapes: torch.Size([153, 3001, 4]) torch.Size([153])
✅ Unique Labels: {0, 1, 2, 3, 4}


In [29]:
# Define CNN-BiLSTM model
class CNN_BiLSTM_Model(nn.Module):
    def __init__(self, input_channels=4, num_classes=5):
        super(CNN_BiLSTM_Model, self).__init__()
        self.conv1 = nn.Conv1d(input_channels, 64, kernel_size=5, stride=1)
        self.bn1 = nn.BatchNorm1d(64)
        self.pool1 = nn.MaxPool1d(kernel_size=2)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(128)
        self.lstm = nn.LSTM(input_size=128, hidden_size=128, bidirectional=True, batch_first=True)
        self.fc1 = nn.Linear(256, 64)
        self.fc2 = nn.Linear(64, num_classes)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = F.relu(self.bn2(self.conv2(x)))
        x = x.permute(0, 2, 1)
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = F.relu(self.fc1(x))
        return self.fc2(x)

In [30]:
# Train model on Sleep-EDF
dataset = TensorDataset(X_tensor, y_tensor)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = CNN_BiLSTM_Model(input_channels=4, num_classes=5)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

EPOCHS = 10
for epoch in range(EPOCHS):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for xb, yb in loader:
        optimizer.zero_grad()
        out = model(xb)
        loss = criterion(out, yb)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (out.argmax(1) == yb).sum().item()
        total += yb.size(0)

    acc = correct / total
    print(f"Epoch {epoch+1}: Loss = {total_loss:.4f}, Acc = {acc:.4f}")


Epoch 1: Loss = 7.7364, Acc = 0.4575
Epoch 2: Loss = 7.2166, Acc = 0.4641
Epoch 3: Loss = 6.8304, Acc = 0.4641
Epoch 4: Loss = 6.5309, Acc = 0.4641
Epoch 5: Loss = 6.5129, Acc = 0.4641
Epoch 6: Loss = 6.5131, Acc = 0.4641
Epoch 7: Loss = 6.4771, Acc = 0.4641
Epoch 8: Loss = 6.5056, Acc = 0.4641
Epoch 9: Loss = 6.4849, Acc = 0.4641
Epoch 10: Loss = 6.4333, Acc = 0.4641


In [31]:
torch.save(model.state_dict(), "sleep_edf_pretrained.pt")