# Step 1. Getting the SEED dataset

## 1.1 load the matlab format files and take a look

In [7]:
import mne
from scipy import io
import os
import numpy as np

In [2]:
folderpath = "./Preprocessed_EEG"

In [8]:
raw_1_1 = scipy.io.loadmat('./Preprocessed_EEG/1_20131027.mat')
raw_1_2 = scipy.io.loadmat('./Preprocessed_EEG/1_20131030.mat')
raw_1_3 = scipy.io.loadmat('./Preprocessed_EEG/1_20131107.mat')
# raw_2_1 = scipy.io.loadmat('./Preprocessed_EEG/2_20140404.mat')
label = scipy.io.loadmat('./Preprocessed_EEG/label.mat')

In [4]:
raw_1_1.keys(), raw_1_2.keys(), raw_1_3.keys(), label.keys()

(dict_keys(['__header__', '__version__', '__globals__', 'djc_eeg1', 'djc_eeg2', 'djc_eeg3', 'djc_eeg4', 'djc_eeg5', 'djc_eeg6', 'djc_eeg7', 'djc_eeg8', 'djc_eeg9', 'djc_eeg10', 'djc_eeg11', 'djc_eeg12', 'djc_eeg13', 'djc_eeg14', 'djc_eeg15']),
 dict_keys(['__header__', '__version__', '__globals__', 'djc_eeg1', 'djc_eeg2', 'djc_eeg3', 'djc_eeg4', 'djc_eeg5', 'djc_eeg6', 'djc_eeg7', 'djc_eeg8', 'djc_eeg9', 'djc_eeg10', 'djc_eeg11', 'djc_eeg12', 'djc_eeg13', 'djc_eeg14', 'djc_eeg15']),
 dict_keys(['__header__', '__version__', '__globals__', 'djc_eeg1', 'djc_eeg2', 'djc_eeg3', 'djc_eeg4', 'djc_eeg5', 'djc_eeg6', 'djc_eeg7', 'djc_eeg8', 'djc_eeg9', 'djc_eeg10', 'djc_eeg11', 'djc_eeg12', 'djc_eeg13', 'djc_eeg14', 'djc_eeg15']),
 dict_keys(['__header__', '__version__', '__globals__', 'label']))

In [6]:
raw_1_1.keys()

dict_keys(['__header__', '__version__', '__globals__', 'djc_eeg1', 'djc_eeg2', 'djc_eeg3', 'djc_eeg4', 'djc_eeg5', 'djc_eeg6', 'djc_eeg7', 'djc_eeg8', 'djc_eeg9', 'djc_eeg10', 'djc_eeg11', 'djc_eeg12', 'djc_eeg13', 'djc_eeg14', 'djc_eeg15'])

In [7]:
raw_1_1['djc_eeg2'].shape, raw_1_2['djc_eeg2'].shape, raw_1_1['djc_eeg10'].shape, raw_1_3['djc_eeg10'].shape

((62, 46601), (62, 46601), (62, 47401), (62, 47401))

raw_1_1['djc_eeg1']:
- raw_1_1: the raw eeg data, the 1st subject(out of 15 subjects), the 1st experiments(out of 3 experiments)
- 'djc_eeg1': djc could possibly mean the name of the subject(which leaks the privacy XD), and 'eeg1' represents it is the 1st trial(out of 15 trials)
- 62: n_channels
- 47001: roughly 4 mins(240sec * 200 Hz sampling rate) the eeg signal during watching the movies

In [8]:
label['label'][0]

array([ 1,  0, -1, -1,  0,  1, -1,  0,  1,  1,  0, -1,  0,  1, -1],
      dtype=int16)

## 1.2 Create the torch dataset

### 1.2.1 cut compute windows

since a movie clip contains eeg signals of ~240 seconds with sf=200Hz, I decide to cut them into 4 sec windows

In [9]:
len_window = 200 * 4

raw_X = []
raw_y = []

prefix = 'djc_eeg'

# raw_1_1
for i in range(1, 16):
    data = raw_1_1[prefix + str(i)]
    # print(data.shape)
    n_windows = data.shape[1] // len_window
    # print(n_windows)
    reshaped_X = np.reshape(data[:, :n_windows * len_window], (62, len_window, n_windows))
    raw_X.append(reshaped_X)
    raw_y.append(np.array([label['label'][0][i-1] for j in range(n_windows)]))

# raw_1_2
for i in range(1, 16):
    data = raw_1_2[prefix + str(i)]
    # print(data.shape)
    n_windows = data.shape[1] // len_window
    # print(n_windows)
    reshaped_X = np.reshape(data[:, :n_windows * len_window], (62, len_window, n_windows))
    raw_X.append(reshaped_X)
    raw_y.append(np.array([label['label'][0][i-1] for j in range(n_windows)]))
    # print(reshaped_X.shape)   # print(reshaped_X.shape)

