In [1]:
# !wget https://www.bbci.de/competition/download/competition_iv/BCICIV_1_mat.zip
# !unzip -q BCICIV_1_mat.zip
# !rm -rf BCICIV_1_mat.zip

In [2]:
# !git clone https://github.com/XJTU-EEG/LibEER.git

In [3]:
import os
import gc
import math
import random
import scipy.io
import numpy as np
from tqdm import tqdm
from sklearn.metrics import precision_recall_fscore_support

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import sys
sys.path.append('LibEER/LibEER/models')

In [4]:
from DGCNN import DGCNN

In [5]:
TRAIN = ['a', 'b', 'd', 'e', 'g']
VALID = ['c', 'f']

In [6]:
# a - left, foot
# b - left, right
# c - left, right
# d - left, right
# e - left, right
# f - left, foot
# g - left, right

In [7]:
LABELS = {
    'idle': 0,
    'left' : 1,
    'right': 2,
    'foot': 3
}

LR = 1e-2
EPOCHS = 1000
BATCH_SIZE = 64

In [8]:
def get_labels(data):
    N = len(data['cnt'])

    labels = np.zeros((N, 4), dtype=np.uint8)
    labels[:, 0] = 1

    cls_labels = [d[0] for d in data['nfo'][0][0][1][0].tolist()]
    timestamps = data['mrk'][0][0][0]
    cls_idx = data['mrk'][0][0][1]

    t1 = timestamps[np.where(cls_idx == 1)]
    l1 = [0]*4
    l1[LABELS[cls_labels[0]]] = 1
    l1 = np.asarray([l1]*100)
    for t in t1:
        labels[t-50:t+50] = l1 # change to t :t+100

    t2 = timestamps[np.where(cls_idx == -1)]
    l2 = [0]*4
    l2[LABELS[cls_labels[1]]] = 1
    l2 = np.asarray([l2]*100)
    for t in t2:
        labels[t-50:t+50] = l2

    return labels

def get_data(data_id, split_size):
    data = scipy.io.loadmat(f'BCICIV_calib_ds1{data_id}.mat')
    eegs = data['cnt']
    labels = get_labels(data)
    np.savetxt(f"labels_{data_id}.csv", labels, delimiter=",", fmt="%d")
    
    splits = range(split_size, len(eegs), split_size)

    eegs_split = np.array_split(eegs, splits)[:-1]
    labels_split = np.array_split(labels, splits)[:-1]
    return eegs_split, labels_split

In [9]:
train_eegs, train_labels, valid_eegs, valid_labels = [], [], [], []

for train_id in tqdm(TRAIN):
    # print(train_id)
    eegs_split, labels_split = get_data(train_id, split_size=200)
    train_eegs.extend(eegs_split)
    train_labels.extend(labels_split)

for valid_id in tqdm(VALID):
    eegs_split, labels_split = get_data(valid_id, split_size=100)
    valid_eegs.extend(eegs_split)
    valid_labels.extend(labels_split)

  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [00:00<00:00,  7.84it/s]
100%|██████████| 2/2 [00:00<00:00,  7.50it/s]


In [10]:
class EEGDataset(Dataset):
    def __init__(self, eegs, labels, split):
        self.eegs = eegs
        self.labels = labels
        self.split = split

    def __getitem__(self, i):
        eeg, label = self.eegs[i], self.labels[i]
        if self.split == 'train':
            start_idx = random.randint(0, 100)
            eeg = eeg[start_idx:start_idx+100]
            label = label[start_idx:start_idx+100]

        eeg = torch.from_numpy(eeg).float()
        #eeg = (eeg - MEAN)/STD
        eeg = (eeg - torch.mean(eeg, dim=0))/torch.std(eeg, dim=0)
        eeg = eeg.transpose(0, 1)

        # print(f"Single sample shape: {eeg.shape}")
        
        label_soft = torch.from_numpy(label).float().mean(dim=0)
        #label_hard = torch.argmax(label_soft)
        return eeg, label_soft

    def __len__(self):
        return len(self.eegs)

In [11]:
dataset_train = EEGDataset(train_eegs, train_labels, 'train')
dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, num_workers=0,
                              shuffle=True, pin_memory=True, drop_last=True)
