# Notebook for training Transformer locally

It is highly recommended to use CUDA for training. Make sure you have NVIDIA GPU, then type this command in terminal:  
`nvidia-smi`  
Check for CUDA Version (should be > 11.8). After that, install torch with CUDA support with this command:  
`pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118_`


# Import libraries


In [83]:
import mne
import torch
import torch.nn as nn
import numpy as np
import os
import csv
from eeg_logger import logger

# Transformer module


In [84]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, nhead):
        """
        :param int d_model: the number of expected features in the encoder/decoder inputs (default=512)
        :param int nhead: the number of heads in the multiheadattention models (default=8)
        """
        super(TransformerBlock, self).__init__()
        self.attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)

        # feed-forward layer
        self.ff = nn.Sequential(nn.Linear(d_model, 512), nn.ReLU(), nn.Linear(512, d_model))

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        attn_output, _ = self.attn(x, x, x)
        """
        x + attn_output
        To residual connection, czyli technika projektowania sieci neuronowych,
        która pozwala warstwom pomijać się nawzajem, co pomaga w szkoleniu głębszych sieci.
        Dodajesz oryginalne dane wejściowe (x) do danych wyjściowych uwagi.
        """
        x = self.norm1(x + attn_output)
        ff_output = self.ff(x)
        x = self.norm2(x + ff_output)
        return x

# Positional encoding


In [85]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=1000):
        """
        :param int d_model: the number of expected features in the encoder/decoder inputs (default=512)
        :param int max_len: maximum sequence length
        """
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        return x + self.pe[: x.size(0), :]

# Spatial transformer


In [86]:
# Learns dependencies between channels
class SpatialTransformer(nn.Module):
    def __init__(self, input_size, d_model, nhead, num_classes):
        super(SpatialTransformer, self).__init__()
        self.embedding = nn.Linear(input_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        self.transformer = nn.Sequential(
            TransformerBlock(d_model, nhead), TransformerBlock(d_model, nhead), TransformerBlock(d_model, nhead)
        )
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x: torch.Tensor):
        x = x.permute(1, 0, 2)  # (channels, batch, time)
        x = self.embedding(x)
        x = self.pos_encoder(x)
        x = self.transformer(x)
        x = x.mean(dim=0)
        return self.fc(x)

# Temporal transformer


In [87]:
class TemporalTransformer(nn.Module):
    def __init__(self, input_size, d_model, nhead, num_classes):
        super(TemporalTransformer, self).__init__()
        self.embedding = nn.Linear(input_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        self.transformer = nn.Sequential(
            TransformerBlock(d_model, nhead), TransformerBlock(d_model, nhead), TransformerBlock(d_model, nhead)
        )
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x: torch.Tensor):  # x shape: (batch, channels, time)
        x = x.permute(2, 0, 1)  # (time, batch, channels)
        x = self.embedding(x)
        x = self.pos_encoder(x)
        x = self.transformer(x)
        x = x.mean(dim=0)
        return self.fc(x)

# Spatial CNN transformer

_In the spatial implementation of the CNN + Transformer
model, the CNN module included two
convolutional layers and one average pooling layer. In the first
convolutional layer, we used 64 kernels with the size of 1 × 16
(channel × time points) to extract EEG temporal information,
and adopted the SAME padding. The average pooling layer
had the pooling size of 1 × 32. The second convolutional
layer used 64 kernels with the size of 1 × 15, and adopted
the VALID padding_


In [88]:
# Uses CNN to help extract temporal features per channel, then model channel relations
class SpatialCNNTransformer(nn.Module):
    def __init__(self, input_size, d_model, nhead, num_classes):
        super(SpatialCNNTransformer, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, (1, 16), padding="same"),
            nn.ReLU(),
            nn.AvgPool2d((1, 32)),
            nn.Conv2d(64, 64, (1, 15)),
            nn.ReLU(),
        )
        self.embedding = nn.Linear(64, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        self.transformer = nn.Sequential(*[TransformerBlock(d_model, nhead) for _ in range(3)])
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x: torch.Tensor):  # x shape: (batch, 1, channels, time)
        x = self.cnn(x).squeeze(3).permute(2, 0, 1)
        x = self.embedding(x)
        x = self.pos_encoder(x)
        x = self.transformer(x)
        x = x.mean(dim=0)
        return self.fc(x)

