In [1]:
# !wget https://www.bbci.de/competition/download/competition_iv/BCICIV_1_mat.zip
# !unzip -q BCICIV_1_mat.zip
# !rm -rf BCICIV_1_mat.zip

In [2]:
# !git clone https://github.com/XJTU-EEG/LibEER.git

In [3]:
import os
import gc
import math
import random
import scipy.io
import numpy as np
from tqdm import tqdm
from sklearn.metrics import precision_recall_fscore_support

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import sys
sys.path.append('LibEER/LibEER/models')

In [4]:
from DGCNN import DGCNN

In [5]:
TRAIN = ['a', 'b', 'd', 'e', 'g']
VALID = ['c', 'f']

In [6]:
# a - left, foot
# b - left, right
# c - left, right
# d - left, right
# e - left, right
# f - left, foot
# g - left, right

In [7]:
LABELS = {
    'idle': 0,
    'left' : 1,
    'right': 2,
    'foot': 3
}

LR = 1e-2
EPOCHS = 1000
BATCH_SIZE = 64

In [8]:
def get_labels(data):
    N = len(data['cnt'])

    labels = np.zeros((N, 4), dtype=np.uint8)
    labels[:, 0] = 1

    cls_labels = [d[0] for d in data['nfo'][0][0][1][0].tolist()]
    timestamps = data['mrk'][0][0][0]
    cls_idx = data['mrk'][0][0][1]

    t1 = timestamps[np.where(cls_idx == 1)]
    l1 = [0]*4
    l1[LABELS[cls_labels[0]]] = 1
    l1 = np.asarray([l1]*100)
    for t in t1:
        labels[t-50:t+50] = l1 # change to t :t+100

    t2 = timestamps[np.where(cls_idx == -1)]
    l2 = [0]*4
    l2[LABELS[cls_labels[1]]] = 1
    l2 = np.asarray([l2]*100)
    for t in t2:
        labels[t-50:t+50] = l2

    return labels

def get_data(data_id, split_size):
    data = scipy.io.loadmat(f'BCICIV_calib_ds1{data_id}.mat')
    eegs = data['cnt']
    labels = get_labels(data)
    np.savetxt(f"labels_{data_id}.csv", labels, delimiter=",", fmt="%d")
    
    splits = range(split_size, len(eegs), split_size)

    eegs_split = np.array_split(eegs, splits)[:-1]
    labels_split = np.array_split(labels, splits)[:-1]
    return eegs_split, labels_split

In [9]:
train_eegs, train_labels, valid_eegs, valid_labels = [], [], [], []

for train_id in tqdm(TRAIN):
    # print(train_id)
    eegs_split, labels_split = get_data(train_id, split_size=200)
    train_eegs.extend(eegs_split)
    train_labels.extend(labels_split)

for valid_id in tqdm(VALID):
    eegs_split, labels_split = get_data(valid_id, split_size=100)
    valid_eegs.extend(eegs_split)
    valid_labels.extend(labels_split)

  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [00:00<00:00,  6.78it/s]
100%|██████████| 2/2 [00:00<00:00,  6.15it/s]


In [10]:
class EEGDataset(Dataset):
    def __init__(self, eegs, labels, split):
        self.eegs = eegs
        self.labels = labels
        self.split = split

    def __getitem__(self, i):
        eeg, label = self.eegs[i], self.labels[i]
        if self.split == 'train':
            start_idx = random.randint(0, 100)
            eeg = eeg[start_idx:start_idx+100]
            label = label[start_idx:start_idx+100]

        eeg = torch.from_numpy(eeg).float()
        #eeg = (eeg - MEAN)/STD
        eeg = (eeg - torch.mean(eeg, dim=0))/torch.std(eeg, dim=0)
        eeg = eeg.transpose(0, 1)

        # print(f"Single sample shape: {eeg.shape}")
        
        label_soft = torch.from_numpy(label).float().mean(dim=0)
        #label_hard = torch.argmax(label_soft)
        return eeg, label_soft

    def __len__(self):
        return len(self.eegs)

