In [None]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [None]:
%%capture
import os
import site
os.sys.path.insert(0, '/home/schirrmr/code/reversible/')
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')


%load_ext autoreload
%autoreload 2
import numpy as np
import logging
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14
import seaborn
seaborn.set_style('darkgrid')

from reversible2.sliced import sliced_from_samples
from numpy.random import RandomState

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import copy
import math

import itertools
import torch as th
from braindecode.torch_ext.util import np_to_var, var_to_np
from reversible2.splitter import SubsampleSplitter

from reversible2.view_as import ViewAs

from reversible2.affine import AdditiveBlock
from reversible2.plot import display_text, display_close
from reversible2.bhno import load_file, create_inputs
th.backends.cudnn.benchmark = True

In [None]:
sensor_names = ['C3',]
orig_train_cnt = load_file('/data/schirrmr/schirrmr/HGD-public/reduced/train/4.mat')
train_cnt = orig_train_cnt.reorder_channels(sensor_names)

train_inputs = create_inputs(train_cnt, final_hz=64, half_before=True,
                            start_ms=500,stop_ms=1500)
t = train_inputs[1]

t_fft = th.rfft(t.squeeze(), signal_ndim=1)
t_fft[:,7:13] *= 2
train_inputs[1] = th.irfft(t_fft, signal_sizes=[64], signal_ndim=1).unsqueeze(1).unsqueeze(-1).detach().clone()
n_split = len(train_inputs[0]) - 40
test_inputs = [t[-40:] for t in train_inputs]
train_inputs = [t[:-40] for t in train_inputs]

In [None]:
train_inputs[0].shape

In [None]:
cuda = True
if cuda:
    train_inputs = [i.cuda() for i in train_inputs]
    test_inputs = [i.cuda() for i in test_inputs]

In [None]:
from reversible2.graph import Node
from reversible2.branching import CatChans, ChunkChans, Select
from reversible2.constantmemory import graph_to_constant_memory

from copy import deepcopy
from reversible2.graph import Node
from reversible2.distribution import TwoClassDist
from reversible2.wrap_invertible import WrapInvertible
from reversible2.blocks import dense_add_no_switch, conv_add_3x3_no_switch
from reversible2.rfft import RFFT, Interleave
from reversible2.util import set_random_seeds
import torch as th
from reversible2.splitter import SubsampleSplitter
from reversible2.models import smaller_model

final_fft = False
constant_memory = False

set_random_seeds(2019011641, cuda)
n_chans = train_inputs[0].shape[1]
n_time = train_inputs[0].shape[2]
feature_model = smaller_model(n_chans, n_time, final_fft, constant_memory)

In [None]:
data_zero_init = False
set_distribution_to_empirical = True

In [None]:
from reversible2.constantmemory import clear_ctx_dicts
if data_zero_init:
    feature_model.data_init(th.cat((train_inputs[0], train_inputs[1]), dim=0))

# Check that forward + inverse is really identical
t_out = feature_model(train_inputs[0][:2])
inverted = feature_model.invert(t_out)
clear_ctx_dicts(feature_model)
assert th.allclose(train_inputs[0][:2], inverted, rtol=1e-3, atol=1e-4)
from reversible2.ot_exact import ot_euclidean_loss_for_samples
from reversible2.distribution import TwoClassIndependentDist
class_dist = TwoClassIndependentDist(np.prod(train_inputs[0].size()[1:]))
class_dist.cuda()

In [None]:
if set_distribution_to_empirical:
    for i_class in range(2):
        with th.no_grad():
            this_outs = feature_model(train_inputs[i_class])
            mean = th.mean(this_outs, dim=0)
            std = th.std(this_outs, dim=0)
            class_dist.set_mean_std(i_class, mean, std)
            # Just check
            setted_mean, setted_std = class_dist.get_mean_std(i_class)
            assert th.allclose(mean, setted_mean)
            assert th.allclose(std, setted_std)
    clear_ctx_dicts(feature_model)

