In [None]:
!pip install scipy
!pip install gdown

In [None]:
import gdown

folder_id = "1IBMYahLwJ5pY9b-f4gCm_LZPCcYccH6h"

gdown.download_folder(
    id=folder_id,
    quiet=False,
    use_cookies=False
)

In [None]:
from scipy.io import loadmat
import numpy as np
import pandas as pd
import scipy.io
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from sklearn.preprocessing import MinMaxScaler, StandardScaler

In [None]:
sampling_freq = 128
FOCUSED_CLASS = 0
UNFOCUSED_CLASS = 1
DROWNSY_CLASS = 2

In [None]:
path = "/content/EEG Data"

In [None]:
import os

files = os.listdir(path)
len(files)

In [None]:
columns = [
    'ED_COUNTER',    'ED_INTERPOLATED',    'ED_RAW_CQ',    'ED_AF3',    'ED_F7',
    'ED_F3',    'ED_FC5',    'ED_T7',    'ED_P7',    'ED_O1',
    'ED_O2',    'ED_P8',    'ED_T8',    'ED_FC6',    'ED_F4',
    'ED_F8',    'ED_AF4',    'ED_GYROX',    'ED_GYROY',    'ED_TIMESTAMP',
    'ED_ES_TIMESTAMP',    'ED_FUNC_ID',    'ED_FUNC_VALUE',    'ED_MARKER',    'ED_SYNC_SIGNAL'
]

In [None]:
def get_state(timestamp):
    if timestamp <= 10*128*60:
        return FOCUSED_CLASS
    elif timestamp > 20*128*60:
        return UNFOCUSED_CLASS
    else:
        return DROWNSY_CLASS

# Scale data
scaler = StandardScaler()

In [None]:
def get_EEG_data(data_root, filename):
    hz = sampling_freq
    mat = scipy.io.loadmat(data_root +"/"+ filename)
    data = mat["o"]["data"][0,0]
    eeg_df = pd.DataFrame(data, columns=columns)
    eeg_df = eeg_df.filter(['ED_AF3', 'ED_F7', 'ED_F3', 'ED_FC5',
                            'ED_T7', 'ED_P7', 'ED_O1', 'ED_O2',
                            'ED_P8', 'ED_T8', 'ED_FC6', 'ED_F4',
                            'ED_F8', 'ED_AF4'])
    labels = ['AF3','F7', 'F3','FC5','T7','P7','O1','O2','P8','T8', 'FC6','F4','F8','AF4']
    eeg_df.columns = labels
    eeg_df = pd.DataFrame(scaler.fit_transform(eeg_df), columns=eeg_df.columns)
    eeg_df.reset_index(inplace=True)
    eeg_df.rename(columns={'index': 'timestamp'}, inplace=True)

    eeg_df['state'] = eeg_df['timestamp'].apply(get_state)

    return eeg_df

In [None]:
dataset = []
# For each file, print # minutes of data
for filename in files:
    data = get_EEG_data(path, filename)
    dataset.append(data)

In [None]:
type(dataset), type(dataset[0])

In [None]:
def split_epochs(data, hz, epoch_length=4, step_size=0.125):
  step = int(epoch_length * hz - step_size * hz)
  offset = int(epoch_length * hz)
  starts = []
  current = 0

# Generate the first series
  while current + offset <= data.shape[0]:
    starts.append(current)
    current += step

  # Generate the second series using a list comprehension
  ends = [x + offset for x in starts]

  # Lưu trữ các epoch
  epochs = []

  # Cắt các epoch từ tín hiệu
  for i in range(len(starts)):
    epoch = data.iloc[starts[i]:ends[i]]
    epochs.append(epoch)

  return epochs

In [None]:
epochs_data = []
for eeg in dataset:
  epochs = split_epochs(eeg, sampling_freq)
  for epoch in epochs:  # Iterate directly over the epochs
    epochs_data.append(epoch)  # Append each DataFrame to the list

In [None]:
len(epochs_data[0])

In [None]:
len(epochs_data), type(epochs_data), type(epochs_data[0])

In [None]:
from torch.utils.data import Dataset # Added this import

class EEGDataset(Dataset):
  def __init__(self, dataframes, target_column='state', wavelet='db6', level=4):
    self.data = []
    self.targets = []
    self.scaler = StandardScaler()

    print(f"Processing {len(dataframes)} dataframes...")

    for df in dataframes:
      # Extract target
      self.targets.append(df[target_column].mode()[0])

      # Process features
      feature = df.drop(columns=[target_column, 'timestamp'], errors='ignore')
      self.data.append(feature.values)

    # Convert lists to tensors AFTER the loop
    self.data = torch.tensor(self.data, dtype=torch.float32)
    self.targets = torch.tensor(self.targets, dtype=torch.long) # Changed dtype to torch.long

  def __len__(self):
    return len(self.targets)  # Should match the number of targets, not individual rows

  def __getitem__(self, idx):
    return self.data[idx], self.targets[idx]

