In [1]:
%load_ext autoreload
%autoreload 2
import os
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')

# Using the Experiment Class

Braindecode provides a convenience `Experiment` class, which removes the necessity to write your own training loop. It expects a training, a validation and a test set and trains as follows:

1. Train on training set until a given stop criterion is fulfilled
2. Reset to the best epoch, i.e. reset parameters of the model and the optimizer to the state at the best epoch ("best" according to a given criterion) 
3. Continue training on the combined training + validation set until the loss on the validation set is as low as it was on the best epoch for the training set. (or until the ConvNet was trained twice as many epochs as the best epoch to prevent infinite training)

<div class='alert alert-warning'>

It is not necessary to use the Experiment class to use the remaning functionality of Braindecode. Feel free to ignore it :)

</div>

## Load data

In [2]:
import mne
from mne.io import concatenate_raws

# 5,6,7,10,13,14 are codes for executed and imagined hands/feet
subject_id = 1
event_codes = [5,6,9,10,13,14]

# This will download the files if you don't have them yet,
# and then return the paths to the files.
physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)

# Load each of the files
parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto', verbose='WARNING')
         for path in physionet_paths]

# Concatenate them
raw = concatenate_raws(parts)

# Find the events in this dataset
events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')

# Use only EEG channels
eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                   exclude='bads')

# Extract trials, only using EEG channels
epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=eeg_channel_inds,
                baseline=None, preload=True)

Removing orphaned offset at the beginning of the file.
179 events found
Events id: [1 2 3]
90 matching events found
Loading data for 90 events and 497 original time points ...
0 bad epochs dropped


## Convert data to Braindecode Format

In [3]:
import numpy as np
from braindecode.datautil.signal_target import SignalAndTarget
from braindecode.datautil.splitters import split_into_two_sets
# Convert data from volt to millivolt
# Pytorch expects float32 for input and int64 for labels.
X = (epoched.get_data() * 1e6).astype(np.float32)
y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1

train_set = SignalAndTarget(X[:60], y=y[:60])
test_set = SignalAndTarget(X[60:], y=y[60:])

train_set, valid_set = split_into_two_sets(train_set, first_set_fraction=0.8)

## Create the model

In [4]:
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from torch import nn
from braindecode.torch_ext.util import set_random_seeds
from braindecode.models.util import to_dense_prediction_model

# Set if you want to use GPU
# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
cuda = False
set_random_seeds(seed=20170629, cuda=cuda)

# This will determine how many crops are processed in parallel
input_time_length = 450
n_classes = 2
in_chans = train_set.X.shape[1]
# final_conv_length determines the size of the receptive field of the ConvNet
model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes, input_time_length=input_time_length,
                        final_conv_length=12).create_network()
to_dense_prediction_model(model)

if cuda:
    model.cuda()

In [5]:
from torch import optim

optimizer = optim.Adam(model.parameters())

In [6]:
from braindecode.torch_ext.util import np_to_var
# determine output size
test_input = np_to_var(np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
if cuda:
    test_input = test_input.cuda()
out = model(test_input)
n_preds_per_input = out.cpu().data.numpy().shape[2]
print("{:d} predictions per input/trial".format(n_preds_per_input))

187 predictions per input/trial


## Setup Experiment

Now we need to setup everything for the experiment: Iterator, loss function, monitors and stop criterion.

In [7]:
from braindecode.experiments.experiment import Experiment
from braindecode.datautil.iterators import CropsFromTrialsIterator
from braindecode.experiments.monitors import RuntimeMonitor, LossMonitor, CroppedTrialMisclassMonitor, MisclassMonitor
from braindecode.experiments.stopcriteria import MaxEpochs
import torch.nn.functional as F
import torch as th
from braindecode.torch_ext.modules import Expression
# Iterator is used to iterate over datasets both for training
# and evaluation
iterator = CropsFromTrialsIterator(batch_size=32,input_time_length=input_time_length,
                                  n_preds_per_input=n_preds_per_input)

# Loss function takes predictions as they come out of the network and the targets
# and returns a loss
loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2, keepdim=False), targets)

# Could be used to apply some constraint on the models, then should be object
# with apply method that accepts a module
model_constraint = None
# Monitors log the training progress
monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),
            CroppedTrialMisclassMonitor(input_time_length), RuntimeMonitor(),]
# Stop criterion determines when the first stop happens
stop_criterion = MaxEpochs(20)
exp = Experiment(model, train_set, valid_set, test_set, iterator, loss_function, optimizer, model_constraint,
          monitors, stop_criterion, remember_best_column='valid_misclass',
          run_after_early_stop=True, batch_modifier=None, cuda=cuda)

## Run experiment

In [8]:
# need to setup python logging before to be able to see anything
import logging
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.DEBUG, stream=sys.stdout)
exp.run()