In [None]:
optim_model = th.optim.Adam(feature_model.parameters(), lr=1e-3, betas=(0.9, 0.999))
optim_dist = th.optim.Adam(class_dist.parameters(), lr=1e-2, betas=(0.9, 0.999))

In [None]:
clf_loss = None
ot_on_class_dims = False

In [None]:
from reversible2.training import CLFTrainer
from reversible2.classifier import SubspaceClassifier
from reversible2.monitor import compute_accs
from reversible2.monitor import compute_clf_accs
if clf_loss is not None:
    clf = SubspaceClassifier(2, 10, np.prod(train_inputs[0].shape[1:]))
    clf.cuda()

    optim_clf = th.optim.Adam(clf.parameters(), lr=1e-3)
    clf_trainer = CLFTrainer(
        feature_model,
        clf,
        class_dist,
        optim_model,
        optim_clf,
        optim_dist,
        outs_loss=clf_loss,
    )

In [None]:

import pandas as pd

df = pd.DataFrame()

from reversible2.training import OTTrainer

trainer = OTTrainer(feature_model, class_dist, optim_model, optim_dist)

from reversible2.constantmemory import clear_ctx_dicts
from reversible2.timer import Timer

i_start_epoch_out = 401
n_epochs = 1001

In [None]:
for i_epoch in range(n_epochs):
    i_epoch = i_epoch
    epoch_row = {}
    with Timer(name="EpochLoop", verbose=False) as loop_time:
        loss_on_outs = i_epoch >= i_start_epoch_out
        result = trainer.train(train_inputs, loss_on_outs=(loss_on_outs and ot_on_class_dims))
        if clf_loss is not None:
            result_clf = clf_trainer.train(train_inputs, loss_on_outs=loss_on_outs)
            epoch_row.update(result_clf)

    epoch_row.update(result)
    epoch_row["runtime"] = loop_time.elapsed_secs * 1000
    acc_results = compute_accs(feature_model, train_inputs, test_inputs, class_dist)
    epoch_row.update(acc_results)
    if clf_loss is not None:
        clf_accs = compute_clf_accs(clf, feature_model, train_inputs, test_inputs)
        epoch_row.update(clf_accs)
    if i_epoch % (n_epochs // 20) != 0:
        df = df.append(epoch_row, ignore_index=True)
        # otherwise add ot loss in
    else:
        for i_class in range(len(train_inputs)):
            with th.no_grad():
                class_ins = train_inputs[i_class]
                samples = class_dist.get_samples(
                    i_class, len(train_inputs[i_class]) * 4
                )
                inverted = feature_model.invert(samples)
                clear_ctx_dicts(feature_model)
                ot_loss_in = ot_euclidean_loss_for_samples(
                    class_ins.view(class_ins.shape[0], -1),
                    inverted.view(inverted.shape[0], -1)[: (len(class_ins))],
                )
                epoch_row["ot_loss_in_{:d}".format(i_class)] = ot_loss_in.item()
        df = df.append(epoch_row, ignore_index=True)
        print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
        print("Loop Time: {:.0f} ms".format(loop_time.elapsed_secs * 1000))
        display(df.iloc[-3:].loc[:,['test_acc', 'test_clf_acc', 'test_data_acc', 'train_acc',
          'train_clf_acc', 'train_data_acc',
                'ot_loss_in_0', 'ot_loss_in_1', 'ot_out_loss', 
          'runtime', 'subspace_loss_0', 'subspace_loss_1',
          'clf_loss_0', 'clf_loss_1',
          'g_grad', 'g_grad_norm', 'g_loss',
       ]])

In [None]:
import matplotlib.style as style
style.use('seaborn-poster')
seaborn.set_context('poster')

seaborn.set_palette("colorblind", )

In [None]:
with seaborn.axes_style("white"):
    plt.figure(figsize=(8,2))
    plt.plot(var_to_np(train_inputs[0][0].squeeze()))
    plt.axis('off')
    plt.savefig('/data/schirrmr/schirrmr/ohbm-rom/signal3.png', dpi=300 ,transparent=True,
               bbox_inches='tight', pad_inches=0)
    plt.ylim(-2.5,2.5)

In [None]:
with seaborn.axes_style("white"):
    plt.figure(figsize=(8,2))
    plt.plot(var_to_np(train_inputs[1][0].squeeze()), color=seaborn.color_palette()[1])
    plt.axis('off')
    plt.savefig('/data/schirrmr/schirrmr/ohbm-rom/signal4.png', dpi=300 ,transparent=True,
               bbox_inches='tight', pad_inches=0)
    plt.ylim(-2.5,2.5)

In [None]:
with seaborn.axes_style("white"):
    plt.figure(figsize=(8,2))
    plt.plot(var_to_np(train_inputs[0][1].squeeze()))
    plt.axis('off')
    plt.savefig('/data/schirrmr/schirrmr/ohbm-rom/signal3.png', dpi=300 ,transparent=True,
               bbox_inches='tight', pad_inches=0)
    plt.ylim(-2.5,2.5)

In [None]:
with seaborn.axes_style("white"):
    plt.figure(figsize=(8,2))
    plt.plot(var_to_np(train_inputs[1][6].squeeze()), color=seaborn.color_palette()[1])
    plt.axis('off')
    plt.savefig('/data/schirrmr/schirrmr/ohbm-rom/signal4.png', dpi=300 ,transparent=True,
               bbox_inches='tight', pad_inches=0)
    plt.ylim(-2.5,2.5)

In [None]:
plt.plot([0,1], color=seaborn.color_palette()[0])
plt.plot([0,1], color=seaborn.color_palette()[1])
plt.legend(["Class 1", "Class 2"], bbox_to_anchor=(1,1,0,0))

In [None]:
for i_example in range(2,8):
    with seaborn.axes_style("white"):
        plt.figure(figsize=(8,2))
        plt.plot(var_to_np(train_inputs[0][i_example].squeeze()))
        plt.axis('off')
        i_file = i_example * 2
        plt.savefig('/data/schirrmr/schirrmr/ohbm-rom/signal{:d}.png'.format(i_file), dpi=300 ,transparent=True,
                   bbox_inches='tight', pad_inches=0)
        plt.ylim(-2.5,2.5)

In [None]:
for i_example in range(2,8):
    with seaborn.axes_style("white"):
        plt.figure(figsize=(8,2))
        plt.plot(var_to_np(train_inputs[1][i_example].squeeze()), color=seaborn.color_palette()[1])
        plt.axis('off')
        i_file = i_example * 2 + 1
        plt.savefig('/data/schirrmr/schirrmr/ohbm-rom/signal{:d}.png'.format(i_file), dpi=300 ,transparent=True,
                   bbox_inches='tight', pad_inches=0)
        plt.ylim(-4,4)

In [None]:
plt.figure(figsize=(4,4))
rng = RandomState(45984985)
means = [[-2,2], [3,-3]]
stds = [[1,0.5], [0.75,1]]
with seaborn.axes_style("white"):
    for i_class in range(2):
        x,y = (rng.randn(2,500) * np.array(stds[i_class])[:,None]) + np.array(means[i_class])[:,None]
        plt.scatter(x,y, alpha=.15, color=seaborn.color_palette()[i_class])
        plt.scatter(means[i_class][0], means[i_class][1], color=seaborn.color_palette()[i_class], 
                   label="Class {:d}".format(i_class))
        plt.scatter(means[i_class][0], means[i_class][1], color=lighten_color(seaborn.color_palette()[i_class], amount=0.75), 
                   label="Mean {:d}".format(i_class),
                    marker='x')
        #plt.scatter(x[:3], y[:3], color=seaborn.color_palette()[i_class], marker='x')
plt.legend(bbox_to_anchor=(0.67,1.52,0,0))
plt.axis('off');
plt.savefig('/data/schirrmr/schirrmr/ohbm-rom/dists.png', dpi=300 ,transparent=True,
                   bbox_inches='tight', pad_inches=0)

In [None]:
def lighten_color(color, amount=0.5):
    """
    Lightens the given color by multiplying (1-luminosity) by the given amount.
    Input can be matplotlib color string, hex string, or RGB tuple.

    Examples:
    >> lighten_color('g', 0.3)
    >> lighten_color('#F034A3', 0.6)
    >> lighten_color((.3,.55,.1), 0.5)
    """
    import matplotlib.colors as mc
    import colorsys
    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])