# raw_1_3
for i in range(1, 16):
    data = raw_1_3[prefix + str(i)]
    # print(data.shape)
    n_windows = data.shape[1] // len_window
    # print(n_windows)
    reshaped_X = np.reshape(data[:, :n_windows * len_window], (62, len_window, n_windows))
    raw_X.append(reshaped_X)
    raw_y.append(np.array([label['label'][0][i-1] for j in range(n_windows)]))
    # print(reshaped_X.shape)   # print(reshaped_X.shape)

In [10]:
concat_X = np.concatenate(raw_X, axis=2)
print(concat_X.shape)
concat_y = np.concatenate(raw_y)
print(concat_y.shape)

(62, 800, 2526)
(2526,)


### 1.2.2 Encode the labels

In [11]:
from sklearn.preprocessing import LabelEncoder
import pandas as pd

In [12]:
concat_y

array([ 1,  1,  1, ..., -1, -1, -1], dtype=int16)

In [13]:
le = LabelEncoder()
y = le.fit_transform(concat_y)
y = pd.get_dummies(y)
y

Unnamed: 0,0,1,2
0,0,0,1
1,0,0,1
2,0,0,1
3,0,0,1
4,0,0,1
...,...,...,...
2521,1,0,0
2522,1,0,0
2523,1,0,0
2524,1,0,0


one-hot encoding
- positive:  1 => [0, 0 ,1]
- neutral:   0 => [0, 1, 0]
- negative: -1 => [1, 0, 0]

In [14]:
y.shape, concat_X.shape

((2526, 3), (62, 800, 2526))

# Step 2. Create EEG-conformer model

In [38]:
import torch
# from braindecode.models import EEGConformer
from braindecode.util import set_random_seeds

from eegconformer import EEGConformer

In [36]:
cuda = torch.cuda.is_available()
print(cuda)
device = "cuda" if cuda else "cpu"
device

True


'cuda'

In [37]:
cuda_version = torch.version.cuda
cuda_version

'11.7'

In [45]:
seed = 20240216
set_random_seeds(seed=seed, cuda=cuda)

model = EEGConformer(
    n_outputs=3,
    n_chans=62,
    n_times=800, # input_winodw_samples
    input_window_seconds=4,
    sfreq=200,
)

print(model)



Layer (type (var_name):depth-idx)                            Input Shape               Output Shape              Param #                   Kernel Shape
EEGConformer (EEGConformer)                                  [1, 62, 800]              [1, 3]                    --                        --
├─_PatchEmbedding (patch_embedding): 1-1                     [1, 1, 62, 800]           [1, 47, 40]               --                        --
│    └─Sequential (shallownet): 2-1                          [1, 1, 62, 800]           [1, 40, 1, 47]            --                        --
│    │    └─Conv2d (0): 3-1                                  [1, 1, 62, 800]           [1, 40, 62, 776]          1,040                     [1, 25]
│    │    └─Conv2d (1): 3-2                                  [1, 40, 62, 776]          [1, 40, 1, 776]           99,240                    [62, 1]
│    │    └─BatchNorm2d (2): 3-3                             [1, 40, 1, 776]           [1, 40, 1, 776]           80             

In [46]:
if cuda:
    model.cuda()

# Step 3. Training the model

## 3.1 prepare the train set / test set

In [47]:
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import matplotlib_inline

In [53]:
concat_X.shape, y.shape

((62, 800, 2526), (2526, 3))

In [55]:
X = concat_X.transpose((2, 0, 1))
X.shape

(2526, 62, 800)

In [56]:
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, random_state=777)

In [58]:
X_train.shape, X_test.shape, y_train.shape, y_test.shape

((1768, 62, 800), (758, 62, 800), (1768, 3), (758, 3))

## 3.2 Training

In [61]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

x_train_tensor = torch.from_numpy(X_train).to(torch.float32).to(device)
y_train_tensor = torch.from_numpy(y_train.values).to(torch.float32).to(device)
train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

