In [1]:
%load_ext autoreload
%autoreload 2
import os
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')

# Read and Decode BBCI Data with Start-Stop-Markers

This tutorial shows how to read and decode BBCI data with start and stop markers. The data loading part is more complicated and it is advised to read the other tutorials before.

## Setup logging to see outputs

In [2]:
import logging
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.DEBUG, stream=sys.stdout)
log = logging.getLogger()


## Load and preprocess data

This is a bit more complicated than before since we have to add breaks etc. Here I now opt to add breaks do all preprocessings per run and only later combine them together.

In [3]:
import numpy as np
from braindecode.datautil.splitters import concatenate_sets
from braindecode.datautil.trial_segment import create_signal_target_from_raw_mne, add_breaks
from braindecode.datasets.bbci import load_bbci_sets_from_folder
from copy import deepcopy
from braindecode.mne_ext.signalproc import resample_cnt, mne_apply
from braindecode.datautil.signalproc import lowpass_cnt
from braindecode.datautil.signalproc import exponential_running_standardize

def create_cnts(folder, runs,):
    # Load data
    cnts = load_bbci_sets_from_folder(folder, runs)
    
    # Now do some preprocessings:
    # Resampling to 250 Hz, lowpass below 38, eponential standardization
    
    new_cnts = []
    for cnt in cnts:
        # Only take some channels 
        #cnt = cnt.drop_channels(['STI 014']) # This would remove stimulus channel
        cnt = cnt.pick_channels(['C3', 'CPz', 'C4'])
        log.info("Preprocessing....")
        cnt = mne_apply(lambda a: lowpass_cnt(a, 38,cnt.info['sfreq'], axis=1), cnt)
        cnt = resample_cnt(cnt, 250)
        # mne apply will apply the function to the data (a 2d-numpy-array)
        # have to transpose data back and forth, since
        # exponential_running_standardize expects time x chans order
        # while mne object has chans x time order
        cnt = mne_apply(lambda a: exponential_running_standardize(
            a.T, init_block_size=1000,factor_new=0.001, eps=1e-4).T,
            cnt)
        new_cnts.append(cnt)
    return new_cnts

In [4]:
from collections import OrderedDict

train_runs = [1,2,3]
train_cnts = create_cnts('/home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R01-8/', 
                         train_runs,)

name_to_start_code = OrderedDict([('Right Hand', 1), ('Feet', 4),
            ('Rotation', 8), ('Words', [10])])

name_to_stop_code = OrderedDict([('Right Hand', [20,21,22,23,24,28,30]),
        ('Feet', [20,21,22,23,24,28,30]),
        ('Rotation', [20,21,22,23,24,28,30]), 
        ('Words', [20,21,22,23,24,28,30])])


2017-11-03 17:54:16,149 INFO : Loading /home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R01-8/AnLaNBD1S001R01_1-1_250Hz.BBCI.mat
Creating RawArray with float64 data, n_channels=64, n_times=151350
    Range : 0 ... 151349 =      0.000 ...   605.396 secs
Ready.
2017-11-03 17:54:17,481 INFO : Loading /home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R01-8/AnLaNBD1S001R02_1-1_250Hz.BBCI.mat
Creating RawArray with float64 data, n_channels=64, n_times=153500
    Range : 0 ... 153499 =      0.000 ...   613.996 secs
Ready.
2017-11-03 17:54:18,843 INFO : Loading /home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R01-8/AnLaNBD1S001R03_1-1_250Hz.BBCI.mat
Creating RawArray with float64 data, n_channels=64, n_times=180700
    Range : 0 ... 180699 =      0.000 ...   722.796 secs
Ready.
2017-11-03 17:54:20,609 INFO : Preprocessing....
2017-11-03 17:54:20,626 INFO : Just copying data, no resampling, since new sampling rate same.
2017-11-03 17:54:20,801 INFO : Preprocessing....
2017-11-03 17:54:20,809 INFO : Just copyin

In [5]:
test_runs = [9,10]
test_cnts = create_cnts('/home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R09-10/', test_runs,)

2017-11-03 17:54:21,179 INFO : Loading /home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R09-10/AnLaNBD1S001R09_1-1_250Hz.BBCI.mat
Creating RawArray with float64 data, n_channels=64, n_times=152050
    Range : 0 ... 152049 =      0.000 ...   608.196 secs
Ready.
2017-11-03 17:54:22,618 INFO : Loading /home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R09-10/AnLaNBD1S001R10_1-1_250Hz.BBCI.mat
Creating RawArray with float64 data, n_channels=64, n_times=151100
    Range : 0 ... 151099 =      0.000 ...   604.396 secs
Ready.
2017-11-03 17:54:24,103 INFO : Preprocessing....
2017-11-03 17:54:24,112 INFO : Just copying data, no resampling, since new sampling rate same.
2017-11-03 17:54:24,277 INFO : Preprocessing....
2017-11-03 17:54:24,284 INFO : Just copying data, no resampling, since new sampling rate same.


## Create the model