In [None]:
with seaborn.axes_style("white"):
    fig, axes = plt.subplots(2,1, figsize=(8,3), sharex=True, sharey=True)
    for i_class in range(2):
        axes[i_class].plot(var_to_np(train_inputs[i_class][0].squeeze()),
                color=lighten_color(seaborn.color_palette()[i_class], amount=0.75))
    axes[0].axis('off')
    axes[1].axis('off')
    ls = axes[0].get_lines() + axes[1].get_lines()
    fig.legend(ls, ["Mean 1", "Mean 2"], fontsize=24,loc='upper right',
              bbox_to_anchor=(1.3,0.8))
    plt.tight_layout(pad=0)
    plt.savefig('/data/schirrmr/schirrmr/ohbm-rom/means.png', dpi=300 ,transparent=True,
               bbox_inches='tight', pad_inches=0)

In [None]:
rng = RandomState(21434)
prior = rng.randn(80,2) * 0.5
outputs = rng.rand(80,2) * np.array([1,4]) + np.array([[-2,-2]])

fig = plt.figure(figsize=(7,7))

plt.plot(outputs[:,0], outputs[:,1], ls='', marker='o', label='Real EEG Signals',
        color=seaborn.color_palette()[2])
plt.plot(prior[:,0], prior[:,1], ls='', marker='o', label='Generated Signals',
        color=seaborn.color_palette()[3])



