In [None]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [None]:
%%capture
import os
import site
os.sys.path.insert(0, '/home/schirrmr/code/reversible/')
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')
os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')


%load_ext autoreload
%autoreload 2
import numpy as np
import logging
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14
import seaborn
seaborn.set_style('darkgrid')

from reversible2.sliced import sliced_from_samples
from numpy.random import RandomState

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import copy
import math

import itertools
import torch as th
from braindecode.torch_ext.util import np_to_var, var_to_np
from reversible2.splitter import SubsampleSplitter

from reversible2.view_as import ViewAs

from reversible2.affine import AdditiveBlock
from reversible2.plot import display_text, display_close
th.backends.cudnn.benchmark = True

In [None]:
from reversible2.high_gamma import load_train_test, to_signal_target
train_inputs, test_inputs = load_train_test(subject_id=4, car=True,n_sensors=22,final_hz=256,
                                           start_ms=500, stop_ms=1500,half_before=True,
                                            only_load_given_sensors=False)

cuda = True
train_set, valid_set = to_signal_target(train_inputs, test_inputs)

In [None]:
from reversible2.models import larger_model
from braindecode.torch_ext.optimizers import AdamW
import torch.nn.functional as F
from reversible2.models import add_softmax, add_bnorm_before_relu, add_dropout_before_convs

n_chans = train_inputs[0].shape[1]
n_time = train_inputs[0].shape[2]

feature_model = larger_model(n_chans=n_chans, n_time=n_time, final_fft=True, 
                             kernel_length=9, constant_memory=False)
feature_model = add_softmax(feature_model)
add_bnorm_before_relu(feature_model)
add_dropout_before_convs(feature_model, p_conv=0.5, p_full=0.5)
feature_model.cuda()

from braindecode.models.base import BaseModel
class WrappedModel(BaseModel):
    def __init__(self, network):
        self.given_network = network

    def create_network(self):
        return self.given_network

from torch import nn
from braindecode.torch_ext.util import set_random_seeds
set_random_seeds(2019011641, cuda)
model = WrappedModel(feature_model)
model.cuda()
lr = 1e-3
weight_decay = 1e-3
optimizer = AdamW(model.parameters(), lr=lr,
                  weight_decay=weight_decay)

max_epochs = 30
model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, )
model.fit(train_set.X, train_set.y, epochs=max_epochs, batch_size=64,
          scheduler='cosine',
          validation_data=(valid_set.X, valid_set.y), )

## Simpler Model

In [None]:
def transpose_1st_and_3rd_dim(x):
    return x.permute(0, 3, 2, 1)

In [None]:

def create_chain_modules(n_chans, n_time, n_blocks_per_stage, n_filters_per_stage):
    
    cur_n_time = n_time
    modules = []
    i_stage = 0

    while cur_n_time > 1:
        chunk_chans_first = not (cur_n_time == n_time)
        modules.append(SubsampleSplitter(stride=[2,1],chunk_chans_first=chunk_chans_first))
        cur_n_time = cur_n_time  // 2
        factor = int(2 ** (i_stage+1))

        if cur_n_time == 1:
            modules.append(ViewAs((-1,factor*n_chans,1,1), (-1, factor*n_chans,)))

        for _ in range(n_blocks_per_stage):
            if cur_n_time > 1:
                modules.append(conv_add_3x3_no_switch(factor*n_chans, n_filters_per_stage[i_stage]))
            else:
                modules.append(dense_add_no_switch(factor*n_chans, n_filters_per_stage[i_stage]))
        i_stage += 1
    return modules

In [None]:
from reversible2.blocks import dense_add_no_switch, conv_add_3x3_no_switch
from reversible2.view_as import ViewAs
from reversible2.splitter import SubsampleSplitter
from reversible2.rfft import RFFT
from braindecode.torch_ext.modules import Expression

n_blocks_per_stage = 3
n_filters_per_stage = [32,48,64,128,192,256,384,512]


cur_n_time = n_time
modules = []
i_stage = 0
#modules.append(nn.Conv2d(n_chans,n_chans,(1,1)))
modules.append(Expression(transpose_1st_and_3rd_dim))
modules.append(
    nn.Conv2d(
        1, 25, (11, 1), stride=1,padding=(5,0)
        ))
modules.append(nn.Conv2d(25,n_chans,(1,n_chans)))

modules.extend(create_chain_modules(n_chans, n_time, n_blocks_per_stage, n_filters_per_stage))

feature_model = nn.Sequential(*modules, RFFT())
feature_model = add_softmax(feature_model)
add_bnorm_before_relu(feature_model)
add_dropout_before_convs(feature_model, p_conv=0.2, p_full=0.5)
feature_model.cuda()

from torch import nn
from braindecode.torch_ext.util import set_random_seeds
set_random_seeds(2019011641, cuda)
model = WrappedModel(feature_model)
model.cuda()
lr = 5e-4
weight_decay = 1e-5
optimizer = AdamW(model.parameters(), lr=lr,
                  weight_decay=weight_decay)

max_epochs = 30
model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, )
model.fit(train_set.X, train_set.y, epochs=max_epochs, batch_size=64,
          scheduler='cosine',
          validation_data=(valid_set.X, valid_set.y), )

## Deep Model

In [None]:
from braindecode.models.deep4 import Deep4Net
n_classes = 2
model = Deep4Net(n_chans, n_classes,
             input_time_length=train_set.X.shape[2],
             pool_time_length=2,
             pool_time_stride=2,
             final_conv_length='auto')
model.cuda()
lr = 1 * 0.01
weight_decay = 0.5 * 0.001
optimizer = AdamW(model.parameters(), lr=lr,
                  weight_decay=weight_decay)

max_epochs = 30
model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, )
model.fit(train_set.X, train_set.y, epochs=max_epochs, batch_size=64,
          scheduler='cosine',
          validation_data=(valid_set.X, valid_set.y), )