We already create the model now, since we need to know the receptive field size for properly cutting out the data to predict. We need to cut out data starting at -receptive_field_size samples before the first sample we want to predict.

In [6]:
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from torch import nn
from braindecode.torch_ext.util import set_random_seeds
from braindecode.models.util import to_dense_prediction_model

# Set if you want to use GPU
# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
cuda = True
set_random_seeds(seed=20170629, cuda=cuda)

# This will determine how many crops are processed in parallel
input_time_length = 650
in_chans = train_cnts[0].get_data().shape[0]
# final_conv_length determines the size of the receptive field of the ConvNet
model = ShallowFBCSPNet(in_chans=in_chans, n_classes=5, input_time_length=input_time_length,
                        final_conv_length=29).create_network()
to_dense_prediction_model(model)

if cuda:
    model.cuda()
from braindecode.torch_ext.util import np_to_var
import numpy as np
# determine output size
test_input = np_to_var(np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
if cuda:
    test_input = test_input.cuda()
out = model(test_input)
n_preds_per_input = out.cpu().data.numpy().shape[2]
print("{:d} predictions per input/trial".format(n_preds_per_input))
n_receptive_field = input_time_length - n_preds_per_input
receptive_field_ms = n_receptive_field * 1000.0 / train_cnts[0].info['sfreq']
print("Receptive field: {:d}/{:.2f} (samples/ms)".format(n_receptive_field,
                                                      receptive_field_ms))

132 predictions per input/trial
Receptive field: 518/2072.00 (samples/ms)


### Create SignalAndTarget Sets

In [7]:
from braindecode.datautil.trial_segment import create_signal_target_with_breaks_from_mne

break_start_offset_ms = 1000
break_stop_offset_ms = -500

train_sets = [create_signal_target_with_breaks_from_mne(
    cnt, name_to_start_code, [0,0], 
    name_to_stop_code, min_break_length_ms=1000, max_break_length_ms=10000,
    break_epoch_ival_ms=[500,-500],
    prepad_trials_to_n_samples=input_time_length) 
              for cnt in train_cnts]
train_set = concatenate_sets(train_sets)

2017-11-03 17:54:27,739 INFO : Trial per class:
Counter({'Break': 72, 'Right Hand': 29, 'Words': 21, 'Feet': 19, 'Rotation': 4})
2017-11-03 17:54:27,766 INFO : Trial per class:
Counter({'Break': 80, 'Feet': 31, 'Words': 26, 'Right Hand': 18, 'Rotation': 6})
2017-11-03 17:54:27,796 INFO : Trial per class:
Counter({'Break': 95, 'Feet': 38, 'Words': 29, 'Right Hand': 22, 'Rotation': 7})


In [8]:
test_sets = [create_signal_target_with_breaks_from_mne(
    cnt, name_to_start_code, [0,0], 
    name_to_stop_code, min_break_length_ms=1000, max_break_length_ms=10000,
    break_epoch_ival_ms=[500,-500],
    prepad_trials_to_n_samples=input_time_length) 
              for cnt in test_cnts]
test_set = concatenate_sets(test_sets)

2017-11-03 17:54:27,856 INFO : Trial per class:
Counter({'Break': 76, 'Feet': 24, 'Right Hand': 24, 'Words': 19, 'Rotation': 10})
2017-11-03 17:54:27,884 INFO : Trial per class:
Counter({'Break': 80, 'Feet': 30, 'Right Hand': 22, 'Words': 21, 'Rotation': 8})


In [9]:
from braindecode.datautil.splitters import split_into_two_sets

train_set, valid_set = split_into_two_sets(train_set, first_set_fraction=0.8)


## Setup optimizer and iterator

In [10]:
from torch import optim

optimizer = optim.Adam(model.parameters())

In [11]:
from braindecode.datautil.iterators import CropsFromTrialsIterator
iterator = CropsFromTrialsIterator(batch_size=32,input_time_length=input_time_length,
                                  n_preds_per_input=n_preds_per_input)

## Setup Monitors, Loss function, Stop Criteria

In [12]:
from braindecode.experiments.experiment import Experiment
from braindecode.experiments.monitors import RuntimeMonitor, LossMonitor, CroppedTrialMisclassMonitor, MisclassMonitor
from braindecode.experiments.stopcriteria import MaxEpochs
import torch.nn.functional as F
import torch as th
from braindecode.torch_ext.modules import Expression
from braindecode.torch_ext.losses import log_categorical_crossentropy


loss_function = log_categorical_crossentropy

model_constraint = None
monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),
            CroppedTrialMisclassMonitor(input_time_length), RuntimeMonitor(),]
stop_criterion = MaxEpochs(20)
exp = Experiment(model, train_set, valid_set, test_set, iterator, loss_function, optimizer, model_constraint,
          monitors, stop_criterion, remember_best_column='valid_misclass',
          run_after_early_stop=True, batch_modifier=None, cuda=cuda)

## Run experiment

In [13]:
exp.run()

