In [None]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [None]:
%%capture
import os
import site
os.sys.path.insert(0, '/home/schirrmr/code/reversible/')
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')
os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')


%load_ext autoreload
%autoreload 2
import numpy as np
import logging
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14
import seaborn
seaborn.set_style('darkgrid')

from reversible2.sliced import sliced_from_samples
from numpy.random import RandomState

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import copy
import math

import itertools
import torch as th
from braindecode.torch_ext.util import np_to_var, var_to_np
from reversible2.splitter import SubsampleSplitter

from reversible2.view_as import ViewAs

from reversible2.affine import AdditiveBlock
from reversible2.plot import display_text, display_close
from reversible2.bhno import load_file, create_inputs
th.backends.cudnn.benchmark = True

In [None]:
sensor_names = ['Fz', 
                'FC3','FC1','FCz','FC2','FC4',
                'C5','C3','C1','Cz','C2','C4','C6',
                'CP3','CP1','CPz','CP2','CP4',
                'P1','Pz','P2',
                'POz']
orig_train_cnt = load_file('/data/schirrmr/schirrmr/HGD-public/reduced/train/4.mat')
train_cnt = orig_train_cnt.reorder_channels(sensor_names)

train_inputs = create_inputs(train_cnt, final_hz=256, half_before=True,
                            start_ms=500, stop_ms=1500)
n_split = len(train_inputs[0]) - 40
test_inputs = [t[-40:] for t in train_inputs]
train_inputs = [t[:-40] for t in train_inputs]

In [None]:
cuda = True
if cuda:
    train_inputs = [i.cuda() for i in train_inputs]
    test_inputs = [i.cuda() for i in test_inputs]

In [None]:
from reversible2.invert import invert
from reversible2.scale import ScalingLayer
from reversible2.constantmemory import sequential_to_constant_memory
from reversible2.rfft import RFFT, IRFFT
from reversible2.constantmemory import clear_ctx_dicts

test_module = nn.Sequential(
    ViewAs((-1,22,256,1), (-1,256)),
    RFFT(),
    ScalingLayer((256,)),
    IRFFT(),
    ViewAs((-1,256),(-1,22,256,1)) )
test_module = sequential_to_constant_memory(test_module)
test_module.cuda()
test_module.zero_grad()
out = test_module(train_inputs[0])
ins = invert(test_module, out)
assert th.allclose(out, train_inputs[0], rtol=1e-4, atol=1e-4)
loss = th.norm(ins)
loss.backward()
del out, ins, loss
clear_ctx_dicts(test_module)
test_module.zero_grad()

In [None]:
feature_model = th.load('/data/schirrmr/schirrmr/reversible/models/notebooks/21ChansOT/feature_model.pkl')

class_dist = th.load('/data/schirrmr/schirrmr/reversible/models/notebooks/21ChansOT/class_dist.pkl')
feature_model.eval();

start_node = feature_model.find_starting_node()

start_node.module = nn.Sequential(*(list(test_module.children()) + 
                                   list(start_node.module.children())))

In [None]:
optim_model = th.optim.Adam(feature_model.parameters(), lr=1e-3, betas=(0.9,0.999))
optim_dist = th.optim.Adam(class_dist.parameters(), lr=1e-2, betas=(0.9,0.999))

### only scaling factors

In [None]:
#optim_model = th.optim.Adam(test_module.parameters(), lr=1e-2, betas=(0.9,0.999))


In [None]:
%%writefile plot.py
import torch as th
import matplotlib.pyplot as plt
import numpy as np
from reversible2.util import var_to_np
from reversible2.plot import display_close
from matplotlib.patches import Ellipse
import seaborn