# Preparing dataset


In [89]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torchmetrics.classification import Accuracy


class EEGDataset(Dataset):
    def __init__(self, X, y, cnn_mode=False):
        self.X = torch.tensor(X, dtype=torch.float32)
        if cnn_mode:
            # Bo modele CNN potrzebuja 4D na feature
            self.X = self.X.unsqueeze(1)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Training model

**Training params from paper:**

- _Empirically, the number of
  head in each multi-head attention layer was set to 8 [25]._
- _The dropout rate was set to 0.3._
- _The parameter of the position-wise fully connected feed-forward layer with a ReLU activation was set to 512._
- _The weight attenuation was 0.0001._
- _All the models used the Adam optimizer. The training epoch was set
  to 50._
- _We set the number of training epochs to 10_
- _The EEG data were transformed into 3D tensors (N, C, T), where N is the number of trials, C is the number of channels, and T is the time points._
- _In our Transformer-based models,
  we set dk = dv = 64, which was the same size as EEG
  channel numbers._


In [90]:
def load_subject_data(file_path: str) -> tuple[np.ndarray, np.ndarray]:
    """
    :param file_path str: filepath to file with subject data
    """
    epochs = mne.read_epochs(file_path, preload=True)
    """
    Format danych: (Number of epochs, channels, n_times)
    Dla danych 3-sekundowych: 3 sekundy x 160 Hz = 480, n_times = 480
    Dla danych 6-sekundowych: 6 sekund x 160 Hz = 960, n_times = 960
    """
    X = epochs.get_data()
    # Labels
    y = epochs.events[:, -1]
    # Labels should be numered 0, 1, 2 ...
    y = np.array([0 if label == 2 else 1 for label in y])
    return X, y

In [91]:
PREPROCESSED_DATA_DIR = "./preprocessed_data/Physionet"
D_MODEL = 64
NUM_HEADS = 8
NUM_CLASSES = 2
BATCH_SIZE = 32
NUM_EPOCHS = 50
WEIGHT_DECAY = 0.0001
LEARNING_RATE = 0.0007
MODELS = {"SpatialTransformer", "TemporalTransformer"}
CNN_MODELS = {"SpatialCNNTransformer"}