#plt.xlim(-3, 1.5)
#plt.ylim(-2.5, 2.5)

import ot

dists = np.sum(np.square(prior[:,None] - outputs[None]), axis=2)

t_map = ot.emd([],[], dists)

matchings = np.argmax(t_map, axis=1)


for i_line, prior_p, out_p in zip(range(len(prior)), prior, outputs[matchings]):
    label = ''
    if i_line == 0:
        label = 'Matching'
    plt.plot([prior_p[0], out_p[0]], [prior_p[1], out_p[1]], color='black',
            lw=0.25, label=label)
plt.legend(fontsize=22, bbox_to_anchor=(0,0,1.05,1.3))
plt.axis('off')

plt.savefig('/data/schirrmr/schirrmr/ohbm-rom/OT.png', dpi=300 ,transparent=True,
           bbox_inches='tight', pad_inches=0)


In [None]:
np.min(outputs)

In [None]:
new_prior = prior * 0.3 + outputs * 0.7


rng = RandomState(21434)

fig = plt.figure(figsize=(7,7))

plt.plot(outputs[:,0], outputs[:,1], ls='', marker='o', label='Real EEG Signals',
        color=seaborn.color_palette()[2])
plt.plot(new_prior[:,0], new_prior[:,1], ls='', marker='o', label='Generated Signals',
        color=seaborn.color_palette()[3])



plt.xlim(np.min(outputs)-0.2, np.max(prior))

import ot

dists = np.sum(np.square(new_prior[:,None] - outputs[None]), axis=2)

t_map = ot.emd([],[], dists)

matchings = np.argmax(t_map, axis=1)


for i_line, prior_p, out_p in zip(range(len(prior)), new_prior, outputs[matchings]):
    label = ''
    if i_line == 0:
        label = 'Matching'
    plt.plot([prior_p[0], out_p[0]], [prior_p[1], out_p[1]], color='black',
            lw=0.25, label=label)
plt.legend(fontsize=22, bbox_to_anchor=(0,0,1.05,1.3))
plt.axis('off')

plt.savefig('/data/schirrmr/schirrmr/ohbm-rom/OT_updated.png', dpi=300 ,transparent=True,
           bbox_inches='tight', pad_inches=0)
