In [None]:
%load_ext autoreload
%autoreload 2
import os
os.sys.path.insert(0, '/users/viral/Neutron6/viral/repo/braindecode/code/braindecode/')

# Using the Experiment Class

Braindecode provides a convenience `Experiment` class, which removes the necessity to write your own training loop. It expects a training, a validation and a test set and trains as follows:

1. Train on training set until a given stop criterion is fulfilled
2. Reset to the best epoch, i.e. reset parameters of the model and the optimizer to the state at the best epoch ("best" according to a given criterion) 
3. Continue training on the combined training + validation set until the loss on the validation set is as low as it was on the best epoch for the training set. (or until the ConvNet was trained twice as many epochs as the best epoch to prevent infinite training)

<div class='alert alert-warning'>

It is not necessary to use the Experiment class to use the remaning functionality of Braindecode. Feel free to ignore it :)

</div>

## Load data

In [None]:
import mne
import numpy as np
from mne.io import concatenate_raws


path='./data/viral_train.set'
.read_raw_edf('./data/viral_train.edf')

# Find the events in this dataset
events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')
eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                   exclude='bads')

epoched = mne.Epochs(raw, events, dict(target=1, nontarget=2), tmin=0, tmax=1, proj=False, picks=np.arange(2,15, 1),
                baseline=None, preload=True)

In [None]:
epoched.event_id

In [None]:
import matplotlib.pyplot as plt

event_ids = ['1', '2']
fig = mne.viz.plot_events(events, raw.info['sfreq'], raw.first_samp, show=False)

# convert plot to plotly
update = dict(layout=dict(showlegend=True), data=[dict(name=e) for e in event_ids])
plt.show()

## Convert data to Braindecode Format

In [None]:
import numpy as np
from braindecode.datautil.signal_target import SignalAndTarget
from braindecode.datautil.splitters import split_into_two_sets
# Convert data from volt to millivolt
# Pytorch expects float32 for input and int64 for labels.
X = (epoched.get_data() * 1e6).astype(np.float32)
y = (epoched.events[:,2]-1).astype(np.int64) #2,3 -> 0,1


data=SignalAndTarget(X, y)
train_set, test_set=split_into_two_sets(data, first_set_fraction=0.8)

train_set, valid_set = split_into_two_sets(train_set, first_set_fraction=0.8)

In [None]:
print train_set.X.shape
print test_set.X.shape
print valid_set.X.shape


## Create the model

In [None]:
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from braindecode.models.eegnet import EEGNet
from torch import nn
from braindecode.torch_ext.util import set_random_seeds
from braindecode.models.util import to_dense_prediction_model

# Set if you want to use GPU
# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
cuda = True
set_random_seeds(seed=20170629, cuda=cuda)

# This will determine how many crops are processed in parallel
input_time_length = 129
n_classes = 2
in_chans = train_set.X.shape[1]
# final_conv_length determines the size of the receptive field of the ConvNet
model = EEGNet(in_chans=in_chans, n_classes=n_classes, input_time_length=input_time_length,
                        final_conv_length=12).create_network()
to_dense_prediction_model(model)



if cuda:
    model.cuda()

In [None]:
model

In [None]:
from torch import optim

optimizer = optim.Adam(model.parameters())

In [None]:
from braindecode.torch_ext.util import np_to_var
# determine output size
test_input = np_to_var(np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))

if cuda:
    test_input = test_input.cuda()
out = model(test_input)
n_preds_per_input = out.cpu().data.numpy().shape[2]
print("{:d} predictions per input/trial".format(n_preds_per_input))

## Setup Experiment

Now we need to setup everything for the experiment: Iterator, loss function, monitors and stop criterion.

In [None]:
from braindecode.experiments.experiment import Experiment
from braindecode.datautil.iterators import CropsFromTrialsIterator
from braindecode.experiments.monitors import RuntimeMonitor, LossMonitor, CroppedTrialMisclassMonitor, MisclassMonitor
from braindecode.experiments.stopcriteria import MaxEpochs
import torch.nn.functional as F
import torch as th
from braindecode.torch_ext.modules import Expression
# Iterator is used to iterate over datasets both for training
# and evaluation
iterator = CropsFromTrialsIterator(batch_size=32,input_time_length=input_time_length,
                                  n_preds_per_input=n_preds_per_input)

# Loss function takes predictions as they come out of the network and the targets
# and returns a loss
loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2)[:,:,0], targets)

# Could be used to apply some constraint on the models, then should be object
# with apply method that accepts a module
model_constraint = None
# Monitors log the training progress
monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),
            CroppedTrialMisclassMonitor(input_time_length), RuntimeMonitor(),]
# Stop criterion determines when the first stop happens
stop_criterion = MaxEpochs(20)
exp = Experiment(model, train_set, valid_set, test_set, iterator, loss_function, optimizer, model_constraint,
          monitors, stop_criterion, remember_best_column='valid_misclass',
          run_after_early_stop=True, batch_modifier=None, cuda=cuda)

## Run experiment

In [None]:
# need to setup python logging before to be able to see anything
import logging
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.DEBUG, stream=sys.stdout)
exp.run()