x_test_tensor = torch.from_numpy(X_test).to(torch.float32).to(device)
y_test_tensor = torch.from_numpy(y_test.values).to(torch.float32).to(device)
test_dataset = TensorDataset(x_test_tensor, y_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [62]:
train_acc_list = []
train_loss_list = []
test_acc_list = []
test_loss_list = []

In [63]:
def train(model, x_train, y_train, x_test, y_test, save_path='./model_transformer/', n_epochs=100):
    """
    the training function.
    Attributes:
    - model: the instance of the network
    - save_path: the path to which the model state will be saved. None means w/o saving.

    Return: 
    the best model if save_path is not None, the last model otherwise
    """
    train_acc_list.clear()
    train_loss_list.clear()
    test_acc_list.clear()
    test_loss_list.clear()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print('the model will be trained on: ', device)

    best_accuracy = 0.0
    
    for epoch in range(n_epochs):
        training_loss = 0.0
        testing_loss = 0.0
        correct = 0
        total = 0
        
        model.train()
        for inputs, labels in tqdm(train_loader):
            # print(inputs.shape, labels.shape)
            inputs = inputs.to(device)
            labels = labels.to(device)
            # print("lables.size", labels.shape)
            optimizer.zero_grad()
            
            outputs = model(inputs)
            # print(outputs.shape)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            training_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            _, label = torch.max(labels, 1)
            # print(predicted.shape)
            total += labels.size(0)
            correct += (predicted == label).sum().item()
        
        train_loss = training_loss / len(train_loader)
        train_loss_list.append(train_loss)
        train_accuracy = correct / total
        train_acc_list.append(train_accuracy)

        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            
            for inputs, labels in test_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                testing_loss += loss.item()


                _, predicted = torch.max(outputs.data, 1)
                _, label = torch.max(labels, 1)
                total += labels.size(0)
                correct += (predicted == label).sum().item()
            
            test_loss = testing_loss / len(test_loader)
            test_loss_list.append(test_loss)
            test_accuracy = correct / total
            test_acc_list.append(test_accuracy)

            if test_accuracy > best_accuracy:
                best_accuracy = test_accuracy
                if save_path is not None:
                    torch.save(model.state_dict(), save_path + 'best_model.pth')
                    print("best_model found, best acc: ", best_accuracy)
        
        print(f"Epoch {epoch+1}/{n_epochs} - Train Loss: {train_loss:.4f} - Train Accuracy: {train_accuracy:.4f} - Test Loss: {test_loss:.4f} - Test Accuracy: {test_accuracy:.4f}")
    
    if save_path is not None: 
        model.load_state_dict(torch.load(save_path + 'best_model.pth'))
    
    return model

In [66]:
best_model = train(model, X_train, X_test, y_train, y_test, n_epochs=10)

the model will be trained on:  cuda


100%|██████████| 56/56 [00:01<00:00, 31.56it/s]


best_model found, best acc:  1.0
Epoch 1/10 - Train Loss: 0.0203 - Train Accuracy: 0.9943 - Test Loss: 0.0002 - Test Accuracy: 1.0000


100%|██████████| 56/56 [00:01<00:00, 31.64it/s]


Epoch 2/10 - Train Loss: 0.0177 - Train Accuracy: 0.9943 - Test Loss: 0.0016 - Test Accuracy: 0.9987


100%|██████████| 56/56 [00:01<00:00, 31.60it/s]


Epoch 3/10 - Train Loss: 0.0205 - Train Accuracy: 0.9949 - Test Loss: 0.0022 - Test Accuracy: 0.9987


100%|██████████| 56/56 [00:01<00:00, 30.49it/s]


Epoch 4/10 - Train Loss: 0.0481 - Train Accuracy: 0.9836 - Test Loss: 0.0000 - Test Accuracy: 1.0000


100%|██████████| 56/56 [00:01<00:00, 31.79it/s]


Epoch 5/10 - Train Loss: 0.0556 - Train Accuracy: 0.9893 - Test Loss: 0.0000 - Test Accuracy: 1.0000


100%|██████████| 56/56 [00:01<00:00, 31.51it/s]


Epoch 6/10 - Train Loss: 0.1126 - Train Accuracy: 0.9678 - Test Loss: 0.0160 - Test Accuracy: 0.9947


100%|██████████| 56/56 [00:01<00:00, 31.95it/s]


Epoch 7/10 - Train Loss: 0.0231 - Train Accuracy: 0.9943 - Test Loss: 0.0001 - Test Accuracy: 1.0000


100%|██████████| 56/56 [00:01<00:00, 31.94it/s]


Epoch 8/10 - Train Loss: 0.0162 - Train Accuracy: 0.9972 - Test Loss: 0.0000 - Test Accuracy: 1.0000


100%|██████████| 56/56 [00:01<00:00, 30.39it/s]


Epoch 9/10 - Train Loss: 0.0083 - Train Accuracy: 0.9972 - Test Loss: 0.0000 - Test Accuracy: 1.0000


100%|██████████| 56/56 [00:01<00:00, 31.50it/s]


Epoch 10/10 - Train Loss: 0.0344 - Train Accuracy: 0.9910 - Test Loss: 0.0000 - Test Accuracy: 1.0000
