In [1]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True' # To prevent the kernel from dying.

In [2]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")

if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")

PyTorch version: 2.5.1+cu121
CUDA available: True
CUDA version: 12.1
GPU device: NVIDIA GeForce GTX 1650
Number of GPUs: 1


In [3]:
import braindecode
from braindecode.models.util import models_dict
print(f"All the Braindecode models:\n{list(models_dict.keys())}")

All the Braindecode models:
['ATCNet', 'AttentionBaseNet', 'AttnSleep', 'BDTCN', 'BENDR', 'BIOT', 'BrainModule', 'CTNet', 'ContraWR', 'Deep4Net', 'DeepSleepNet', 'EEGConformer', 'EEGITNet', 'EEGInceptionERP', 'EEGInceptionMI', 'EEGMiner', 'EEGNeX', 'EEGNet', 'EEGSimpleConv', 'EEGSym', 'EEGTCNet', 'FBCNet', 'FBLightConvNet', 'FBMSNet', 'IFNet', 'LUNA', 'Labram', 'MEDFormer', 'MSVTNet', 'PBT', 'REVE', 'SCCNet', 'SPARCNet', 'SSTDPN', 'ShallowFBCSPNet', 'SignalJEPA', 'SignalJEPA_Contextual', 'SignalJEPA_PostLocal', 'SignalJEPA_PreLocal', 'SincShallowNet', 'SleepStagerBlanco2020', 'SleepStagerChambon2018', 'SyncNet', 'TIDNet', 'TSception', 'USleep']


In [4]:
from braindecode.models import ShallowFBCSPNet
from braindecode.util import set_random_seeds

In [5]:
print(ShallowFBCSPNet.__doc__)

# Set random seed to ensure reproducible initialization below
seed = 20240205
cuda = torch.cuda.is_available()
set_random_seeds(seed=seed, cuda=cuda)

Shallow ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.

    :bdg-success:`Convolution`

    .. figure:: https://onlinelibrary.wiley.com/cms/asset/221ea375-6701-40d3-ab3f-e411aad62d9e/hbm23730-fig-0002-m.jpg
        :align: center
        :alt: ShallowNet Architecture

    Model described in [Schirrmeister2017]_.

    Parameters
    ----------
    n_filters_time: int
        Number of temporal filters.
    filter_time_length: int
        Length of the temporal filter.
    n_filters_spat: int
        Number of spatial filters.
    pool_time_length: int
        Length of temporal pooling filter.
    pool_time_stride: int
        Length of stride between temporal pooling filters.
    final_conv_length: int | str
        Length of the final convolution layer.
        If set to "auto", length of the input signal must be specified.
    conv_nonlin: callable
        Non-linear function to be used after convolution layers.
    pool_mode: str
        Method to use on pooling

In [6]:
model = ShallowFBCSPNet(
    n_chans=32,
    n_times=1000,
    n_outputs=2,
    final_conv_length="auto",
)
print(model)

Layer (type (var_name):depth-idx)             Input Shape               Output Shape              Param #                   Kernel Shape
ShallowFBCSPNet (ShallowFBCSPNet)             [1, 32, 1000]             [1, 2]                    --                        --
├─Ensure4d (ensuredims): 1-1                  [1, 32, 1000]             [1, 32, 1000, 1]          --                        --
├─Rearrange (dimshuffle): 1-2                 [1, 32, 1000, 1]          [1, 1, 1000, 32]          --                        --
├─CombinedConv (conv_time_spat): 1-3          [1, 1, 1000, 32]          [1, 40, 976, 1]           52,240                    --
├─BatchNorm2d (bnorm): 1-4                    [1, 40, 976, 1]           [1, 40, 976, 1]           80                        --
├─Expression (conv_nonlin_exp): 1-5           [1, 40, 976, 1]           [1, 40, 976, 1]           --                        --
├─AvgPool2d (pool): 1-6                       [1, 40, 976, 1]           [1, 40, 61, 1]            -- 

In [7]:
import mne
import numpy as np

info = mne.create_info(ch_names=["C3", "C4", "Cz"], sfreq=256.0, ch_types="eeg")
X = np.random.randn(100, 3, 1024)  # 100 epochs, 3 channels, 4 seconds (@256Hz)
epochs = mne.EpochsArray(X, info=info)
y = np.random.randint(0, 4, size=100)  # 4 classes
print(epochs)