n_train = len(dataloader_train)

dataset_valid = EEGDataset(valid_eegs, valid_labels, 'valid')
dataloader_valid = DataLoader(dataset_valid, batch_size=BATCH_SIZE, num_workers=0,
                              shuffle=False, pin_memory=True, drop_last=False)
n_valid = len(dataloader_valid)

In [12]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6

In [13]:
model = DGCNN(num_electrodes=59, in_channels=100, num_classes=4, k=2, relu_is=1, layers=None, dropout_rate=0.2)

print(f'Number of parameters: {count_parameters(model):.2f}M')
device = torch.device("mps" if torch.backends.mps.is_available() else
                      "cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# criterion = nn.CrossEntropyLoss()
# loss_name = "Cross_Entropy"

# Calculate class weights based on label distribution
labels = torch.stack([label for _, label in dataset_train])
label_counts = torch.bincount(torch.argmax(labels, dim=1))
class_weights = 1.0 / label_counts
class_weights = class_weights / class_weights.sum()
class_weights = class_weights.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)
loss_name = "Weighted_Cross_Entropy"

#grad_scaler = torch.amp.GradScaler('cuda')
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, betas=[0.9, 0.999], weight_decay=0.001)


#scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR,
#                                                steps_per_epoch=10, epochs=EPOCHS//10,
#                                                pct_start=0.1)


config/model_param/DGCNN.yaml may not exist or not available
DGCNN Model, Parameters:

k (The order of Chebyshev polynomials):                         2
relu_is (The type of B_Relu func):                              1
layers (The channels of each layers):        None                
dropout rate:                                                 0.2

Not Using Default Setting, the performance may be not the best
Starting......
Number of parameters: 0.98M


In [14]:
# for x, y in dataloader_train:
#     print("Input shape:", x.shape)
#     print("Label shape:", y.shape)
#     break

In [15]:
for name, param in model.named_parameters():
    print(f"{name}: {param.mean().item():.4f}")

adj: 0.0025
adj_bias: 0.0100
graphConvs.0.weight: 0.0003
fc.weight: -0.0000
fc.bias: 0.1000
fc2.weight: -0.0015
fc2.bias: 0.1000
b_relus.0.bias: 0.0003


In [16]:
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

# Hyperparameters
patience = 10
best_val_loss = float('inf')
counter = 0

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

for epoch in range(1, EPOCHS + 1):
    model.train()
    cur_lr = f"LR : {optimizer.param_groups[0]['lr']:.2E}"

    pbar_train = tqdm(dataloader_train, total=n_train, bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}")
    mloss_train, mloss_val = 0.0, 0.0

    for i, (x, y) in enumerate(pbar_train):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()

        mloss_train += loss.item()

        gpu_mem = (
            f"Mem : {torch.cuda.memory_reserved() / 1E9:.3g}GB" if torch.cuda.is_available()
            else "MPS Active" if torch.backends.mps.is_available()
            else "CPU Mode"
        )

        pbar_train.set_description(f"Epoch {epoch}/{EPOCHS}  {gpu_mem}  {cur_lr}  Loss: {mloss_train / (i + 1):.4f}")

    # Validation
    model.eval()
    pbar_val = tqdm(dataloader_valid, total=n_valid, bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}")

    y_true, y_preds = [], []
    for i, (x, y) in enumerate(pbar_val):
        x, y = x.to(device), y.to(device)

        with torch.no_grad():
            y_hat = model(x)

        loss = criterion(y_hat, y)
        mloss_val += loss.item()
        y_preds.append(y_hat)
        y_true.append(y)

        pbar_val.set_description(f"Val Loss: {mloss_val / (i + 1):.4f}")

    y_true = torch.cat(y_true).cpu().numpy()
    y_true = np.argmax(y_true, axis=1)
    y_preds = F.softmax(torch.cat(y_preds), dim=1).argmax(dim=1).cpu().numpy()

    accuracy = np.mean(y_true == y_preds)  
    mets = precision_recall_fscore_support(y_true, y_preds, labels=[0, 1, 2, 3], average=None)

    print(f'Validation Accuracy: {accuracy:.4f}')
    np.savetxt("validation_predictions.csv", y_preds, delimiter=",", fmt="%d")

    print(f'Precision: {[f"{p:.4f}" for p in mets[0]]}')
    print(f'Recall: {[f"{r:.4f}" for r in mets[1]]}')
    print(f'FScore: {[f"{f:.4f}" for f in mets[2]]}')

    # Early Stopping & Save Best Model
    avg_val_loss = mloss_val / len(dataloader_valid)
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        counter = 0
        torch.save(model.state_dict(), f"DGCNN_{loss_name}_best_model.pth")
        print("Best model saved!")
    else:
        counter += 1
        print(f"Early stopping counter: {counter}/{patience}")

    if counter >= patience:
        print("Validation loss has increased for multiple epochs, stopping early!")
        break

    # Adjust learning rate
    scheduler.step(avg_val_loss)