2017-11-03 17:54:28,100 INFO : Run until first stop...
2017-11-03 17:54:29,008 INFO : Epoch 0
2017-11-03 17:54:29,011 INFO : train_loss                6.69229
2017-11-03 17:54:29,012 INFO : valid_loss                6.46568
2017-11-03 17:54:29,013 INFO : test_loss                 7.03033
2017-11-03 17:54:29,014 INFO : train_sample_misclass     0.82260
2017-11-03 17:54:29,015 INFO : valid_sample_misclass     0.80963
2017-11-03 17:54:29,016 INFO : test_sample_misclass      0.84529
2017-11-03 17:54:29,017 INFO : train_misclass            0.84673
2017-11-03 17:54:29,018 INFO : valid_misclass            0.83838
2017-11-03 17:54:29,019 INFO : test_misclass             0.87261
2017-11-03 17:54:29,020 INFO : runtime                   0.00000
2017-11-03 17:54:29,021 INFO : 
2017-11-03 17:54:29,023 INFO : New best valid_misclass: 0.838384
2017-11-03 17:54:29,024 INFO : 
2017-11-03 17:54:29,883 INFO : Time only for training updates: 0.76s
2017-11-03 17:54:30,662 INFO : Epoch 1
2017-11-03 17:54:30

2017-11-03 17:54:43,461 INFO : test_misclass             0.46497
2017-11-03 17:54:43,462 INFO : runtime                   1.57837
2017-11-03 17:54:43,462 INFO : 
2017-11-03 17:54:44,271 INFO : Time only for training updates: 0.71s
2017-11-03 17:54:45,020 INFO : Epoch 10
2017-11-03 17:54:45,021 INFO : train_loss                0.64797
2017-11-03 17:54:45,022 INFO : valid_loss                1.28604
2017-11-03 17:54:45,023 INFO : test_loss                 1.41996
2017-11-03 17:54:45,023 INFO : train_sample_misclass     0.24798
2017-11-03 17:54:45,024 INFO : valid_sample_misclass     0.44050
2017-11-03 17:54:45,025 INFO : test_sample_misclass      0.47793
2017-11-03 17:54:45,026 INFO : train_misclass            0.23869
2017-11-03 17:54:45,027 INFO : valid_misclass            0.41414
2017-11-03 17:54:45,027 INFO : test_misclass             0.44904
2017-11-03 17:54:45,028 INFO : runtime                   1.56668
2017-11-03 17:54:45,029 INFO : 
2017-11-03 17:54:45,847 INFO : Time only for tr

2017-11-03 17:55:00,944 INFO : Epoch 20
2017-11-03 17:55:00,945 INFO : train_loss                0.50380
2017-11-03 17:55:00,946 INFO : valid_loss                1.28838
2017-11-03 17:55:00,947 INFO : test_loss                 1.54519
2017-11-03 17:55:00,947 INFO : train_sample_misclass     0.17429
2017-11-03 17:55:00,948 INFO : valid_sample_misclass     0.43744
2017-11-03 17:55:00,949 INFO : test_sample_misclass      0.48286
2017-11-03 17:55:00,950 INFO : train_misclass            0.17085
2017-11-03 17:55:00,950 INFO : valid_misclass            0.38384
2017-11-03 17:55:00,951 INFO : test_misclass             0.42994
2017-11-03 17:55:00,952 INFO : runtime                   1.58936
2017-11-03 17:55:00,953 INFO : 
2017-11-03 17:55:00,954 INFO : Setup for second stop...
2017-11-03 17:55:00,957 INFO : Train loss to reach 0.60028
2017-11-03 17:55:00,958 INFO : Run until second stop...
2017-11-03 17:55:01,805 INFO : Epoch 12
2017-11-03 17:55:01,806 INFO : train_loss                0.71588
20

2017-11-03 17:55:18,757 INFO : train_sample_misclass     0.21574
2017-11-03 17:55:18,758 INFO : valid_sample_misclass     0.28888
2017-11-03 17:55:18,759 INFO : test_sample_misclass      0.51284
2017-11-03 17:55:18,760 INFO : train_misclass            0.19517
2017-11-03 17:55:18,760 INFO : valid_misclass            0.28283
2017-11-03 17:55:18,761 INFO : test_misclass             0.48089
2017-11-03 17:55:18,762 INFO : runtime                   1.88302
2017-11-03 17:55:18,763 INFO : 
2017-11-03 17:55:19,771 INFO : Time only for training updates: 0.89s
2017-11-03 17:55:20,627 INFO : Epoch 22
2017-11-03 17:55:20,628 INFO : train_loss                0.65365
2017-11-03 17:55:20,629 INFO : valid_loss                0.82897
2017-11-03 17:55:20,629 INFO : test_loss                 1.40076
2017-11-03 17:55:20,630 INFO : train_sample_misclass     0.25455
2017-11-03 17:55:20,631 INFO : valid_sample_misclass     0.33560
2017-11-03 17:55:20,632 INFO : test_sample_misclass      0.48546
2017-11-03 17:

We arrive at about 54% accuracy. With only 3 sensors and 3 training runs, we cannot get much better :)