In [None]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [None]:
%%capture
import os
import site
os.sys.path.insert(0, '/home/schirrmr/code/reversible/')
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')
os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')


%load_ext autoreload
%autoreload 2
import numpy as np
import logging
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14
import seaborn
seaborn.set_style('darkgrid')

from reversible2.sliced import sliced_from_samples
from numpy.random import RandomState

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import copy
import math

import itertools
import torch as th
from braindecode.torch_ext.util import np_to_var, var_to_np
from reversible2.splitter import SubsampleSplitter

from reversible2.view_as import ViewAs

from reversible2.affine import AdditiveBlock
from reversible2.plot import display_text, display_close
from reversible2.high_gamma import load_file, create_inputs
from reversible2.high_gamma import load_train_test
th.backends.cudnn.benchmark = True
from reversible2.models import deep_invertible


In [None]:
sensor_names = ['Fz', 
                'FC3','FC1','FCz','FC2','FC4',
                'C5','C3','C1','Cz','C2','C4','C6',
                'CP3','CP1','CPz','CP2','CP4',
                'P1','Pz','P2',
                'POz']

In [None]:
# create model
# create dist

train_inputs, test_inputs = load_train_test(
    subject_id=4,
    car=True,
    n_sensors=22,
    final_hz=256,
    start_ms=500,
    stop_ms=1500,
    half_before=True,
    only_load_given_sensors=False,
)

In [None]:
train_less = [t[:10,7:9].clone().contiguous() for t in train_inputs]
test_less = [t[:10,7:9].clone().contiguous() for t in test_inputs]


In [None]:
for t in train_less + test_less:
    t.data[:,1] = 0

In [None]:
from reversible2.models import larger_model

In [None]:
from reversible2.distribution import TwoClassIndependentDist

def create_model():
    n_chan_pad = 0
    filter_length_time = 11
    feature_model = deep_invertible(
        n_chans, n_time,  n_chan_pad,  filter_length_time)
    
    feature_model = larger_model(n_chans, n_time, final_fft=True, kernel_length=11, constant_memory=False)
    
    return feature_model


def to_generator(feature_model):
    from reversible2.view_as import ViewAs
    feature_model.add_module('flatten',
                             ViewAs((-1, 8*2, 32), (-1, 8*2*32)))
    from reversible2.graph import Node
    feature_model = Node(None, feature_model)
    return feature_model

def create_dist():
    return TwoClassIndependentDist(np.prod(train_less[0].size()[1:]))
    

from reversible2.invert import invert

class ModelAndDist(nn.Module):
    def __init__(self, model, dist):
        super(ModelAndDist, self).__init__()
        self.model = model
        self.dist = dist
        
    def get_examples(self, i_class, n_samples,):
        samples = self.dist.get_samples(i_class=i_class, n_samples=n_samples)
        if hasattr(self.model, 'invert'):
            examples = invert(self.model, samples)
        else:
            examples = self.model.invert(samples)
        return examples

import ot

from reversible2.ot_exact import get_matched_samples

def flatten_2d(a):
    return a.view(len(a), -1)

from reversible2.constantmemory import clear_ctx_dicts
def set_dist_to_empirical(feature_model, class_dist, inputs):
    for i_class in range(len(inputs)):
        with th.no_grad():
            this_outs = feature_model(inputs[i_class].cuda())
            mean = th.mean(this_outs, dim=0)
            std = th.std(this_outs, dim=0)
            class_dist.set_mean_std(i_class, mean, std)
            # Just check
            setted_mean, setted_std = class_dist.get_mean_std(i_class)
            assert th.allclose(mean, setted_mean)
            assert th.allclose(std, setted_std)
    clear_ctx_dicts(feature_model)

In [None]:
n_chans = train_less[0].shape[1]
n_time = train_less[0].shape[2]

model = create_model()
#model = to_generator(model)
model.cuda()
dist = create_dist()
dist.cuda()
model_and_dist = ModelAndDist(model, dist)

set_dist_to_empirical(model_and_dist.model, model_and_dist.dist, train_less)

optim = th.optim.Adam([{'params': dist.parameters(), 'lr':1e-2},
                      {'params': list(model_and_dist.model.parameters()),
                      'lr': 1e-3}])