In [11]:
dataset_train = EEGDataset(train_eegs, train_labels, 'train')
dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, num_workers=0,
                              shuffle=True, pin_memory=True, drop_last=True)
n_train = len(dataloader_train)

dataset_valid = EEGDataset(valid_eegs, valid_labels, 'valid')
dataloader_valid = DataLoader(dataset_valid, batch_size=BATCH_SIZE, num_workers=0,
                              shuffle=False, pin_memory=True, drop_last=False)
n_valid = len(dataloader_valid)

In [12]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6

In [13]:
model = DGCNN(num_electrodes=59, in_channels=100, num_classes=4, k=2, relu_is=1, layers=None, dropout_rate=0.2)

print(f'Number of parameters: {count_parameters(model):.2f}M')
device = torch.device("mps" if torch.backends.mps.is_available() else
                      "cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Calculate class weights based on label distribution
labels = torch.stack([label for _, label in dataset_train])
label_counts = torch.bincount(torch.argmax(labels, dim=1))
class_weights = 1.0 / label_counts
class_weights = class_weights / class_weights.sum()
class_weights = class_weights.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)

# criterion = nn.CrossEntropyLoss()

#grad_scaler = torch.amp.GradScaler('cuda')
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, betas=[0.9, 0.999], weight_decay=0.001)
#scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR,
#                                                steps_per_epoch=10, epochs=EPOCHS//10,
#                                                pct_start=0.1)


config/model_param/DGCNN.yaml may not exist or not available
DGCNN Model, Parameters:

k (The order of Chebyshev polynomials):                         2
relu_is (The type of B_Relu func):                              1
layers (The channels of each layers):        None                
dropout rate:                                                 0.2

Not Using Default Setting, the performance may be not the best
Starting......
Number of parameters: 0.98M


In [14]:
# for x, y in dataloader_train:
#     print("Input shape:", x.shape)
#     print("Label shape:", y.shape)
#     break

In [15]:
for name, param in model.named_parameters():
    print(f"{name}: {param.mean().item():.4f}")

adj: 0.0008
adj_bias: 0.0100
graphConvs.0.weight: -0.0006
fc.weight: 0.0000
fc.bias: 0.1000
fc2.weight: 0.0008
fc2.bias: 0.1000
b_relus.0.bias: -0.0002


In [16]:
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

# Hyperparameters
patience = 10
best_val_loss = float('inf')
counter = 0

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

for epoch in range(1, EPOCHS + 1):
    model.train()
    cur_lr = f"LR : {optimizer.param_groups[0]['lr']:.2E}"

    pbar_train = tqdm(dataloader_train, total=n_train, bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}")
    mloss_train, mloss_val = 0.0, 0.0

    for i, (x, y) in enumerate(pbar_train):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()

        mloss_train += loss.item()

        gpu_mem = (
            f"Mem : {torch.cuda.memory_reserved() / 1E9:.3g}GB" if torch.cuda.is_available()
            else "MPS Active" if torch.backends.mps.is_available()
            else "CPU Mode"
        )

        pbar_train.set_description(f"Epoch {epoch}/{EPOCHS}  {gpu_mem}  {cur_lr}  Loss: {mloss_train / (i + 1):.4f}")

    # Validation
    model.eval()
    pbar_val = tqdm(dataloader_valid, total=n_valid, bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}")

    y_true, y_preds = [], []
    for i, (x, y) in enumerate(pbar_val):
        x, y = x.to(device), y.to(device)

        with torch.no_grad():
            y_hat = model(x)

        loss = criterion(y_hat, y)
        mloss_val += loss.item()
        y_preds.append(y_hat)
        y_true.append(y)

        pbar_val.set_description(f"Val Loss: {mloss_val / (i + 1):.4f}")

    y_true = torch.cat(y_true).cpu().numpy()
    y_true = np.argmax(y_true, axis=1)
    y_preds = F.softmax(torch.cat(y_preds), dim=1).argmax(dim=1).cpu().numpy()

    accuracy = np.mean(y_true == y_preds)  
    mets = precision_recall_fscore_support(y_true, y_preds, labels=[0, 1, 2, 3], average=None)

    print(f'Validation Accuracy: {accuracy:.4f}')
    np.savetxt("validation_predictions.csv", y_preds, delimiter=",", fmt="%d")

    print(f'Precision: {[f"{p:.4f}" for p in mets[0]]}')
    print(f'Recall: {[f"{r:.4f}" for r in mets[1]]}')
    print(f'FScore: {[f"{f:.4f}" for f in mets[2]]}')

    # Early Stopping & Save Best Model
    avg_val_loss = mloss_val / len(dataloader_valid)
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        counter = 0
        torch.save(model.state_dict(), "best_model.pth")
        print("Best model saved!")
    else:
        counter += 1
        print(f"Early stopping counter: {counter}/{patience}")

    if counter >= patience:
        print("Validation loss has increased for multiple epochs, stopping early!")
        break

    # Adjust learning rate
    scheduler.step(avg_val_loss)


