In [None]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [None]:
%%capture
import os
import site
os.sys.path.insert(0, '/home/schirrmr/code/reversible/')
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')
os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')


%load_ext autoreload
%autoreload 2
import numpy as np
import logging
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14
import seaborn
seaborn.set_style('darkgrid')

from reversible2.sliced import sliced_from_samples
from numpy.random import RandomState

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import copy
import math

import itertools
import torch as th
from braindecode.torch_ext.util import np_to_var, var_to_np
from reversible2.splitter import SubsampleSplitter

from reversible2.view_as import ViewAs

from reversible2.affine import AdditiveBlock
from reversible2.plot import display_text, display_close
th.backends.cudnn.benchmark = True

In [None]:
from reversible2.high_gamma import load_train_test, to_signal_target
train_inputs, test_inputs = load_train_test(subject_id=4, car=True,n_sensors=22,final_hz=256,
                                           start_ms=500, stop_ms=1500,half_before=True,
                                            only_load_given_sensors=False)

cuda = True
train_set, valid_set = to_signal_target(train_inputs, test_inputs)

In [None]:
class ScaleAndShift(nn.Module):
    def __init__(self,):
        super(ScaleAndShift, self).__init__()
        self.factor = 0.2
        self.add = 0
        
    def forward(self, x):
        return  (x + self.add) * self.factor
    def invert(self, y):
        return (y / self.factor) - self.add


class ZeroPadChans(nn.Module):
    def __init__(self, n_per_side):
        super(ZeroPadChans, self).__init__()
        self.n_per_side = n_per_side
        
    def forward(self, x):
        return th.cat((th.zeros_like(x[:,:self.n_per_side]),
                       x,
                       th.zeros_like(x[:,:self.n_per_side]),
                      ), dim=1)
    
    def invert(self, y):
        return y[:, self.n_per_side:-self.n_per_side]

In [None]:
from braindecode.torch_ext.modules import Expression
from reversible2.rfft import RFFT
from braindecode.torch_ext.optimizers import AdamW
n_chans = train_set.X.shape[1]
n_classes = 2
input_time_length=train_set.X.shape[2]
n_iters = 5
dfs = []
for _ in range (n_iters):
    model = nn.Sequential()

    n_chan_pad = 0
    n_filters_start = 22 + n_chan_pad
    filter_length_time = 9
    conv_stride = 1
    n_filters_conv = n_filters_start
    nonlin = F.elu
    pool_length = 3
    pool_stride = 1

    model.add_module('padchan',
                     ZeroPadChans(n_chan_pad//2))
    
    model.add_module(
        "conv_time",
        AdditiveBlock(
            nn.Sequential(
                nn.Conv2d(
                    (n_chans + n_chan_pad) // 2,
                    n_filters_start//2,
                    (25, 1),
                    stride=(conv_stride, 1),
                    padding=(12,0),
                ),
                Expression(nonlin),
            nn.MaxPool2d(
                kernel_size=(pool_length, 1), stride=(pool_stride, 1),
                padding=(pool_length //2,0)
            ),
            ),
        
            nn.Sequential(
                nn.Conv2d(
                    (n_chans + n_chan_pad) // 2,
                    n_filters_start // 2,
                    (25, 1),
                    stride=(conv_stride, 1),
                    padding=(12,0),
                ),
                Expression(nonlin),
            nn.MaxPool2d(
                kernel_size=(pool_length, 1), stride=(pool_stride, 1),
                padding=(pool_length //2,0)
            )),
            switched_order=False)
            
    )

    def add_conv_pool_block(
        model, n_filters_before, n_filters, filter_length, block_nr
    ):
        suffix = "_{:d}".format(block_nr)
        #model.add_module('pad_inc' + suffix,
        #                ZeroPadChans((n_filters - n_filters_before) //2))
        model.add_module('split_inc' + suffix,
                         SubsampleSplitter([2,1], chunk_chans_first=False,checkerboard=False))
        def conv_pool_block():
            return nn.Sequential(
                nn.Conv2d(
                    n_filters//2,
                    n_filters//2,
                    (filter_length, 1),
                    stride=(1, 1),
                    padding=(filter_length // 2, 0)
                ),
                Expression(nonlin),
            nn.MaxPool2d(
                    kernel_size=(pool_length, 1),
                    stride=(pool_stride, 1),
                    padding=(pool_length // 2, 0)
                ))
        
            
        model.add_module(
            "conv_res" + suffix,
            AdditiveBlock(
                conv_pool_block(),
                conv_pool_block(),
                switched_order=False
        ))
        

    add_conv_pool_block(
        model, n_filters_conv, 2 * n_filters_start, 11, 2)
    add_conv_pool_block(
        model, 2 * n_filters_start, 4 * n_filters_start, 11, 3)
    add_conv_pool_block(
        model, 4 * n_filters_start, 8 * n_filters_start, 11, 4)
    
    model.add_module('reshape_for_fft', ViewAs((-1,n_filters_start * 8, input_time_length // 8,1),
                                              (-1, input_time_length // 8)))
    model.add_module('fft', RFFT())
    model.add_module('unreshape_for_fft', ViewAs(
                                              (-1, input_time_length // 8),
                (-1,n_filters_start * 8, input_time_length // 8,),))
    model.add_module('scaledown', ScaleAndShift())


    model.add_module("select_dims", Expression(lambda x: x[:,:2,0]))
    model.add_module("softmax", nn.LogSoftmax(dim=1))

    from reversible2.models import WrappedModel

    model = WrappedModel(model)

    model.cuda()



    from copy import deepcopy
    model_to_train = deepcopy(model)
    lr = 1 * 0.001
    weight_decay = 0.5 * 0.01
    optimizer = AdamW(model_to_train.parameters(), lr=lr,
                      weight_decay=weight_decay)

    max_epochs = 50
    model_to_train.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, )
    model_to_train.fit(train_set.X, train_set.y, epochs=max_epochs, batch_size=64,
              scheduler='cosine',
              validation_data=(valid_set.X, valid_set.y), )
    dfs.append(model_to_train.epochs_df)
    

In [None]:
import pandas as pd
pd.concat([df.iloc[-1:] for df in dfs])