def plot_outs(feature_model, train_inputs, test_inputs, class_dist):
    with th.no_grad():
        # Compute dist for mean/std of encodings
        data_cls_dists = []
        for i_class in range(len(train_inputs)):
            this_class_outs = feature_model(train_inputs[i_class])[:,class_dist.i_class_inds]
            data_cls_dists.append(
                th.distributions.MultivariateNormal(th.mean(this_class_outs, dim=0),
                covariance_matrix=th.diag(th.std(this_class_outs, dim=0) ** 2)))
        for setname, set_inputs in (("Train", train_inputs), ("Test", test_inputs)):

            outs = [feature_model(ins) for ins in set_inputs]
            c_outs = [o[:,class_dist.i_class_inds] for o in outs]

            c_outs_all = th.cat(c_outs)

            cls_dists = []
            for i_class in range(len(c_outs)):
                mean, std = class_dist.get_mean_std(i_class)
                cls_dists.append(
                    th.distributions.MultivariateNormal(mean[class_dist.i_class_inds],
                                                        covariance_matrix=th.diag(std[class_dist.i_class_inds] ** 2)))

            preds_per_class = [th.stack([cls_dists[i_cls].log_prob(c_out)
                             for i_cls in range(len(cls_dists))],
                            dim=-1) for c_out in c_outs]

            pred_labels_per_class = [np.argmax(var_to_np(preds), axis=1)
                           for preds in preds_per_class]

            labels = np.concatenate([np.ones(len(set_inputs[i_cls])) * i_cls 
             for i_cls in range(len(train_inputs))])

            acc = np.mean(labels == np.concatenate(pred_labels_per_class))

            data_preds_per_class = [th.stack([data_cls_dists[i_cls].log_prob(c_out)
                             for i_cls in range(len(cls_dists))],
                            dim=-1) for c_out in c_outs]
            data_pred_labels_per_class = [np.argmax(var_to_np(data_preds), axis=1)
                                for data_preds in data_preds_per_class]
            data_acc = np.mean(labels == np.concatenate(data_pred_labels_per_class))

            print("{:s} Accuracy: {:.1f}%".format(setname, acc * 100))
            fig = plt.figure(figsize=(5,5))
            ax = plt.gca()
            for i_class in range(len(c_outs)):
                #if i_class == 0:
                #    continue
                o = var_to_np(c_outs[i_class]).squeeze()
                incorrect_pred_mask = pred_labels_per_class[i_class] != i_class
                plt.scatter(o[:,0], o[:,1], s=20, alpha=0.75, label=["Right", "Rest"][i_class])
                assert len(incorrect_pred_mask) == len(o)
                plt.scatter(o[incorrect_pred_mask,0], o[incorrect_pred_mask,1], marker='x', color='black',
                           alpha=1, s=5)
                means, stds = class_dist.get_mean_std(i_class)
                means = var_to_np(means)[class_dist.i_class_inds]
                stds = var_to_np(stds)[class_dist.i_class_inds]
                for sigma in [0.5,1,2,3]:
                    ellipse = Ellipse(means, stds[0]*sigma, stds[1]*sigma)
                    ax.add_artist(ellipse)
                    ellipse.set_edgecolor(seaborn.color_palette()[i_class])
                    ellipse.set_facecolor("None")
            for i_class in range(len(c_outs)):
                o = var_to_np(c_outs[i_class]).squeeze()
                plt.scatter(np.mean(o[:,0]), np.mean(o[:,1]),
                           color=seaborn.color_palette()[i_class+2], s=80, marker="^",
                           label=["Right Mean", "Rest Mean"][i_class])

            plt.title("{:6s} Accuracy:        {:.1f}%\n"
                      "From data mean/std: {:.1f}%".format(setname, acc * 100, data_acc * 100))
            plt.legend(bbox_to_anchor=(1,1,0,0))
            display_close(fig)
    return

In [None]:
import pandas as pd
df = pd.DataFrame()

from reversible2.training import OTTrainer
trainer = OTTrainer(feature_model, class_dist,
                optim_model, optim_dist)

In [None]:
from reversible2.ot_exact import ot_euclidean_loss_for_samples

In [None]:
class_dist.i_class_inds = [0,1]

In [None]:
from reversible2.constantmemory import clear_ctx_dicts
from reversible2.timer import Timer
from plot import plot_outs
from reversible2.gradient_penalty import gradient_penalty