Not setting metadata
100 matching events found
No baseline correction applied
0 projection items activated
<EpochsArray | 100 events (all good), 0 – 3.996 s (baseline off), ~2.4 MiB, data loaded,
 '1': 100>


In [8]:
from skorch.dataset import ValidSplit

from braindecode import EEGClassifier

net = EEGClassifier(
    "ShallowFBCSPNet",
    module__final_conv_length="auto",
    train_split=ValidSplit(0.2),
)

In [9]:
from braindecode.datautil import infer_signal_properties

sig_props = infer_signal_properties(epochs, y, mode="classification")
print(f"Inferred signal properties:\n{sig_props}")

Inferred signal properties:
{'n_outputs': 4, 'n_times': 1024, 'sfreq': 256.0, 'chs_info': [{'loc': array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]), 'unit_mul': 0 (FIFF_UNITM_NONE), 'range': 1.0, 'cal': 1.0, 'kind': 2 (FIFFV_EEG_CH), 'coil_type': 1 (FIFFV_COIL_EEG), 'unit': 107 (FIFF_UNIT_V), 'coord_frame': 4 (FIFFV_COORD_HEAD), 'ch_name': 'C3', 'scanno': 1, 'logno': 1}, {'loc': array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]), 'unit_mul': 0 (FIFF_UNITM_NONE), 'range': 1.0, 'cal': 1.0, 'kind': 2 (FIFFV_EEG_CH), 'coil_type': 1 (FIFFV_COIL_EEG), 'unit': 107 (FIFF_UNIT_V), 'coord_frame': 4 (FIFFV_COORD_HEAD), 'ch_name': 'C4', 'scanno': 2, 'logno': 2}, {'loc': array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]), 'unit_mul': 0 (FIFF_UNITM_NONE), 'range': 1.0, 'cal': 1.0, 'kind': 2 (FIFFV_EEG_CH), 'coil_type': 1 (FIFFV_COIL_EEG), 'unit': 107 (FIFF_UNIT_V), 'coord_frame': 4 (FIFFV_COORD_HEAD), 'ch_name': 'Cz', 'scanno': 3, 'logno': 3}]}


In [10]:
net.fit(epochs, y)

  epoch    valid_acc    valid_loss     dur
-------  -----------  ------------  ------
      1       [36m0.3000[0m       [32m15.6130[0m  0.0269
      2       0.3000       15.6130  0.0067
      3       0.3000       15.6130  0.0092
      4       0.3000       15.6130  0.0085
      5       0.3000       15.6130  0.0119
      6       0.3000       15.6130  0.0071
      7       0.3000       15.6130  0.0094
      8       0.3000       15.6130  0.0120
      9       0.3000       15.6130  0.0247
     10       0.3000       15.6130  0.0231


0,1,2
,module,'ShallowFBCSPNet'
,criterion,<class 'torch...sEntropyLoss'>
,cropped,False
,callbacks,
,iterator_train__shuffle,True
,iterator_train__drop_last,True
,aggregate_predictions,True
,optimizer,<class 'torch.optim.sgd.SGD'>
,lr,0.01
,max_epochs,10


In [11]:
print(net.module_)

Layer (type (var_name):depth-idx)             Input Shape               Output Shape              Param #                   Kernel Shape
ShallowFBCSPNet (ShallowFBCSPNet)             [1, 3, 1024]              [1, 4]                    --                        --
├─Ensure4d (ensuredims): 1-1                  [1, 3, 1024]              [1, 3, 1024, 1]           --                        --
├─Rearrange (dimshuffle): 1-2                 [1, 3, 1024, 1]           [1, 1, 1024, 3]           --                        --
├─CombinedConv (conv_time_spat): 1-3          [1, 1, 1024, 3]           [1, 40, 1000, 1]          5,840                     --
├─BatchNorm2d (bnorm): 1-4                    [1, 40, 1000, 1]          [1, 40, 1000, 1]          80                        --
├─Expression (conv_nonlin_exp): 1-5           [1, 40, 1000, 1]          [1, 40, 1000, 1]          --                        --
├─AvgPool2d (pool): 1-6                       [1, 40, 1000, 1]          [1, 40, 62, 1]            -- 