In [None]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [None]:
%%capture
import os
import site
os.sys.path.insert(0, '/home/schirrmr/code/reversible/')
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')
os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')


%load_ext autoreload
%autoreload 2
import numpy as np
import logging
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14
import seaborn
seaborn.set_style('darkgrid')

from reversible2.sliced import sliced_from_samples
from numpy.random import RandomState

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import copy
import math

import itertools
import torch as th
from braindecode.torch_ext.util import np_to_var, var_to_np
from reversible2.splitter import SubsampleSplitter

from reversible2.view_as import ViewAs
from reversible2.invert import invert

from reversible2.affine import AdditiveBlock
from reversible2.plot import display_text, display_close
from reversible2.bhno import load_file, create_inputs

In [None]:
orig_train_cnt = load_file('/data/schirrmr/schirrmr/HGD-public/reduced/train/4.mat')
train_cnt = orig_train_cnt.reorder_channels(['C3',])

train_inputs = create_inputs(train_cnt, final_hz=64, half_before=True)

In [None]:
orig_test_cnt = load_file('/data/schirrmr/schirrmr/HGD-public/reduced/test/4.mat')
test_cnt = orig_test_cnt.reorder_channels(['C3', ])
test_inputs = create_inputs(test_cnt, final_hz=64, half_before=True)

In [None]:
cuda = True
if cuda:
    train_inputs = [i.cuda() for i in train_inputs]
    test_inputs = [i.cuda() for i in test_inputs]

In [None]:
from reversible2.distribution import TwoClassDist

from reversible2.blocks import dense_add_block, conv_add_block_3x3
from reversible2.rfft import RFFT, Interleave
from reversible2.util import set_random_seeds
from torch.nn import ConstantPad2d
import torch as th
from reversible2.splitter import SubsampleSplitter


set_random_seeds(2019011641, cuda)
feature_model = nn.Sequential(
    SubsampleSplitter(stride=[2,1],chunk_chans_first=False),# 2 x 32
    conv_add_block_3x3(2,32),
    conv_add_block_3x3(2,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 4 x 16
    conv_add_block_3x3(4,32),
    conv_add_block_3x3(4,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 8 x 8
    conv_add_block_3x3(8,32),
    conv_add_block_3x3(8,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 16 x 4
    conv_add_block_3x3(16,32),
    conv_add_block_3x3(16,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 32 x 2
    conv_add_block_3x3(32,32),
    conv_add_block_3x3(32,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 64 x 1
    ViewAs((-1,64,1, 1), (-1,64)),
    dense_add_block(64,64),
    dense_add_block(64,64),
    dense_add_block(64,64),
    dense_add_block(64,64),
    dense_add_block(64,64),
    dense_add_block(64,64),
    RFFT(),
)
if cuda:
    feature_model.cuda()
device = list(feature_model.parameters())[0].device
from reversible2.ot_exact import ot_euclidean_loss_for_samples
class_dist = TwoClassDist(2,62)
class_dist.cuda()
optim_model = th.optim.Adam(feature_model.parameters())
optim_dist = th.optim.Adam(class_dist.parameters(), lr=1e-2)

In [None]:
i_class = 0
class_ins = train_inputs[i_class]

In [None]:
from reversible2.timer import Timer

In [None]:
with Timer(name='all'):
    with Timer(name='samples'):
        samples = class_dist.get_samples(i_class, len(train_inputs[i_class]) * 5)
    with Timer(name='invert'):
        inverted = invert(feature_model, samples)
    with Timer(name='forward'):
        outs = feature_model(class_ins)
    #with Timer(name='ot_out'):
    #    ot_loss_out = ot_euclidean_loss_for_samples(outs[:,:2].squeeze(), samples[:,:2].squeeze())
    #with Timer(name='ot_in'):
    #    ot_loss_in = ot_euclidean_loss_for_samples(class_ins.squeeze(), inverted.squeeze())


In [None]:
from timeit import default_timer

x = class_ins
times = []
start = default_timer()
for module in feature_model.children():
    x = module(x)
    times.append(default_timer())

times_inv = []
start_inv = default_timer()
x = samples
for module in list(feature_model.children())[::-1]:
    x = invert(nn.Sequential(module), x)
    times_inv.append(default_timer())

In [None]:
(np.array(times) - start) * 1000

In [None]:
(np.array(times_inv) - start_inv) * 1000

In [None]:
list(zip([m.__class__.__name__ for m in feature_model.children()],
         np.diff(np.insert(np.array(times) - start, 0, 0) * 1000)))

In [None]:
list(zip([m.__class__.__name__ for m in feature_model.children()],
         np.diff(np.insert(np.array(times_inv) - start_inv,0,0) * 1000)[::-1]))

In [None]:
plt.plot(np.diff(np.insert(np.array(times) - start, 0, 0) * 1000))
plt.plot(np.diff(np.insert(np.array(times_inv) - start_inv,0,0) * 1000)[::-1])

In [None]:
from plot import plot_outs

n_epochs = 2001
for i_epoch in range(n_epochs):
    optim_model.zero_grad()
    optim_dist.zero_grad()
    for i_class in range(len(train_inputs)):
        class_ins = train_inputs[i_class]
        samples = class_dist.get_samples(i_class, len(train_inputs[i_class]) * 5)
        inverted = invert(feature_model, samples)
        outs = feature_model(class_ins)
        ot_loss_out = ot_euclidean_loss_for_samples(outs[:,:2].squeeze(), samples[:,:2].squeeze())
        ot_loss_in = ot_euclidean_loss_for_samples(class_ins.squeeze(), inverted.squeeze())
        
        other_class_ins = train_inputs[1-i_class]
        other_outs = feature_model(other_class_ins)
        changed_outs = class_dist.change_to_other_class(other_outs, i_class_from=1-i_class, i_class_to=i_class)
        changed_inverted = invert(feature_model, changed_outs)
        ot_transformed_in = ot_euclidean_loss_for_samples(class_ins.squeeze(), changed_inverted.squeeze())
        ot_transformed_out = ot_euclidean_loss_for_samples(changed_outs[:,:2].squeeze(), samples[:,:2].squeeze(),)
        loss = ot_loss_in + ot_loss_out + ot_transformed_in + ot_transformed_out
        loss.backward()
    optim_model.step()
    optim_dist.step()
    if i_epoch % (n_epochs // 20) == 0:
        print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
        print("Loss: {:E}".format(loss.item()))
        print("OT Loss In: {:E}".format(ot_loss_in.item()))
        print("OT Loss Out: {:E}".format(ot_loss_out.item()))
        print("Transformed OT Loss In: {:E}".format(ot_transformed_in.item()))
        print("Transformed OT Loss Out: {:E}".format(ot_transformed_out.item()))
        plot_outs(feature_model, train_inputs, test_inputs,
                 class_dist)
        fig = plt.figure(figsize=(8,2))
        plt.plot(var_to_np(th.cat((th.exp(class_dist.class_log_stds),
                                 th.exp(class_dist.non_class_log_stds)))),
                marker='o')
        display_close(fig)