Epoch 1/1000  MPS Active  LR : 1.00E-02  Loss: 0.5325: 100%|██████████| 74/74 [00:01<00:00, 40.52it/s]
Val Loss: 0.0796: 100%|██████████| 60/60 [00:00<00:00, 70.06it/s]


Validation Accuracy: 0.2585
Precision: ['0.9026', '0.0557', '0.0238', '0.0000']
Recall: ['0.2419', '0.7950', '0.0100', '0.0000']
FScore: ['0.3815', '0.1041', '0.0141', '0.0000']
Best model saved!


Epoch 2/1000  MPS Active  LR : 1.00E-02  Loss: 0.0516: 100%|██████████| 74/74 [00:01<00:00, 46.32it/s]
Val Loss: 0.0480: 100%|██████████| 60/60 [00:00<00:00, 71.91it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.0328
Precision: ['0.9062', '0.0000', '0.0300', '0.0253']
Recall: ['0.0085', '0.0000', '0.0300', '0.9300']
FScore: ['0.0168', '0.0000', '0.0300', '0.0492']
Best model saved!


Epoch 3/1000  MPS Active  LR : 1.00E-02  Loss: 0.0411: 100%|██████████| 74/74 [00:01<00:00, 49.58it/s]
Val Loss: 0.0476: 100%|██████████| 60/60 [00:00<00:00, 79.04it/s]


Validation Accuracy: 0.0459
Precision: ['0.8442', '0.0657', '0.0226', '0.0251']
Recall: ['0.0191', '0.1400', '0.1000', '0.7200']
FScore: ['0.0373', '0.0895', '0.0368', '0.0486']
Best model saved!


Epoch 4/1000  MPS Active  LR : 1.00E-02  Loss: nan: 100%|██████████| 74/74 [00:01<00:00, 50.44it/s]   
Val Loss: nan: 100%|██████████| 60/60 [00:00<00:00, 67.74it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 1/10


Epoch 5/1000  MPS Active  LR : 1.00E-02  Loss: nan: 100%|██████████| 74/74 [00:01<00:00, 45.97it/s]
Val Loss: nan: 100%|██████████| 60/60 [00:00<00:00, 78.01it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 2/10


Epoch 6/1000  MPS Active  LR : 1.00E-02  Loss: nan: 100%|██████████| 74/74 [00:01<00:00, 46.61it/s]
Val Loss: nan: 100%|██████████| 60/60 [00:00<00:00, 67.82it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 3/10


Epoch 7/1000  MPS Active  LR : 1.00E-02  Loss: nan: 100%|██████████| 74/74 [00:01<00:00, 45.99it/s]
Val Loss: nan: 100%|██████████| 60/60 [00:00<00:00, 76.41it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Accuracy: 0.8950
Precision: ['0.8950', '0.0000', '0.0000', '0.0000']
Recall: ['1.0000', '0.0000', '0.0000', '0.0000']
FScore: ['0.9446', '0.0000', '0.0000', '0.0000']
Early stopping counter: 4/10


Epoch 8/1000  MPS Active  LR : 5.00E-03  Loss: nan:   7%|▋         | 5/74 [00:00<00:01, 49.77it/s]