2017-10-18 14:35:26,018 INFO : Run until first stop...
2017-10-18 14:35:27,899 INFO : Epoch 0
2017-10-18 14:35:27,904 INFO : train_loss                0.86930
2017-10-18 14:35:27,908 INFO : valid_loss                0.74838
2017-10-18 14:35:27,909 INFO : test_loss                 0.69756
2017-10-18 14:35:27,910 INFO : train_sample_misclass     0.53894
2017-10-18 14:35:27,912 INFO : valid_sample_misclass     0.47103
2017-10-18 14:35:27,913 INFO : test_sample_misclass      0.44251
2017-10-18 14:35:27,914 INFO : train_misclass            0.60417
2017-10-18 14:35:27,915 INFO : valid_misclass            0.50000
2017-10-18 14:35:27,916 INFO : test_misclass             0.40000
2017-10-18 14:35:27,917 INFO : runtime                   0.00000
2017-10-18 14:35:27,919 INFO : 
2017-10-18 14:35:27,922 INFO : New best valid_misclass: 0.500000
2017-10-18 14:35:27,924 INFO : 
2017-10-18 14:35:30,830 INFO : Time only for training updates: 2.87s
2017-10-18 14:35:32,663 INFO : Epoch 1
2017-10-18 14:35:32

2017-10-18 14:36:09,672 INFO : Epoch 10
2017-10-18 14:36:09,674 INFO : train_loss                0.25623
2017-10-18 14:36:09,675 INFO : valid_loss                0.89555
2017-10-18 14:36:09,676 INFO : test_loss                 0.75208
2017-10-18 14:36:09,678 INFO : train_sample_misclass     0.06088
2017-10-18 14:36:09,679 INFO : valid_sample_misclass     0.55637
2017-10-18 14:36:09,680 INFO : test_sample_misclass      0.41765
2017-10-18 14:36:09,681 INFO : train_misclass            0.02083
2017-10-18 14:36:09,682 INFO : valid_misclass            0.58333
2017-10-18 14:36:09,683 INFO : test_misclass             0.33333
2017-10-18 14:36:09,684 INFO : runtime                   3.40419
2017-10-18 14:36:09,685 INFO : 
2017-10-18 14:36:11,601 INFO : Time only for training updates: 1.90s
2017-10-18 14:36:12,948 INFO : Epoch 11
2017-10-18 14:36:12,950 INFO : train_loss                0.24021
2017-10-18 14:36:12,951 INFO : valid_loss                0.90580
2017-10-18 14:36:12,953 INFO : test_los

2017-10-18 14:36:43,296 INFO : valid_loss                1.05598
2017-10-18 14:36:43,297 INFO : test_loss                 0.88760
2017-10-18 14:36:43,298 INFO : train_sample_misclass     0.00847
2017-10-18 14:36:43,300 INFO : valid_sample_misclass     0.51448
2017-10-18 14:36:43,301 INFO : test_sample_misclass      0.33066
2017-10-18 14:36:43,302 INFO : train_misclass            0.00000
2017-10-18 14:36:43,303 INFO : valid_misclass            0.50000
2017-10-18 14:36:43,305 INFO : test_misclass             0.36667
2017-10-18 14:36:43,306 INFO : runtime                   3.33038
2017-10-18 14:36:43,307 INFO : 
2017-10-18 14:36:43,308 INFO : Setup for second stop...
2017-10-18 14:36:43,314 INFO : Train loss to reach 0.09315
2017-10-18 14:36:43,315 INFO : Run until second stop...
2017-10-18 14:36:44,773 INFO : Epoch 19
2017-10-18 14:36:44,774 INFO : train_loss                0.25265
2017-10-18 14:36:44,776 INFO : valid_loss                0.89064
2017-10-18 14:36:44,777 INFO : test_loss  

2017-10-18 14:37:21,337 INFO : valid_sample_misclass     0.03387
2017-10-18 14:37:21,338 INFO : test_sample_misclass      0.25856
2017-10-18 14:37:21,339 INFO : train_misclass            0.00000
2017-10-18 14:37:21,339 INFO : valid_misclass            0.00000
2017-10-18 14:37:21,340 INFO : test_misclass             0.26667
2017-10-18 14:37:21,341 INFO : runtime                   3.99343
2017-10-18 14:37:21,342 INFO : 
2017-10-18 14:37:23,868 INFO : Time only for training updates: 2.51s
2017-10-18 14:37:25,405 INFO : Epoch 29
2017-10-18 14:37:25,406 INFO : train_loss                0.12664
2017-10-18 14:37:25,407 INFO : valid_loss                0.21718
2017-10-18 14:37:25,408 INFO : test_loss                 0.74796
2017-10-18 14:37:25,409 INFO : train_sample_misclass     0.04274
2017-10-18 14:37:25,410 INFO : valid_sample_misclass     0.08623
2017-10-18 14:37:25,411 INFO : test_sample_misclass      0.26381
2017-10-18 14:37:25,411 INFO : train_misclass            0.03333
2017-10-18 14:

In this case, we arrive at 80.0% accuracy, the training stops after the validation loss decreases below the training loss at the best epoch of 0.09315.

## Dataset References


 This dataset was created and contributed to PhysioNet by the developers of the [BCI2000](http://www.schalklab.org/research/bci2000) instrumentation system, which they used in making these recordings. The system is described in:
 
     Schalk, G., McFarland, D.J., Hinterberger, T., Birbaumer, N., Wolpaw, J.R. (2004) BCI2000: A General-Purpose Brain-Computer Interface (BCI) System. IEEE TBME 51(6):1034-1043.

[PhysioBank](https://physionet.org/physiobank/) is a large and growing archive of well-characterized digital recordings of physiologic signals and related data for use by the biomedical research community and further described in:

    Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220.