In [1]:
%load_ext autoreload
%autoreload 2
import os
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')

# Read and Decode BBCI Data with Start-Stop-Markers

This tutorial shows how to read and decode BBCI data with start and stop markers.

## Setup logging to see outputs

In [2]:
import logging
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.DEBUG, stream=sys.stdout)
log = logging.getLogger()


## Load and preprocess data

This is a bit more complicated than before since we ahve to add breaks etc. Here I now opt to add breaks do all preprocessings per run and only later combine them together.

In [3]:
from braindecode.datautil.splitters import concatenate_sets
from braindecode.datautil.trial_segment import create_signal_target_from_raw_mne, add_breaks
from braindecode.datasets.bbci import load_bbci_sets_from_folder
from copy import deepcopy
from braindecode.mne_ext.signalproc import resample_cnt, mne_apply
from braindecode.datautil.signalproc import lowpass_cnt
from braindecode.datautil.signalproc import exponential_running_standardize

def create_cnts(folder, runs, name_to_start_code, name_to_stop_code, break_start_offset_ms,
              break_stop_offset_ms, break_start_code, break_stop_code):
    # Load data
    cnts = load_bbci_sets_from_folder(folder, runs)
    
    # Now do some preprocessings:
    # Adding breaks, resampling to 250 Hz, lowpass below 38, eponential standardization
    break_start_code = -1
    break_stop_code = -2
    
    new_cnts = []
    for cnt in cnts:
        # Only take some channels 
        #cnt = cnt.drop_channels(['STI 014']) # This would remove stimulus channel
        cnt = cnt.pick_channels(['C3', 'CPz' 'C4'])
        # add breaks
        new_events = add_breaks(
            cnt.info['events'], cnt.info['sfreq'],
            break_start_code=break_start_code, break_stop_code=break_stop_code,
            name_to_start_codes=name_to_start_code, name_to_stop_codes=name_to_stop_code,
            min_break_length_ms=5000, max_break_length_ms=9000)
        n_break_start_offset = int(cnt.info['sfreq'] * break_start_offset_ms / 1000.0)
        n_break_stop_offset = int(cnt.info['sfreq'] * break_stop_offset_ms / 1000.0)
        # lets add some offset to break start and stop
        new_events[new_events[:,2] == break_start_code, 0] += n_break_start_offset
        # 0.5 sec for break stop
        new_events[new_events[:,2] == break_stop_code, 0] +=  n_break_stop_offset
        cnt.info['events'] = new_events
        log.info("Preprocessing....")
        cnt = mne_apply(lambda a: lowpass_cnt(a, 38,cnt.info['sfreq'], axis=1), cnt)
        cnt = resample_cnt(cnt, 250)
        # mne apply will apply the function to the data (a 2d-numpy-array)
        # have to transpose data back and forth, since
        # exponential_running_standardize expects time x chans order
        # while mne object has chans x time order
        cnt = mne_apply(lambda a: exponential_running_standardize(
            a.T, init_block_size=1000,factor_new=0.001, eps=1e-4).T,
            cnt)
        new_cnts.append(cnt)
    return new_cnts

In [4]:
from collections import OrderedDict
name_to_start_code = OrderedDict([('Right Hand', 1), ('Feet', 4),
            ('Rotation', 8), ('Words', [10])])

name_to_stop_code = OrderedDict([('Right Hand', [20,21,22,23,24,28,30]),
        ('Feet', [20,21,22,23,24,28,30]),
        ('Rotation', [20,21,22,23,24,28,30]), 
        ('Words', [20,21,22,23,24,28,30])])

break_start_offset_ms = 1000
break_stop_offset_ms = -500
# pick some numbers that were not used before/do not exist in markers
break_start_code = -1
break_stop_code = -2
train_runs = [1,2,3]
train_cnts = create_cnts('/home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R01-8/', 
                         train_runs,
                         name_to_start_code,
                         name_to_stop_code, break_start_offset_ms,
                         break_stop_offset_ms, break_start_code, break_stop_code)

name_to_code_with_breaks = deepcopy(name_to_start_code)
name_to_code_with_breaks['Break'] = break_start_code
name_to_stop_code_with_breaks = deepcopy(name_to_stop_code)
name_to_stop_code_with_breaks['Break'] = break_stop_code