def train_model(model, train_loader, device: torch.device):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            output = model(X_batch)
            loss = criterion(output, y_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        # logger.info(f"Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {total_loss:.4f}")


def evaluate_model(model, test_loader, device) -> float:
    acc = Accuracy(task="binary").to(device)
    model.eval()

    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            output = model(X_batch)
            preds = torch.argmax(output, dim=1)
            acc.update(preds, y_batch)

    return acc.compute().item()

# Training within one patient

_Here we train and test model only on data that belongs to one patient_


In [92]:
def train_within_individual(model_class):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for subj_folder in os.listdir(PREPROCESSED_DATA_DIR):
        subj_folder_path = os.path.join(PREPROCESSED_DATA_DIR, subj_folder)
        file_path = os.path.join(subj_folder_path, f"PA{subj_folder[1:]}-3s-epo.fif")

        # 5 subjects for now
        if not os.path.exists(file_path) or subj_folder[1:] == "006":
            break

        logger.info(f"\nTraining model for subject {subj_folder}...")
        X, y = load_subject_data(file_path)

        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

        cnn_mode = model_class.__name__ in CNN_MODELS

        train_dataset = EEGDataset(X_train, y_train, cnn_mode)
        test_dataset = EEGDataset(X_test, y_test, cnn_mode)

        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

        model = model_class(input_size=X.shape[2], d_model=D_MODEL, nhead=NUM_HEADS, num_classes=NUM_CLASSES)
        train_model(model, train_loader=train_loader, device=device)

        accuracy = evaluate_model(model, test_loader, device)
        logger.info(f"Accuracy for subject {subj_folder}: {accuracy * 100:.2f}%")


# train_within_individual(SpatialTransformer)

# Cross patient training

_Here model is trained on data from all other subjects_


In [96]:
def train_cross_individual(model_class):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    results = []
    total_correct = 0
    total_samples = 0

    all_subjects = sorted(os.listdir(PREPROCESSED_DATA_DIR))
    subject_data = {}

    # Load all data once
    for subj_folder in all_subjects:
        subj_folder_path = os.path.join(PREPROCESSED_DATA_DIR, subj_folder)
        file_path = os.path.join(subj_folder_path, f"PA{subj_folder[1:]}-3s-epo.fif")
        if os.path.exists(file_path):
            X, y = load_subject_data(file_path)
            subject_data[subj_folder] = (X, y)

    # One subject for now
    for test_subj in ["S001"]:
        logger.info(f"\nTesting on subject {test_subj}...")

        X_test, y_test = subject_data[test_subj]

        X_train_all = []
        y_train_all = []

        for train_subj, (X, y) in subject_data.items():
            if train_subj != test_subj:
                X_train_all.append(X)
                y_train_all.append(y)

        X_train = np.concatenate(X_train_all, axis=0)
        y_train = np.concatenate(y_train_all, axis=0)

        cnn_mode = model_class.__name__ in CNN_MODELS
        train_dataset = EEGDataset(X_train, y_train, cnn_mode)
        test_dataset = EEGDataset(X_test, y_test, cnn_mode)

        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

        model = model_class(input_size=X_train.shape[2], d_model=D_MODEL, nhead=NUM_HEADS, num_classes=NUM_CLASSES)
        train_model(model, train_loader=train_loader, device=device)

        accuracy = evaluate_model(model, test_loader, device)
        logger.info(f"Accuracy on subject {test_subj}: {accuracy * 100:.2f}%")
        results.append([test_subj, accuracy * 100])

        total_samples += len(test_dataset)
        total_correct += accuracy * len(test_dataset)

    overall_accuracy = total_correct / total_samples
    logger.info(f"\nOverall cross-subject accuracy: {overall_accuracy * 100:.2f}%")


train_cross_individual(SpatialCNNTransformer)

Reading c:\projects\eeg_transformer\preprocessed_data\Physionet\S001\PA001-3s-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =    2000.00 ...    5000.00 ms
        0 CTF compensation matrices available
Not setting metadata
42 matching events found
No baseline correction applied
0 projection items activated
Reading c:\projects\eeg_transformer\preprocessed_data\Physionet\S002\PA002-3s-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =    2000.00 ...    5000.00 ms
        0 CTF compensation matrices available
Not setting metadata
42 matching events found
No baseline correction applied
0 projection items activated
Reading c:\projects\eeg_transformer\preprocessed_data\Physionet\S003\PA003-3s-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =    2000.00 ...    5000.00 ms
        0 CTF compensation matrices available
Not setting metadata
42 matching events found
No baseline correction applied
0 projection items activated
R

[34m2025-04-21 20:43:50,091 - INFO - 
Testing on subject S001...[0m
[34m2025-04-21 20:46:38,660 - INFO - Accuracy on subject S001: 80.95%[0m
[34m2025-04-21 20:46:38,661 - INFO - 
Overall cross-subject accuracy: 80.95%[0m


# 5-Fold Cross validation


In [95]:
from sklearn.model_selection import KFold


def train_5_fold_cross_validation(model_class):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    results = []
    total_correct = 0
    total_samples = 0

    all_subjects = sorted(os.listdir(PREPROCESSED_DATA_DIR))
    subject_data = {}

    # Load all data once
    for subj_folder in all_subjects:
        subj_folder_path = os.path.join(PREPROCESSED_DATA_DIR, subj_folder)
        file_path = os.path.join(subj_folder_path, f"PA{subj_folder[1:]}-3s-epo.fif")
        if os.path.exists(file_path):
            X, y = load_subject_data(file_path)
            subject_data[subj_folder] = (X, y)

    # Split the subjects into 5 folds using KFold
    kf = KFold(n_splits=5, shuffle=True, random_state=42)

    for fold, (train_idx, test_idx) in enumerate(kf.split(subject_data)):
        logger.info(f"Fold {fold + 1}/5")

        # Create the training and test datasets for this fold
        X_train_all = []
        y_train_all = []
        X_test_all = []
        y_test_all = []

        # Prepare the train and test data based on the current fold
        for idx in train_idx:
            subj_folder = list(subject_data.keys())[idx]
            X, y = subject_data[subj_folder]
            X_train_all.append(X)
            y_train_all.append(y)

        for idx in test_idx:
            subj_folder = list(subject_data.keys())[idx]
            X, y = subject_data[subj_folder]
            X_test_all.append(X)
            y_test_all.append(y)

        # Convert lists into numpy arrays
        X_train = np.concatenate(X_train_all, axis=0)
        y_train = np.concatenate(y_train_all, axis=0)
        X_test = np.concatenate(X_test_all, axis=0)
        y_test = np.concatenate(y_test_all, axis=0)

        # Create DataLoader for train and test datasets
        cnn_mode = model_class.__name__ in CNN_MODELS

        train_dataset = EEGDataset(X_train, y_train, cnn_mode)
        test_dataset = EEGDataset(X_test, y_test, cnn_mode)

        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

        # Initialize the model
        model = model_class(input_size=X_train.shape[2], d_model=D_MODEL, nhead=NUM_HEADS, num_classes=NUM_CLASSES)
        train_model(model, train_loader=train_loader, device=device)

        # Evaluate model on the test data
        accuracy = evaluate_model(model, test_loader, device)
        logger.info(f"Accuracy for fold {fold + 1}: {accuracy * 100:.2f}%")
        results.append(accuracy)

        total_samples += len(test_dataset)
        total_correct += accuracy * len(test_dataset)

    # Calculate the average accuracy across all folds
    average_accuracy = np.mean(results)
    logger.info(f"Average accuracy across 5 folds: {average_accuracy * 100:.2f}%")


train_5_fold_cross_validation(SpatialCNNTransformer)

cuda
Reading c:\projects\eeg_transformer\preprocessed_data\Physionet\S001\PA001-3s-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =    2000.00 ...    5000.00 ms
        0 CTF compensation matrices available
Not setting metadata
42 matching events found
No baseline correction applied
0 projection items activated
Reading c:\projects\eeg_transformer\preprocessed_data\Physionet\S002\PA002-3s-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =    2000.00 ...    5000.00 ms
        0 CTF compensation matrices available
Not setting metadata
42 matching events found
No baseline correction applied
0 projection items activated
Reading c:\projects\eeg_transformer\preprocessed_data\Physionet\S003\PA003-3s-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =    2000.00 ...    5000.00 ms
        0 CTF compensation matrices available
Not setting metadata
42 matching events found
No baseline correction applied
0 projection items activa

[34m2025-04-21 20:22:38,933 - INFO - Fold 1/5[0m
  return F.conv2d(
[34m2025-04-21 20:24:48,234 - INFO - Accuracy for fold 1: 70.29%[0m
[34m2025-04-21 20:24:48,234 - INFO - Fold 2/5[0m
[34m2025-04-21 20:27:01,751 - INFO - Accuracy for fold 2: 76.98%[0m
[34m2025-04-21 20:27:01,757 - INFO - Fold 3/5[0m
[34m2025-04-21 20:29:31,401 - INFO - Accuracy for fold 3: 72.34%[0m
[34m2025-04-21 20:29:31,401 - INFO - Fold 4/5[0m
[34m2025-04-21 20:31:57,021 - INFO - Accuracy for fold 4: 75.96%[0m
[34m2025-04-21 20:31:57,021 - INFO - Fold 5/5[0m
[34m2025-04-21 20:34:06,154 - INFO - Accuracy for fold 5: 67.12%[0m
[34m2025-04-21 20:34:06,154 - INFO - Average accuracy across 5 folds: 72.54%[0m