Epoch 1/1000  MPS Active  LR : 1.00E-02  Loss: 0.4192: 100%|██████████| 74/74 [00:01<00:00, 37.82it/s]
Val Loss: 0.0589: 100%|██████████| 60/60 [00:00<00:00, 76.86it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.1008
Precision: ['0.8900', '0.0000', '0.0320', '0.0343']
Recall: ['0.0783', '0.0000', '0.4600', '0.7100']
FScore: ['0.1439', '0.0000', '0.0598', '0.0654']
Best model saved!


Epoch 2/1000  MPS Active  LR : 1.00E-02  Loss: 0.0557: 100%|██████████| 74/74 [00:01<00:00, 43.34it/s]
Val Loss: 0.0829: 100%|██████████| 60/60 [00:00<00:00, 76.98it/s]


Validation Accuracy: 0.0491
Precision: ['0.8462', '0.0515', '0.0165', '0.0000']
Recall: ['0.0032', '0.8350', '0.0900', '0.0000']
FScore: ['0.0064', '0.0970', '0.0278', '0.0000']
Early stopping counter: 1/10


Epoch 3/1000  MPS Active  LR : 1.00E-02  Loss: 0.0453: 100%|██████████| 74/74 [00:01<00:00, 48.12it/s]
Val Loss: 0.0484: 100%|██████████| 60/60 [00:00<00:00, 78.35it/s]


Validation Accuracy: 0.1847
Precision: ['0.8842', '0.0345', '0.0467', '0.0295']
Recall: ['0.1768', '0.0050', '0.2300', '0.7700']
FScore: ['0.2946', '0.0087', '0.0776', '0.0569']
Best model saved!


Epoch 4/1000  MPS Active  LR : 1.00E-02  Loss: 0.0433: 100%|██████████| 74/74 [00:01<00:00, 50.10it/s]
Val Loss: 0.0693: 100%|██████████| 60/60 [00:00<00:00, 70.14it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.0302
Precision: ['0.0000', '0.0000', '0.0276', '0.0833']
Recall: ['0.0000', '0.0000', '1.0000', '0.1500']
FScore: ['0.0000', '0.0000', '0.0536', '0.1071']
Early stopping counter: 1/10


Epoch 5/1000  MPS Active  LR : 1.00E-02  Loss: 0.0448: 100%|██████████| 74/74 [00:01<00:00, 46.82it/s]
Val Loss: 0.0542: 100%|██████████| 60/60 [00:00<00:00, 78.25it/s]


Validation Accuracy: 0.1236
Precision: ['0.8631', '0.2500', '0.0000', '0.0236']
Recall: ['0.1146', '0.0050', '0.0000', '0.7900']
FScore: ['0.2024', '0.0098', '0.0000', '0.0459']
Early stopping counter: 2/10


Epoch 6/1000  MPS Active  LR : 1.00E-02  Loss: 0.0464: 100%|██████████| 74/74 [00:01<00:00, 48.76it/s]
Val Loss: 0.0554: 100%|██████████| 60/60 [00:00<00:00, 77.61it/s]


Validation Accuracy: 0.6872
Precision: ['0.8992', '1.0000', '0.0000', '0.0439']
Recall: ['0.7555', '0.0050', '0.0000', '0.4100']
FScore: ['0.8211', '0.0100', '0.0000', '0.0794']
Early stopping counter: 3/10


Epoch 7/1000  MPS Active  LR : 1.00E-02  Loss: 0.0442: 100%|██████████| 74/74 [00:01<00:00, 48.45it/s]
Val Loss: 0.0968: 100%|██████████| 60/60 [00:00<00:00, 72.95it/s]


Validation Accuracy: 0.1800
Precision: ['0.9048', '0.0537', '0.0000', '0.0000']
Recall: ['0.1504', '0.8650', '0.0000', '0.0000']
FScore: ['0.2579', '0.1012', '0.0000', '0.0000']
Early stopping counter: 4/10


Epoch 8/1000  MPS Active  LR : 5.00E-03  Loss: 0.0427: 100%|██████████| 74/74 [00:01<00:00, 48.97it/s]
Val Loss: 0.0588: 100%|██████████| 60/60 [00:00<00:00, 79.93it/s]


Validation Accuracy: 0.6618
Precision: ['0.9006', '0.0678', '0.0317', '0.0515']
Recall: ['0.7253', '0.0800', '0.1700', '0.1500']
FScore: ['0.8035', '0.0734', '0.0534', '0.0767']
Early stopping counter: 5/10


Epoch 9/1000  MPS Active  LR : 5.00E-03  Loss: 0.0379: 100%|██████████| 74/74 [00:01<00:00, 49.61it/s]
Val Loss: 0.0541: 100%|██████████| 60/60 [00:00<00:00, 77.70it/s]


Validation Accuracy: 0.1304
Precision: ['0.8975', '0.0514', '0.0000', '0.0254']
Recall: ['0.1052', '0.5100', '0.0000', '0.3600']
FScore: ['0.1884', '0.0934', '0.0000', '0.0475']
Early stopping counter: 6/10


Epoch 10/1000  MPS Active  LR : 5.00E-03  Loss: 0.0398: 100%|██████████| 74/74 [00:01<00:00, 51.27it/s]
Val Loss: 0.0826: 100%|██████████| 60/60 [00:00<00:00, 75.40it/s]


Validation Accuracy: 0.6106
Precision: ['0.8966', '0.0637', '0.0369', '0.0000']
Recall: ['0.6631', '0.2250', '0.2000', '0.0000']
FScore: ['0.7624', '0.0993', '0.0623', '0.0000']
Early stopping counter: 7/10


Epoch 11/1000  MPS Active  LR : 5.00E-03  Loss: 0.0393: 100%|██████████| 74/74 [00:01<00:00, 49.00it/s]
Val Loss: 0.0554: 100%|██████████| 60/60 [00:00<00:00, 79.89it/s]


Validation Accuracy: 0.0525
Precision: ['0.8784', '0.0530', '0.0515', '0.0196']
Recall: ['0.0191', '0.4550', '0.0700', '0.3700']
FScore: ['0.0373', '0.0949', '0.0593', '0.0373']
Early stopping counter: 8/10


Epoch 12/1000  MPS Active  LR : 2.50E-03  Loss: 0.0368: 100%|██████████| 74/74 [00:01<00:00, 50.10it/s]
Val Loss: 0.1018: 100%|██████████| 60/60 [00:00<00:00, 80.39it/s]


Validation Accuracy: 0.5865
Precision: ['0.9067', '0.1273', '0.0263', '0.0588']
Recall: ['0.6409', '0.0700', '0.3300', '0.0200']
FScore: ['0.7509', '0.0903', '0.0487', '0.0299']
Early stopping counter: 9/10


Epoch 13/1000  MPS Active  LR : 2.50E-03  Loss: 0.0371: 100%|██████████| 74/74 [00:01<00:00, 50.72it/s]
Val Loss: 0.0747: 100%|██████████| 60/60 [00:00<00:00, 76.19it/s]

Validation Accuracy: 0.4285
Precision: ['0.8995', '0.0561', '0.0320', '0.0417']
Recall: ['0.4538', '0.1700', '0.3600', '0.1500']
FScore: ['0.6033', '0.0844', '0.0588', '0.0652']
Early stopping counter: 10/10
Validation loss has increased for multiple epochs, stopping early!