2017-07-04 12:41:13,177 INFO : Loading /home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R01-8/AnLaNBD1S001R01_1-1_250Hz.BBCI.mat
Creating RawArray with float64 data, n_channels=64, n_times=151350
    Range : 0 ... 151349 =      0.000 ...   605.396 secs
Ready.
2017-07-04 12:41:14,266 INFO : Loading /home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R01-8/AnLaNBD1S001R02_1-1_250Hz.BBCI.mat
Creating RawArray with float64 data, n_channels=64, n_times=153500
    Range : 0 ... 153499 =      0.000 ...   613.996 secs
Ready.
2017-07-04 12:41:15,356 INFO : Loading /home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R01-8/AnLaNBD1S001R03_1-1_250Hz.BBCI.mat
Creating RawArray with float64 data, n_channels=64, n_times=180700
    Range : 0 ... 180699 =      0.000 ...   722.796 secs
Ready.
2017-07-04 12:41:16,723 INFO : Preprocessing....
2017-07-04 12:41:16,728 INFO : Just copying data, no resampling, since new sampling rate same.
2017-07-04 12:41:16,822 INFO : Preprocessing....
2017-07-04 12:41:16,827 INFO : Just copyin

In [5]:
test_runs = [9,10]
test_cnts = create_cnts('/home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R09-10/', test_runs, name_to_start_code,
           name_to_stop_code, break_start_offset_ms,
              break_stop_offset_ms, break_start_code, break_stop_code)

2017-07-04 12:41:16,994 INFO : Loading /home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R09-10/AnLaNBD1S001R09_1-1_250Hz.BBCI.mat
Creating RawArray with float64 data, n_channels=64, n_times=152050
    Range : 0 ... 152049 =      0.000 ...   608.196 secs
Ready.
2017-07-04 12:41:18,075 INFO : Loading /home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R09-10/AnLaNBD1S001R10_1-1_250Hz.BBCI.mat
Creating RawArray with float64 data, n_channels=64, n_times=151100
    Range : 0 ... 151099 =      0.000 ...   604.396 secs
Ready.
2017-07-04 12:41:19,260 INFO : Preprocessing....
2017-07-04 12:41:19,264 INFO : Just copying data, no resampling, since new sampling rate same.
2017-07-04 12:41:19,358 INFO : Preprocessing....
2017-07-04 12:41:19,363 INFO : Just copying data, no resampling, since new sampling rate same.


## Create the model

We already create the model now, since we need to know the receptive field size for properly cutting out the data to predict. We need to cut out data starting at -receptive_field_size samples before the first sample we want to predict.

In [6]:
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from torch import nn
from braindecode.torch_ext.util import set_random_seeds, to_dense_prediction_model

# Set if you want to use GPU
# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
cuda = True
set_random_seeds(seed=20170629, cuda=cuda)

# This will determine how many crops are processed in parallel
input_time_length = 650
in_chans = train_cnts[0].get_data().shape[0]
# final_conv_length determines the size of the receptive field of the ConvNet
model = ShallowFBCSPNet(in_chans=in_chans, n_classes=5, input_time_length=input_time_length,
                        final_conv_length=29).create_network()
to_dense_prediction_model(model)

if cuda:
    model.cuda()