In [None]:
i_class = 1
n_epochs = 2001
class_ins = train_less[i_class].cuda()
for i_epoch in range(n_epochs):
    examples = model_and_dist.get_examples(1,len(class_ins) * 20)
    matched_examples = get_matched_samples(flatten_2d(class_ins), flatten_2d(examples))
    loss = th.mean(th.norm(flatten_2d(class_ins).unsqueeze(1)  - matched_examples, p=2, dim=2))#
    optim.zero_grad()
    loss.backward()
    optim.step()
    

    if i_epoch % (n_epochs // 20) == 0:
        display_text("OT {:.1E}".format(loss.item()))
        fig, axes = plt.subplots(5,2, figsize=(16,12), sharex=True, sharey=True)
        for ax, signal, matched in zip(axes.flatten(), class_ins, matched_examples):
            ax.plot(var_to_np(signal).squeeze().T)
            for ex in var_to_np(matched.view(len(matched), class_ins.shape[1], class_ins.shape[2])):
                ax.plot(ex[0], color=seaborn.color_palette()[0], lw=0.5, alpha=0.7)
                ax.plot(ex[1], color=seaborn.color_palette()[1], lw=0.5, alpha=0.7)
        display_close(fig)
        fig = plt.figure()
        plt.plot(var_to_np(th.exp(model_and_dist.dist.class_log_stds)[1]))
        display_close(fig)

In [None]:
from reversible2.ot_exact import ot_euclidean_loss_for_samples
print("Train-Test:    {:.1E}".format(
    ot_euclidean_loss_for_samples(flatten_2d(train_less[i_class]), flatten_2d(test_less[i_class]))))
print("Fake-Test:     {:.1E}".format(
    ot_euclidean_loss_for_samples(flatten_2d(model_and_dist.get_examples(i_class, len(train_less[i_class]))),
                              flatten_2d(test_less[i_class].cuda()))))
print("Fake-Train:    {:.1E}".format(ot_euclidean_loss_for_samples(flatten_2d(model_and_dist.get_examples(i_class, len(train_less[i_class]))),
                              flatten_2d(train_less[i_class].cuda()))))
print("Fake*10-Test:  {:.1E}".format(
    ot_euclidean_loss_for_samples(flatten_2d(model_and_dist.get_examples(i_class, len(train_less[i_class])*10)),
                              flatten_2d(test_less[i_class].cuda()))))
print("Fake*10-Train: {:.1E}".format(
    ot_euclidean_loss_for_samples(flatten_2d(model_and_dist.get_examples(i_class, len(train_less[i_class]) * 10)),
                              flatten_2d(train_less[i_class].cuda()))))
print("Zero-Train:    {:.1E}".format(
    ot_euclidean_loss_for_samples(1e-4*flatten_2d(model_and_dist.get_examples(i_class, len(train_less[i_class]))),
                              flatten_2d(train_less[i_class].cuda()))))
print("Zero-Test:     {:.1E}".format(
    ot_euclidean_loss_for_samples(1e-4*flatten_2d(model_and_dist.get_examples(i_class, len(train_less[i_class]))),
                              flatten_2d(test_less[i_class].cuda()))))

In [None]:

    
from reversible2.grids import create_th_grid
def create_grid(mean, std, n_grid_points,i_dims):
    i_dim_0, i_dim_1 = i_dims
    mins = mean[[i_dim_0, i_dim_1]] - std[[i_dim_0, i_dim_1]]
    maxs = mean[[i_dim_0, i_dim_1]] + std[[i_dim_0, i_dim_1]]
    dim_0_vals = th.linspace(mins[0].item(), maxs[0].item(),
                             n_grid_points, device=mean.device)
    dim_1_vals = th.linspace(mins[1].item(), maxs[1].item(),
                             n_grid_points, device=mean.device)

    grid = create_th_grid(dim_0_vals, dim_1_vals)

    full_grid = mean.repeat(grid.shape[0],grid.shape[1],1)

    full_grid.data[:,:,i_dim_0] = grid[:,:,0]
    full_grid.data[:,:,i_dim_1] = grid[:,:,1]
    return full_grid, mins, maxs


def plot_grid(inverted_grid, mins, maxs):
    x_len, y_len = var_to_np(maxs - mins) / full_grid.shape[:-1]

    max_abs = th.max(th.abs(inverted_grid))

    y_factor =  (y_len / (2*max_abs)).item() * 0.9
    fig = plt.figure(figsize=(32,32))

    for i_x in range(full_grid.shape[0]):
        for i_y in range(full_grid.shape[1]):
            x_start = mins[0].item() + x_len * i_x + 0.1 * x_len
            x_end = mins[0].item() + x_len * i_x + 0.9 * x_len
            y_center = mins[1].item() + y_len * i_y + 0.5 * y_len

            curve = var_to_np(inverted_grid[i_x][i_y])
            label = ''
            if i_x == 0 and i_y == 0:
                label = 'Generated data'
            plt.plot(np.linspace(x_start, x_end, len(curve)),
                     curve * y_factor + y_center, color='black',
                    label=label)
    return fig

def plot_two_dim_embedding(model_and_dist, i_dims, i_class, inputs=None, autoscale=False, ):
    mean, std = model_and_dist.dist.get_mean_std(i_class)
    grid, mins, maxs = create_grid(mean, std, n_grid_points=30,i_dims=i_dims,)
    inverted = invert(model, grid.view(-1, grid.shape[-1])).squeeze()[:,0] # ignore empty chan
    inverted_grid = inverted.view(len(dim_0_vals), len(dim_1_vals),-1)
    fig = plot_grid(inverted_grid, mins, maxs);
    if inputs is not None:
        add_outs_to_grid_plot(model_and_dist, inputs, i_dims=i_dims, autoscale=autoscale, )
    return fig

def add_outs_to_grid_plot(model_and_dist, inputs, i_dims, autoscale):
    outs = model_and_dist.model(inputs)[:,i_dims]
    plt.gca().set_autoscale_on(autoscale)
    plt.scatter(var_to_np(outs[:,0]),
               var_to_np(outs[:,1]), s=200, alpha=0.75, label='Encodings')

    plt.legend(fontsize=14)

In [None]:
mean, std = model_and_dist.dist.get_mean_std(i_class)
i_dims = np.argsort(var_to_np(std))[::-1][:2].copy()
fig = plot_two_dim_embedding(model_and_dist, i_dims, 1, class_ins, autoscale=False)
fig = plot_two_dim_embedding(model_and_dist, i_dims=np.argsort(var_to_np(std))[::-1].copy()[[0,2]],
                             i_class=1, inputs=class_ins, autoscale=True)

In [None]:
from reversible2.graph import Node
from braindecode.torch_ext.modules import Expression
from reversible2.rfft import RFFT
from braindecode.torch_ext.optimizers import AdamW
from reversible2.scale import scale_to_unit_var
from reversible2.high_gamma import load_train_test, to_signal_target
from reversible2.scale import ScaleAndShift


train_set, valid_set = to_signal_target(train_inputs, test_inputs)
n_chans = train_set.X.shape[1]
n_classes = 2
input_time_length = train_set.X.shape[2]
n_iters = 5
dfs = []
for _ in range (n_iters):
    n_chan_pad = 0
    filter_length_time = 11
    model = create_model()
    model = Node(model, nn.Sequential(
        Expression(lambda x: x[:,:2,].unsqueeze(2)),
        ScaleAndShift(),
        Expression(lambda x: x.squeeze(2)),
            nn.LogSoftmax(dim=1)
                    ))
    #model.add_module("select_dims", Expression(lambda x: x[:,:2,]))
    #
    #model.add_module("softmax", nn.LogSoftmax(dim=1))
    from reversible2.models import WrappedModel
    model = WrappedModel(model)

    model.cuda()
    
    for module in model.network.modules():
        if hasattr(module, 'log_factor'):
            module._forward_hooks.clear()
            module.register_forward_hook(scale_to_unit_var)
    model.network(train_inputs[0].cuda());
    for module in model.network.modules():
        if hasattr(module, 'log_factor'):
            module._forward_hooks.clear()

        


    from copy import deepcopy
    model_to_train = deepcopy(model)
    lr = 1 * 1e-4
    weight_decay = 0.5 * 1e-3
    optimizer = AdamW(model_to_train.parameters(), lr=lr,
                      weight_decay=weight_decay)

    max_epochs = 50
    model_to_train.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, )
    model_to_train.fit(train_set.X, train_set.y, epochs=max_epochs, batch_size=64,
              scheduler='cosine',
              validation_data=(valid_set.X, valid_set.y), )
    dfs.append(model_to_train.epochs_df)
    


In [None]:
model.network(train_inputs[0].cuda()).shape

In [None]:
import pandas as pd
pd.concat([df.iloc[-1:] for df in dfs])

In [None]:
# plot real signal and matched example
# see how they move
# try it for another network trained on different examples and see what ot looks like then
