# BCI Competition IV Dataset 2b: CNN vs GCN

In [3]:
import numpy as np
import mne
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import glob

## Load BCI Competition IV Dataset 2b (2-class motor imagery)

In [None]:
data_list = []
labels_list = []

train_files = sorted(glob.glob('BCI_2b/*T.gdf'))
print(f'Found {len(train_files)} training files')

for train_file in train_files:
    raw = mne.io.read_raw_gdf(train_file, preload=True, verbose=False)
    events, event_id = mne.events_from_annotations(raw, verbose=False)
    
    mi_event_ids = {k: v for k, v in event_id.items() if k in ['769', '770']}
    
    if len(mi_event_ids) > 0:
        epochs = mne.Epochs(raw, events, event_id=mi_event_ids,
                           tmin=0, tmax=4, baseline=None, preload=True, verbose=False)
        
        data = epochs.get_data()
        labels = epochs.events[:, -1]
        label_mapping = {v: i for i, (k, v) in enumerate(sorted(mi_event_ids.items(), key=lambda x: x[1]))}
        labels = np.array([label_mapping[l] for l in labels])
        
        data_list.append(data)
        labels_list.append(labels)
        print(f'{train_file}: {data.shape[0]} epochs, {data.shape[1]} channels')

X = np.concatenate(data_list, axis=0)
y = np.concatenate(labels_list, axis=0)

print(f'\nTotal: {X.shape} - 2 classes (left/right hand)')
num_channels = X.shape[1]
num_timepoints = X.shape[2]

Found 27 training files


  raw = mne.io.read_raw_gdf(train_file, preload=True, verbose=False)


BCI_2b\B0101T.gdf: 120 epochs, 6 channels


  raw = mne.io.read_raw_gdf(train_file, preload=True, verbose=False)


BCI_2b\B0102T.gdf: 120 epochs, 6 channels


  raw = mne.io.read_raw_gdf(train_file, preload=True, verbose=False)


BCI_2b\B0103T.gdf: 160 epochs, 6 channels


  raw = mne.io.read_raw_gdf(train_file, preload=True, verbose=False)


BCI_2b\B0201T.gdf: 120 epochs, 6 channels


  raw = mne.io.read_raw_gdf(train_file, preload=True, verbose=False)


BCI_2b\B0202T.gdf: 120 epochs, 6 channels


  raw = mne.io.read_raw_gdf(train_file, preload=True, verbose=False)


BCI_2b\B0203T.gdf: 160 epochs, 6 channels


  raw = mne.io.read_raw_gdf(train_file, preload=True, verbose=False)


BCI_2b\B0301T.gdf: 120 epochs, 6 channels


  raw = mne.io.read_raw_gdf(train_file, preload=True, verbose=False)


BCI_2b\B0302T.gdf: 120 epochs, 6 channels


  raw = mne.io.read_raw_gdf(train_file, preload=True, verbose=False)


BCI_2b\B0303T.gdf: 160 epochs, 6 channels


  raw = mne.io.read_raw_gdf(train_file, preload=True, verbose=False)


BCI_2b\B0401T.gdf: 120 epochs, 6 channels


  raw = mne.io.read_raw_gdf(train_file, preload=True, verbose=False)


BCI_2b\B0402T.gdf: 140 epochs, 6 channels


  raw = mne.io.read_raw_gdf(train_file, preload=True, verbose=False)


BCI_2b\B0403T.gdf: 160 epochs, 6 channels


  raw = mne.io.read_raw_gdf(train_file, preload=True, verbose=False)


BCI_2b\B0501T.gdf: 120 epochs, 6 channels


  raw = mne.io.read_raw_gdf(train_file, preload=True, verbose=False)


## Prepare Data

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train.reshape(-1, num_timepoints)).reshape(-1, num_channels, num_timepoints)
X_test = scaler.transform(X_test.reshape(-1, num_timepoints)).reshape(-1, num_channels, num_timepoints)

y_train_cat = keras.utils.to_categorical(y_train, 2)
y_test_cat = keras.utils.to_categorical(y_test, 2)

print(f'Train: {X_train.shape}, Test: {X_test.shape}')

## Compute Graph Structure

In [None]:
def compute_adjacency_from_pearson(data):
    n_epochs, n_channels, n_timepoints = data.shape
    data_reshaped = data.transpose(1, 0, 2).reshape(n_channels, -1)
    correlation_matrix = np.corrcoef(data_reshaped)
    adjacency = np.abs(correlation_matrix)
    np.fill_diagonal(adjacency, 1.0)
    return adjacency.astype(np.float32)

