In [1]:
from braindecode.models.util import models_dict


In [2]:
print(f'All the Braindecode models:\n{list(models_dict.keys())}')


All the Braindecode models:
['ATCNet', 'Deep4Net', 'DeepSleepNet', 'EEGConformer', 'EEGITNet', 'EEGInception', 'EEGInceptionERP', 'EEGInceptionMI', 'EEGNetv1', 'EEGNetv4', 'EEGResNet', 'HybridNet', 'ShallowFBCSPNet', 'SleepStagerBlanco2020', 'SleepStagerChambon2018', 'SleepStagerEldele2021', 'TCN', 'TIDNet', 'USleep']


In [4]:
from braindecode.models import ShallowFBCSPNet
print(ShallowFBCSPNet.__doc__)

Shallow ConvNet model from Schirrmeister et al 2017.

Model described in [Schirrmeister2017]_.

Parameters
----------
n_chans : int
    Number of EEG channels.
n_outputs : int
    Number of outputs of the model. This is the number of classes
    in the case of classification.
n_times : int
    Number of time samples of the input window.
n_filters_time: int
    Number of temporal filters.
filter_time_length: int
    Length of the temporal filter.
n_filters_spat: int
    Number of spatial filters.
pool_time_length: int
    Length of temporal pooling filter.
pool_time_stride: int
    Length of stride between temporal pooling filters.
final_conv_length: int | str
    Length of the final convolution layer.
    If set to "auto", length of the input signal must be specified.
conv_nonlin: callable
    Non-linear function to be used after convolution layers.
pool_mode: str
    Method to use on pooling layers. "max" or "mean".
pool_nonlin: callable
    Non-linear function to be used after poolin

In [5]:
model = ShallowFBCSPNet(
    n_chans=32,
    n_times=1000,
    n_outputs=2,
    final_conv_length='auto',
)
print(model)

Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Kernel Shape
ShallowFBCSPNet (ShallowFBCSPNet)        [1, 32, 1000]             [1, 2]                    --                        --
├─Ensure4d (ensuredims): 1-1             [1, 32, 1000]             [1, 32, 1000, 1]          --                        --
├─Rearrange (dimshuffle): 1-2            [1, 32, 1000, 1]          [1, 1, 1000, 32]          --                        --
├─CombinedConv (conv_time_spat): 1-3     [1, 1, 1000, 32]          [1, 40, 976, 1]           52,240                    --
├─BatchNorm2d (bnorm): 1-4               [1, 40, 976, 1]           [1, 40, 976, 1]           80                        --
├─Expression (conv_nonlin_exp): 1-5      [1, 40, 976, 1]           [1, 40, 976, 1]           --                        --
├─AvgPool2d (pool): 1-6                  [1, 40, 976, 1]           [1, 40, 61, 1]            --                        [75, 1]
├─Express



In [6]:
import mne
import numpy as np

In [7]:
info = mne.create_info(ch_names=['C3', 'C4', 'Cz'], sfreq=256., ch_types='eeg')
X = np.random.randn(100, 3, 1024)  # 100 epochs, 3 channels, 4 seconds (@256Hz)
epochs = mne.EpochsArray(X, info=info)
y = np.random.randint(0, 4, size=100)  # 4 classes
print(epochs)

Not setting metadata
100 matching events found
No baseline correction applied
0 projection items activated
<EpochsArray | 100 events (all good), 0 – 3.996 s (baseline off), ~2.4 MiB, data loaded,
 '1': 100>


In [9]:
from skorch.dataset import ValidSplit
from braindecode import EEGClassifier

net = EEGClassifier(
    'ShallowFBCSPNet',
    module__final_conv_length='auto',
    train_split=ValidSplit(0.2),
    # To train a neural network you need validation split, here, we use 20%.
)

In [10]:
net.fit(epochs, y)

  epoch    valid_acc    valid_loss     dur
-------  -----------  ------------  ------
      1       [36m0.2000[0m       [32m51.5983[0m  0.0380
      2       0.2000       51.5983  0.0110
      3       0.2000       51.5983  0.0100
      4       0.2000       51.5983  0.0117
      5       0.2000       51.5983  0.0116
      6       0.2000       51.5983  0.0070
      7       0.2000       51.5983  0.0090
      8       0.2000       51.5983  0.0080
      9       0.2000       51.5983  0.0060
     10       0.2000       51.5983  0.0090




In [11]:
print(net.module_)

Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Kernel Shape
ShallowFBCSPNet (ShallowFBCSPNet)        [1, 3, 1024]              [1, 4]                    --                        --
├─Ensure4d (ensuredims): 1-1             [1, 3, 1024]              [1, 3, 1024, 1]           --                        --
├─Rearrange (dimshuffle): 1-2            [1, 3, 1024, 1]           [1, 1, 1024, 3]           --                        --
├─CombinedConv (conv_time_spat): 1-3     [1, 1, 1024, 3]           [1, 40, 1000, 1]          5,840                     --
├─BatchNorm2d (bnorm): 1-4               [1, 40, 1000, 1]          [1, 40, 1000, 1]          80                        --
├─Expression (conv_nonlin_exp): 1-5      [1, 40, 1000, 1]          [1, 40, 1000, 1]          --                        --
├─AvgPool2d (pool): 1-6                  [1, 40, 1000, 1]          [1, 40, 62, 1]            --                        [75, 1]
├─Express

In [12]:
print(f'{net.module_.n_chans=}\n{net.module_.n_times=}\n{net.module_.n_outputs=}'
      f'\n{net.module_.input_window_seconds=}\n{net.module_.sfreq=}\n{net.module_.chs_info=}')

net.module_.n_chans=3
net.module_.n_times=1024
net.module_.n_outputs=4
net.module_.input_window_seconds=4.0
net.module_.sfreq=256.0
net.module_.chs_info=[{'loc': array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]), 'unit_mul': 0 (FIFF_UNITM_NONE), 'range': 1.0, 'cal': 1.0, 'kind': 2 (FIFFV_EEG_CH), 'coil_type': 1 (FIFFV_COIL_EEG), 'unit': 107 (FIFF_UNIT_V), 'coord_frame': 4 (FIFFV_COORD_HEAD), 'ch_name': 'C3', 'scanno': 1, 'logno': 1}, {'loc': array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]), 'unit_mul': 0 (FIFF_UNITM_NONE), 'range': 1.0, 'cal': 1.0, 'kind': 2 (FIFFV_EEG_CH), 'coil_type': 1 (FIFFV_COIL_EEG), 'unit': 107 (FIFF_UNIT_V), 'coord_frame': 4 (FIFFV_COORD_HEAD), 'ch_name': 'C4', 'scanno': 2, 'logno': 2}, {'loc': array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]), 'unit_mul': 0 (FIFF_UNITM_NONE), 'range': 1.0, 'cal': 1.0, 'kind': 2 (FIFFV_EEG_CH), 'coil_type': 1 (FIFFV_COIL_EEG), 'unit': 107 (FIFF_UNIT_V), 'coord_frame': 4 (FI