In [1]:
%matplotlib inline


# Custom Dataset Example

This example shows how to convert data X and y as numpy arrays to a braindecode
compatible data format.


In [2]:
# Authors: Lukas Gemein <l.gemein@gmail.com>
#
# License: BSD (3-clause)

import mne

from braindecode.datasets import create_from_X_y

To set up the example, we first fetch some data using mne:



In [3]:
# 5, 6, 7, 10, 13, 14 are codes for executed and imagined hands/feet
subject_id = 22
event_codes = [5, 6, 9, 10, 13, 14]
# event_codes = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

# This will download the files if you don't have them yet,
# and then return the paths to the files.
physionet_paths = mne.datasets.eegbci.load_data(
    subject_id, event_codes, update_path=False)

# Load each of the files
parts = [mne.io.read_raw_edf(path, preload=True, stim_channel='auto')
         for path in physionet_paths]

Extracting EDF parameters from D:\mne_data\MNE-eegbci-data\files\eegmmidb\1.0.0\S022\S022R05.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from D:\mne_data\MNE-eegbci-data\files\eegmmidb\1.0.0\S022\S022R06.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from D:\mne_data\MNE-eegbci-data\files\eegmmidb\1.0.0\S022\S022R09.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from D:\mne_data\MNE-eegbci-data\files\eegmmidb\1.0.0\S022\S022R10.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from D:\mne_data\M

We take the required data, targets and additional information sampling
frequency and channel names from the loaded data. Note that this data and
information can originate from any source.



In [4]:
X = [raw.get_data() for raw in parts]
y = event_codes
sfreq = parts[0].info["sfreq"]
ch_names = parts[0].info["ch_names"]

Convert to data format compatible with skorch and braindecode:



In [5]:
windows_dataset = create_from_X_y(
    X, y, drop_last_window=False, sfreq=sfreq, ch_names=ch_names,
    window_stride_samples=500,
    window_size_samples=500,
)

windows_dataset.description  # look as dataset description

Creating RawArray with float64 data, n_channels=64, n_times=20000
    Range : 0 ... 19999 =      0.000 ...   124.994 secs
Ready.
Creating RawArray with float64 data, n_channels=64, n_times=20000
    Range : 0 ... 19999 =      0.000 ...   124.994 secs
Ready.
Creating RawArray with float64 data, n_channels=64, n_times=20000
    Range : 0 ... 19999 =      0.000 ...   124.994 secs
Ready.
Creating RawArray with float64 data, n_channels=64, n_times=20000
    Range : 0 ... 19999 =      0.000 ...   124.994 secs
Ready.
Creating RawArray with float64 data, n_channels=64, n_times=20000
    Range : 0 ... 19999 =      0.000 ...   124.994 secs
Ready.
Creating RawArray with float64 data, n_channels=64, n_times=20000
    Range : 0 ... 19999 =      0.000 ...   124.994 secs
Ready.
Adding metadata with 4 columns
40 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 40 events and 500 original time points ...
0 bad epochs dropped
Adding metad

Unnamed: 0,target
0,5
1,6
2,9
3,10
4,13
5,14


You can manipulate the dataset



In [6]:
print(len(windows_dataset))  # get the number of samples

240


You can now index the data



In [7]:
i = 0
x_i, y_i, window_ind = windows_dataset[0]
n_channels, n_times = x_i.shape  # the EEG data
_, start_ind, stop_ind = window_ind
print(f"n_channels={n_channels}  -- n_times={n_times} -- y_i={y_i}")
print(f"start_ind={start_ind} -- stop_ind={stop_ind}")

Using data from preloaded Raw for 1 events and 500 original time points ...
n_channels=64  -- n_times=500 -- y_i=5
start_ind=0 -- stop_ind=500


In [9]:
windows_dataset.description

Unnamed: 0,target
0,5
1,6
2,9
3,10
4,13
5,14


In [10]:
windows_dataset.get_metadata()

Unnamed: 0,i_window_in_trial,i_start_in_trial,i_stop_in_trial,target
0,0,0,500,5
1,1,500,1000,5
2,2,1000,1500,5
3,3,1500,2000,5
4,4,2000,2500,5
...,...,...,...,...
35,35,17500,18000,14
36,36,18000,18500,14
37,37,18500,19000,14
38,38,19000,19500,14