i_start_epoch_out = 401
n_epochs = 1001
for i_epoch in range(n_epochs):
    epoch_row = {}
    with Timer(name='EpochLoop', verbose=False) as loop_time:
        loss_on_outs = i_epoch >= i_start_epoch_out
        result = trainer.train(train_inputs, loss_on_outs=loss_on_outs)
        
    epoch_row.update(result)
    epoch_row['runtime'] = loop_time.elapsed_secs * 1000
    if i_epoch % (n_epochs // 20) != 0:
        df = df.append(epoch_row, ignore_index=True)
        # otherwise add ot loss in
    else:
        for i_class in range(len(train_inputs)):
            with th.no_grad():
                class_ins = train_inputs[i_class]
                samples = class_dist.get_samples(i_class, len(train_inputs[i_class]) * 4)
                inverted = feature_model.invert(samples)
                clear_ctx_dicts(feature_model)
                ot_loss_in = ot_euclidean_loss_for_samples(class_ins.view(class_ins.shape[0], -1),
                                                           inverted.view(inverted.shape[0], -1)[:(len(class_ins))])
                epoch_row['ot_loss_in_{:d}'.format(i_class)] = ot_loss_in.item()
        df = df.append(epoch_row, ignore_index=True)
        print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
        print("Loop Time: {:.0f} ms".format(loop_time.elapsed_secs * 1000))
        display(df.iloc[-3:])
        plot_outs(feature_model, train_inputs, test_inputs,
                 class_dist)
        fig = plt.figure(figsize=(8,2))
        plt.plot(var_to_np(th.cat((th.exp(class_dist.class_log_stds),
                                 th.exp(class_dist.non_class_log_stds)))),
                marker='o')
        display_close(fig)
        fig = plt.figure(figsize=(8,2))
        plt.plot(1 / np.exp(var_to_np(test_module[2].F.log_factors)))
        plt.title("Factors per FFT component")
        display_close(fig)



In [None]:
from reversible2.plot import plot_head_signals_tight
inverted_per_class = []
for i_class in range(2):
    with th.no_grad():
        samples = class_dist.get_samples(i_class, len(train_inputs[i_class]) * 4)
        inverted = feature_model.invert(samples)
        inverted = var_to_np(inverted).squeeze()
        signals = var_to_np(train_inputs[i_class]).squeeze()
    clear_ctx_dicts(feature_model)
    

In [None]:
import matplotlib.style as style
style.use('seaborn-poster')
seaborn.set_context('poster')

seaborn.set_palette("colorblind", )



In [None]:
fig = plt.figure(figsize=(8,2))
plt.plot(np.fft.rfftfreq(256,d=1.0/256.0), np.mean(np.abs(np.fft.rfft(signals)), axis=(0,1)),
         color=seaborn.color_palette()[2])
plt.plot(np.fft.rfftfreq(256,d=1.0/256.0), np.mean(np.abs(np.fft.rfft(inverted)), axis=(0,1)),
         color=seaborn.color_palette()[3])
plt.xlabel("Frequency [Hz]")
plt.ylabel("Amplitude")

In [None]:
for i_class in range(len(train_inputs)):
    with th.no_grad():
        class_ins = train_inputs[i_class]
        samples = class_dist.get_samples(i_class, len(train_inputs[i_class]) * 4)
        inverted = feature_model.invert(samples)
        clear_ctx_dicts(feature_model)
        ot_loss_in = ot_euclidean_loss_for_samples(class_ins.view(class_ins.shape[0], -1),
                                                   inverted.view(inverted.shape[0], -1)[:(len(class_ins))])
        epoch_row['ot_loss_in_{:d}'.format(i_class)] = ot_loss_in.item()
print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
print("Loop Time: {:.0f} ms".format(loop_time.elapsed_secs * 1000))
display(df.iloc[-3:])
plot_outs(feature_model, train_inputs, test_inputs,
         class_dist)
fig = plt.figure(figsize=(8,2))
plt.plot(var_to_np(th.cat((th.exp(class_dist.class_log_stds),
                         th.exp(class_dist.non_class_log_stds)))),
        marker='o')
display_close(fig)
fig = plt.figure(figsize=(8,2))
plt.plot(1 / np.exp(var_to_np(test_module[2].F.log_factors)))
plt.title("Factors per FFT component")
display_close(fig)

In [None]:
1/th.exp(test_module[2].F.log_factors)

In [None]:
from reversible2.plot import plot_head_signals_tight

mean_bps_per_class = []
for i_class in range(len(train_inputs)):
    samples = class_dist.get_samples(i_class, 400)
    inverted = feature_model.invert(samples)
    mean_bps_per_class.append(
        np.mean(np.abs(np.fft.rfft(var_to_np(inverted.squeeze()))), axis=0))

fig = plt.figure(figsize=(8,3))
plt.plot(np.mean(mean_bps_per_class, axis=1).T)
display_close(fig)

fig = plot_head_signals_tight(np.stack(mean_bps_per_class, axis=-1), sensor_names=sensor_names,
                             figsize=(20,12));

In [None]:
from reversible2.plot import plot_head_signals_tight

mean_bps_per_class = []
for i_class in range(len(train_inputs)):
    samples = class_dist.get_samples(i_class, 400)
    inverted = feature_model.invert(samples)
    mean_bps_per_class.append(
        np.mean(np.log(np.abs(np.fft.rfft(var_to_np(inverted.squeeze())))), axis=0))

fig = plt.figure(figsize=(8,3))
plt.plot(np.mean(mean_bps_per_class, axis=1).T)
display_close(fig)

fig = plot_head_signals_tight(np.stack(mean_bps_per_class, axis=-1), sensor_names=sensor_names,
                             figsize=(20,12));

In [None]:
from reversible2.plot import plot_head_signals_tight

mean_bps_per_class = []
for i_class in range(len(train_inputs)):
    samples = class_dist.get_samples(i_class, 400)
    inverted = train_inputs[i_class]
    mean_bps_per_class.append(
        np.mean(np.log(np.abs(np.fft.rfft(var_to_np(inverted.squeeze())))), axis=0))

fig = plt.figure(figsize=(8,3))
plt.plot(np.mean(mean_bps_per_class, axis=1).T)
display_close(fig)

fig = plot_head_signals_tight(np.stack(mean_bps_per_class, axis=-1), sensor_names=sensor_names,
                             figsize=(20,12));

In [None]:
from reversible2.plot import plot_head_signals_tight

mean_bps_per_class = []
for i_class in range(len(train_inputs)):
    samples = class_dist.get_samples(i_class, 400)
    inverted = feature_model.invert(samples)
    mean_bps_per_class.append(
        np.mean(np.abs(np.fft.rfft(var_to_np(inverted.squeeze()))), axis=0))

fig = plt.figure(figsize=(8,3))
plt.plot(np.mean(mean_bps_per_class, axis=1).T)
display_close(fig)

fig = plot_head_signals_tight(np.stack(mean_bps_per_class, axis=-1), sensor_names=sensor_names,
                             figsize=(20,12));

In [None]:
start_node.module[2].module.factors.grad

In [None]:
[p.grad for p in feature_model.parameters()]