In [None]:
# Create dataset
dataset = EEGDataset(epochs_data)

In [None]:
dataset.targets, len(dataset)

In [None]:
dataset.data[0].shape

In [None]:
class TransformerClassifier(nn.Module):
    def __init__(self, input_dim=14, num_classes=3, seq_len=512, d_model=64, nhead=4, num_layers=2):
        super().__init__()
        self.embedding = nn.Linear(input_dim, d_model)
        self.pos_encoding = nn.Parameter(torch.randn(1, seq_len, d_model))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=2 * d_model,
            dropout=0.1,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(d_model, num_classes)
        )

    def forward(self, x):
        x = self.embedding(x) + self.pos_encoding[:, :x.size(1), :]
        x = self.encoder(x)
        x = x.permute(0, 2, 1)
        return self.classifier(x)

In [None]:
!pip install pytorch-lightning

In [None]:
from pytorch_lightning.core.module import LightningModule # Added this import

class LitTransformer(LightningModule):
    def __init__(self, input_dim=14, num_classes=2, seq_len=512, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.model = TransformerClassifier(input_dim, num_classes, seq_len)
        self.criterion = nn.CrossEntropyLoss()

        # Track metrics
        self.train_losses = []
        self.train_accuracies = []
        self.val_losses = []
        self.val_accuracies = []

        # Internal accumulators for training metrics
        self._train_epoch_losses = []
        self._train_epoch_accs = []

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        X, y = batch
        logits = self(X)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        acc = (preds == y).float().mean()

        # Store per-batch metrics to average later
        self._train_epoch_losses.append(loss.item())
        self._train_epoch_accs.append(acc.item())

        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def on_train_epoch_end(self):
        # Compute mean metrics for this epoch
        if self._train_epoch_losses:
            avg_loss = np.mean(self._train_epoch_losses)
            avg_acc = np.mean(self._train_epoch_accs)

            self.train_losses.append(avg_loss)
            self.train_accuracies.append(avg_acc)

            print(f"\nEpoch {self.current_epoch + 1}: train_loss={avg_loss:.4f}, train_acc={avg_acc * 100:.2f}%")

            # Reset accumulators
            self._train_epoch_losses.clear()
            self._train_epoch_accs.clear()

    def validation_step(self, batch, batch_idx):
        X, y = batch
        logits = self(X)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        acc = (preds == y).float().mean()

        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', acc, on_epoch=True, prog_bar=True)
        return {'val_loss': loss, 'val_acc': acc}

    def on_validation_epoch_end(self):
        metrics = self.trainer.callback_metrics
        val_loss = metrics['val_loss'].item()
        val_acc = metrics['val_acc'].item()

        self.val_losses.append(val_loss)
        self.val_accuracies.append(val_acc)
        print(f"Epoch {self.current_epoch + 1}: val_loss={val_loss:.4f}, val_acc={val_acc * 100:.2f}%")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

In [None]:
train_size = int(0.8 * len(dataset))  # 80% cho training
val_size = int(0.10 * len(dataset))  # 10% cho validation
test_size = len(dataset) - train_size - val_size # 10% cho validation

# Chia dataset thành train và val
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Tạo DataLoader cho train và val
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
model = LitTransformer(
    input_dim=14,
    num_classes=3,  # Changed from 2 to 3 to accommodate DROWNSY_CLASS
    seq_len=512,
    lr=1e-3
)

In [None]:
from pytorch_lightning import Trainer

In [None]:
trainer = Trainer(
    max_epochs=10,
    accelerator='auto',
    devices=1,
)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Ensure consistent length
min_len = min(len(model.train_losses), len(model.val_losses))
epochs = range(1, min_len + 1)

plt.figure(figsize=(12, 5))

# ------------------------------
# Plot LOSS
# ------------------------------
plt.subplot(1, 2, 1)
plt.plot(epochs, model.train_losses[:min_len], label='Train Loss', marker='o')
plt.plot(epochs, model.val_losses[:min_len], label='Validation Loss', marker='o')
plt.title('Epoch vs Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# ------------------------------
# Plot ACCURACY
# ------------------------------
min_len_acc = min(len(model.train_accuracies), len(model.val_accuracies))
epochs_acc = range(1, min_len_acc + 1)

plt.subplot(1, 2, 2)
plt.plot(epochs_acc, [x * 100 for x in model.train_accuracies[:min_len_acc]],
         label='Train Accuracy', marker='o')
plt.plot(epochs_acc, [x * 100 for x in model.val_accuracies[:min_len_acc]],
         label='Validation Accuracy', marker='o', color='green')
plt.title('Epoch vs Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()