def compute_laplacian(adjacency):
    D = np.sum(adjacency, axis=1)
    D_sqrt_inv = np.diag(1.0 / np.sqrt(D + 1e-6))
    L = np.eye(len(adjacency)) - D_sqrt_inv @ adjacency @ D_sqrt_inv
    lambda_max = np.linalg.eigvalsh(L)[-1]
    L_rescaled = (2.0 / lambda_max) * L - np.eye(len(L))
    return L_rescaled.astype(np.float32)

adjacency = compute_adjacency_from_pearson(X_train)
L_rescaled = compute_laplacian(adjacency)
print(f'Adjacency: {adjacency.shape}, Laplacian: {L_rescaled.shape}')

## Chebyshev Graph Convolution

In [None]:
class ChebyshevGraphConv(layers.Layer):
    def __init__(self, num_filters, K=3, **kwargs):
        super().__init__(**kwargs)
        self.num_filters = num_filters
        self.K = K

    def build(self, input_shape):
        self.theta = [self.add_weight(shape=(input_shape[0][-1], self.num_filters),
                                     initializer='glorot_uniform', name=f'theta_{k}')
                     for k in range(self.K)]
        super().build(input_shape)

    def call(self, inputs):
        x, L_rescaled = inputs
        Tx_0 = x
        Tx_1 = tf.matmul(L_rescaled, x)
        out = tf.matmul(Tx_0, self.theta[0])
        if self.K > 1:
            out += tf.matmul(Tx_1, self.theta[1])
        for k in range(2, self.K):
            Tx_2 = 2 * tf.matmul(L_rescaled, Tx_1) - Tx_0
            out += tf.matmul(Tx_2, self.theta[k])
            Tx_0, Tx_1 = Tx_1, Tx_2
        return out

## CNN Model

In [None]:
def create_cnn(input_shape, num_classes):
    inputs = layers.Input(shape=input_shape)
    x = layers.Permute((2, 1))(inputs)
    
    x = layers.Conv1D(64, 50, activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(2)(x)
    x = layers.Dropout(0.3)(x)
    
    x = layers.Conv1D(128, 25, activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(2)(x)
    x = layers.Dropout(0.3)(x)
    
    x = layers.Conv1D(256, 10, activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.GlobalAveragePooling1D()(x)
    
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    model = keras.Model(inputs, outputs)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

cnn_model = create_cnn((num_channels, num_timepoints), 2)
print(f'CNN: {cnn_model.count_params():,} parameters')

## GCN Model

In [None]:
def create_gcn(input_shape, num_classes, num_channels, L_rescaled):
    inputs = layers.Input(shape=input_shape)
    x = layers.Permute((2, 1))(inputs)
    
    x = layers.Conv1D(64, 50, activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(2)(x)
    x = layers.Dropout(0.3)(x)
    
    x = layers.Conv1D(128, 25, activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.GlobalAveragePooling1D()(x)
    cnn_features = x
    
    channel_input = layers.Input(shape=(num_channels, num_timepoints))
    channel_avg = layers.Lambda(lambda x: tf.reduce_mean(x, axis=-1, keepdims=True))(channel_input)
    
    L_tensor = tf.constant(L_rescaled, dtype=tf.float32)
    graph_features = ChebyshevGraphConv(64, K=2)([channel_avg, L_tensor])
    graph_features = layers.Flatten()(graph_features)
    
    combined = layers.Concatenate()([cnn_features, graph_features])
    x = layers.Dense(128, activation='relu')(combined)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    model = keras.Model([inputs, channel_input], outputs)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

gcn_model = create_gcn((num_channels, num_timepoints), 2, num_channels, L_rescaled)
print(f'GCN: {gcn_model.count_params():,} parameters')

## Train CNN

In [None]:
cnn_history = cnn_model.fit(
    X_train, y_train_cat,
    validation_split=0.2,
    epochs=30,
    batch_size=64,
    callbacks=[
        keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True),
        keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=5)
    ]
)

## Train GCN

In [None]:
gcn_history = gcn_model.fit(
    [X_train, X_train], y_train_cat,
    validation_split=0.2,
    epochs=30,
    batch_size=64,
    callbacks=[
        keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True),
        keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=5)
    ]
)

## Results

In [None]:
cnn_loss, cnn_acc = cnn_model.evaluate(X_test, y_test_cat, verbose=0)
gcn_loss, gcn_acc = gcn_model.evaluate([X_test, X_test], y_test_cat, verbose=0)

print('='*70)
print('FINAL RESULTS')
print('='*70)
print(f'CNN Accuracy: {cnn_acc*100:.2f}%')
print(f'GCN Accuracy: {gcn_acc*100:.2f}%')
print('='*70)

if gcn_acc > cnn_acc:
    print(f'WINNER: GCN (+{(gcn_acc-cnn_acc)*100:.2f}%)')
else:
    print(f'WINNER: CNN (+{(cnn_acc-gcn_acc)*100:.2f}%)')