In [11]:
windows_dataset.datasets

[<braindecode.datasets.base.WindowsDataset at 0x21527298d68>,
 <braindecode.datasets.base.WindowsDataset at 0x215272f09b0>,
 <braindecode.datasets.base.WindowsDataset at 0x215272f0ef0>,
 <braindecode.datasets.base.WindowsDataset at 0x2152e96a320>,
 <braindecode.datasets.base.WindowsDataset at 0x2152e96a438>,
 <braindecode.datasets.base.WindowsDataset at 0x2152e96afd0>]

In [12]:
windows_dataset.datasets[0]

<braindecode.datasets.base.WindowsDataset at 0x21527298d68>

In [15]:
windows_dataset.datasets[0][0]

Using data from preloaded Raw for 1 events and 500 original time points ...


(array([[ 1.10e-05,  1.10e-05, -1.00e-06, ..., -6.80e-05, -9.80e-05,
         -8.50e-05],
        [ 2.80e-05,  2.70e-05,  2.80e-05, ..., -9.00e-05, -1.16e-04,
         -1.15e-04],
        [ 2.00e-06,  1.00e-06, -5.00e-06, ..., -5.50e-05, -9.20e-05,
         -9.10e-05],
        ...,
        [-5.90e-05, -4.30e-05, -3.00e-05, ..., -1.12e-04, -1.05e-04,
         -9.80e-05],
        [-6.80e-05, -6.10e-05, -5.30e-05, ..., -7.90e-05, -7.10e-05,
         -6.40e-05],
        [-5.10e-05, -3.90e-05, -2.70e-05, ..., -5.40e-05, -5.40e-05,
         -5.20e-05]], dtype=float32),
 5,
 [0, 0, 500])

In [16]:
import torch
from braindecode.util import set_random_seeds
from braindecode.models import ShallowFBCSPNet

cuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it
device = 'cuda' if cuda else 'cpu'
if cuda:
    torch.backends.cudnn.benchmark = True
# Set random seed to be able to roughly reproduce results
# Note that with cudnn benchmark set to True, GPU indeterminism
# may still make results substantially different between runs.
# To obtain more consistent results at the cost of increased computation time,
# you can set `cudnn_benchmark=False` in `set_random_seeds`
# or remove `torch.backends.cudnn.benchmark = True`
seed = 20200220
set_random_seeds(seed=seed, cuda=cuda)

n_classes = 4
# Extract number of chans and time steps from dataset
n_chans = windows_dataset[0][0].shape[0]
input_window_samples = windows_dataset[0][0].shape[1]

model = ShallowFBCSPNet(
    n_chans,
    n_classes,
    input_window_samples=input_window_samples,
    final_conv_length='auto',
)

# Send model to GPU
if cuda:
    model.cuda()

Using data from preloaded Raw for 1 events and 500 original time points ...
Using data from preloaded Raw for 1 events and 500 original time points ...


In [17]:
print(n_chans, n_classes, input_window_samples)

64 4 500


In [25]:
from skorch.callbacks import LRScheduler
from skorch.helper import predefined_split

from braindecode import EEGClassifier
# These values we found good for shallow network:
lr = 0.0625 * 0.01
weight_decay = 0

# For deep4 they should be:
# lr = 1 * 0.01
# weight_decay = 0.5 * 0.001

batch_size = 64
n_epochs = 4

clf = EEGClassifier(
    model,
    criterion=torch.nn.NLLLoss,
    optimizer=torch.optim.AdamW,
    #train_split=predefined_split(valid_set),  # using valid_set for validation
    optimizer__lr=lr,
    optimizer__weight_decay=weight_decay,
    batch_size=batch_size,
    callbacks=[
        "accuracy", ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
    ],
    device=device,
)
# Model training for a specified number of epochs. `y` is None as it is already supplied
# in the dataset.
clf.fit(windows_dataset.datasets[0], y=windows_dataset.datasets[0].y, epochs=n_epochs)

TypeError: Cannot convert this data type to a numpy array.

In [22]:
windows_dataset.get_metadata()['target']

0      5
1      5
2      5
3      5
4      5
      ..
35    14
36    14
37    14
38    14
39    14
Name: target, Length: 240, dtype: int64

In [27]:
windows_dataset

<braindecode.datasets.base.BaseConcatDataset at 0x21527344ac8>