In [None]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [None]:
%%capture
import os
import site
os.sys.path.insert(0, '/home/schirrmr/code/reversible/')
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')
os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')


%load_ext autoreload
%autoreload 2
import numpy as np
import logging
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14
import seaborn
seaborn.set_style('darkgrid')

from reversible2.sliced import sliced_from_samples
from numpy.random import RandomState

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import copy
import math

import itertools
import torch as th
from braindecode.torch_ext.util import np_to_var, var_to_np
from reversible2.splitter import SubsampleSplitter

from reversible2.view_as import ViewAs

from reversible2.affine import AdditiveBlock
from reversible2.plot import display_text, display_close
from reversible2.high_gamma import load_file, create_inputs
from reversible2.high_gamma import load_train_test
th.backends.cudnn.benchmark = True
from reversible2.models import deep_invertible


In [None]:
sensor_names = ['Fz', 
                'FC3','FC1','FCz','FC2','FC4',
                'C5','C3','C1','Cz','C2','C4','C6',
                'CP3','CP1','CPz','CP2','CP4',
                'P1','Pz','P2',
                'POz']

In [None]:
# create model
# create dist

train_inputs, test_inputs = load_train_test(
    subject_id=4,
    car=True,
    n_sensors=22,
    final_hz=256,
    start_ms=500,
    stop_ms=1500,
    half_before=True,
    only_load_given_sensors=False,
)

In [None]:
train_less = [t[:10,7:9].clone().contiguous() for t in train_inputs]
test_less = [t[:10,7:9].clone().contiguous() for t in test_inputs]


In [None]:
for t in train_less + test_less:
    t.data[:,1] = 0

In [None]:
from reversible2.distribution import TwoClassIndependentDist

def create_model():
    n_chan_pad = 0
    filter_length_time = 11
    feature_model = deep_invertible(
        n_chans, n_time,  n_chan_pad,  filter_length_time)
    from reversible2.view_as import ViewAs
    feature_model.add_module('flatten',
                             ViewAs((-1, 8*2, 32), (-1, 8*2*32)))
    from reversible2.graph import Node
    feature_model = Node(None, feature_model)
    return feature_model

def create_dist():
    return TwoClassIndependentDist(np.prod(train_less[0].size()[1:]))
    

from reversible2.invert import invert

class ModelAndDist(nn.Module):
    def __init__(self, model, dist):
        super(ModelAndDist, self).__init__()
        self.model = model
        self.dist = dist
        
    def get_examples(self, i_class, n_samples,):
        samples = self.dist.get_samples(i_class=i_class, n_samples=n_samples)
        if hasattr(self.model, 'invert'):
            examples = invert(self.model, samples)
        else:
            examples = self.model.invert(samples)
        return examples

In [None]:
import ot

from reversible2.ot_exact import get_matched_samples

def flatten_2d(a):
    return a.view(len(a), -1)

In [None]:
from reversible2.constantmemory import clear_ctx_dicts
def set_dist_to_empirical(feature_model, class_dist, train_inputs):
    for i_class in range(len(train_inputs)):
        with th.no_grad():
            this_outs = feature_model(train_inputs[i_class])
            mean = th.mean(this_outs, dim=0)
            std = th.std(this_outs, dim=0)
            class_dist.set_mean_std(i_class, mean, std)
            # Just check
            setted_mean, setted_std = class_dist.get_mean_std(i_class)
            assert th.allclose(mean, setted_mean)
            assert th.allclose(std, setted_std)
    clear_ctx_dicts(feature_model)


In [None]:
n_chans = train_less[0].shape[1]
n_time = train_less[0].shape[2]
model = create_model()

dist = create_dist()
model_and_dist = ModelAndDist(model, dist)

set_dist_to_empirical(model_and_dist.model, model_and_dist.dist, train_less)

optim = th.optim.Adam([{'params': dist.parameters()},
                      {'params': list(model.parameters()),
                      'lr': 1e-2}])

In [None]:
i_class = 1
n_epochs = 101
class_ins = train_less[i_class]
for i_epoch in range(n_epochs):
    examples = model_and_dist.get_examples(1,len(class_ins) * 3)
    matched_examples = get_matched_samples(flatten_2d(class_ins), flatten_2d(examples))
    loss = th.mean(th.norm(flatten_2d(class_ins).unsqueeze(1)  - matched_examples, p=2, dim=2))#
    optim.zero_grad()
    loss.backward()
    optim.step()
    

    if i_epoch % (n_epochs // 20) == 0:
        display_text("OT {:.1E}".format(loss.item()))
        fig, axes = plt.subplots(5,2, figsize=(16,12), sharex=True, sharey=True)
        for ax, signal, matched in zip(axes.flatten(), class_ins, matched_examples):
            ax.plot(var_to_np(signal).squeeze().T)
            for ex in var_to_np(matched.view(len(matched), class_ins.shape[1], class_ins.shape[2])):
                ax.plot(ex[0], color=seaborn.color_palette()[0], lw=0.5, alpha=0.7)
                ax.plot(ex[1], color=seaborn.color_palette()[1], lw=0.5, alpha=0.7)
        display_close(fig)
        fig = plt.figure()
        plt.plot(var_to_np(th.exp(model_and_dist.dist.class_log_stds)).T)
        display_close(fig)

In [None]:
# plot real signal and matched example
# see how they move
# try it for another network trained on different examples and see what ot looks like then