from braindecode.torch_ext.util import np_to_var
import numpy as np
# determine output size
test_input = np_to_var(np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
if cuda:
    test_input = test_input.cuda()
out = model(test_input)
n_preds_per_input = out.cpu().data.numpy().shape[2]
print("{:d} predictions per input/trial".format(n_preds_per_input))
n_receptive_field = input_time_length - n_preds_per_input
receptive_field_ms = n_receptive_field * 1000.0 / train_cnts[0].info['sfreq']
print("Receptive field: {:d}/{:.2f} (samples/ms)".format(n_receptive_field,
                                                      receptive_field_ms))

132 predictions per input/trial
Receptive field: 518/2072.00 (samples/ms)


### Create SignalAndTarget Sets

In [7]:
train_sets = [create_signal_target_from_raw_mne(cnt, name_to_code_with_breaks, [-receptive_field_ms,0], 
                                         name_to_stop_code_with_breaks) for cnt in train_cnts]
train_set = concatenate_sets(train_sets)

2017-07-04 12:41:22,900 INFO : Trial per class:
Counter({'Right Hand': 29, 'Words': 21, 'Feet': 19, 'Break': 18, 'Rotation': 4})
2017-07-04 12:41:22,904 INFO : Trial per class:
Counter({'Feet': 31, 'Words': 26, 'Break': 20, 'Right Hand': 18, 'Rotation': 6})
2017-07-04 12:41:22,909 INFO : Trial per class:
Counter({'Feet': 38, 'Words': 29, 'Break': 23, 'Right Hand': 22, 'Rotation': 7})


In [8]:
test_sets = [create_signal_target_from_raw_mne(cnt, name_to_code_with_breaks, [-receptive_field_ms,0], 
                                         name_to_stop_code_with_breaks) for cnt in test_cnts]
test_set = concatenate_sets(test_sets)

2017-07-04 12:41:22,930 INFO : Trial per class:
Counter({'Feet': 24, 'Right Hand': 24, 'Words': 19, 'Break': 19, 'Rotation': 10})
2017-07-04 12:41:22,934 INFO : Trial per class:
Counter({'Feet': 30, 'Right Hand': 22, 'Words': 21, 'Break': 20, 'Rotation': 8})


In [9]:
from braindecode.datautil.splitters import split_into_two_sets

train_set, valid_set = split_into_two_sets(train_set, first_set_fraction=0.8)


## Setup optimizer and iterator

In [10]:
from torch import optim

optimizer = optim.Adam(model.parameters())

In [11]:
from braindecode.datautil.iterators import CropsFromTrialsIterator
iterator = CropsFromTrialsIterator(batch_size=32,input_time_length=input_time_length,
                                  n_preds_per_input=n_preds_per_input)

## Setup Monitors, Loss function, Stop Criteria

In [12]:
from braindecode.experiments.experiment import Experiment
from braindecode.experiments.monitors import RuntimeMonitor, LossMonitor, CroppedTrialMisclassMonitor, MisclassMonitor
from braindecode.experiments.stopcriteria import MaxEpochs
import torch.nn.functional as F
import torch as th
from braindecode.torch_ext.modules import Expression


loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2)[:,:,0], targets)

model_constraint = None
monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),
            CroppedTrialMisclassMonitor(input_time_length), RuntimeMonitor(),]
stop_criterion = MaxEpochs(20)
exp = Experiment(model, train_set, valid_set, test_set, iterator, loss_function, optimizer, model_constraint,
          monitors, stop_criterion, remember_best_column='valid_misclass',
          run_after_early_stop=True, batch_modifier=None, cuda=cuda)

## Run experiment

In [13]:
exp.run()

2017-07-04 12:41:23,035 INFO : Run until first stop...
2017-07-04 12:41:23,620 INFO : Epoch 0
2017-07-04 12:41:23,621 INFO : train_loss                15.20389
2017-07-04 12:41:23,622 INFO : valid_loss                16.27810
2017-07-04 12:41:23,622 INFO : test_loss                 15.07834
2017-07-04 12:41:23,623 INFO : train_sample_misclass     0.93548
2017-07-04 12:41:23,623 INFO : valid_sample_misclass     0.98502
2017-07-04 12:41:23,624 INFO : test_sample_misclass      0.91691
2017-07-04 12:41:23,625 INFO : train_misclass            0.93976
2017-07-04 12:41:23,625 INFO : valid_misclass            0.96774
2017-07-04 12:41:23,626 INFO : test_misclass             0.90863
2017-07-04 12:41:23,627 INFO : runtime                   0.00000
2017-07-04 12:41:23,627 INFO : 
2017-07-04 12:41:23,628 INFO : New best valid_misclass: 0.967742
2017-07-04 12:41:23,629 INFO : 
2017-07-04 12:41:24,679 INFO : Epoch 1
2017-07-04 12:41:24,681 INFO : train_loss                1.69720
2017-07-04 12:41:24,

We arrive only at 38.6% accuracy. With only 3 sensors and 3 training runs, cannot expect too much great performance :)