In [None]:
%%capture
import os
import site
os.sys.path.insert(0, '/home/schirrmr/code/reversible/reversible2/')
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')
os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')
%cd /home/schirrmr/


%load_ext autoreload
%autoreload 2
import numpy as np
import logging
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14
import seaborn
seaborn.set_style('darkgrid')

from reversible.sliced import sliced_from_samples

from numpy.random import RandomState

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import copy
import math

import itertools
from reversible.plot import create_bw_image
import torch as th
from braindecode.torch_ext.util import np_to_var, var_to_np
from reversible.revnet import ResidualBlock, invert, SubsampleSplitter, ViewAs, ReversibleBlockOld
from spectral_norm import spectral_norm
from conv_spectral_norm import conv_spectral_norm

def display_text(text, fontsize=18):
    fig = plt.figure(figsize=(12,0.1))
    plt.title(text, fontsize=fontsize)
    plt.axis('off')
    display(fig)
    plt.close(fig)

In [None]:
from braindecode.datasets.bbci import BBCIDataset
from braindecode.mne_ext.signalproc import mne_apply
def load_file(filename):
    cnt = BBCIDataset(filename).load()
    cnt = cnt.drop_channels(['STI 014'])
    def car(a):
        return a - np.mean(a, keepdims=True, axis=0)

    cnt = mne_apply(
        car, cnt)
    return cnt

In [None]:
from collections import OrderedDict
from braindecode.datautil.trial_segment import create_signal_target_from_raw_mne

def create_set(cnt):
    marker_def = OrderedDict([('Right Hand', [1]), ('Left Hand', [2],),
                             ('Rest', [3]), ('Feet', [4])])
    ival = [500,1500]
    from braindecode.mne_ext.signalproc import mne_apply, resample_cnt
    from braindecode.datautil.signalproc import exponential_running_standardize, bandpass_cnt

    log.info("Resampling train...")
    cnt = resample_cnt(cnt, 250.0)
    log.info("Standardizing train...")
    cnt = mne_apply(lambda a: exponential_running_standardize(a.T ,factor_new=1e-3, init_block_size=1000, eps=1e-4).T,
                         cnt)
    cnt = resample_cnt(cnt, 32.0)
    cnt = resample_cnt(cnt, 64.0)

    dataset = create_signal_target_from_raw_mne(cnt, marker_def, ival)
    return dataset

def create_inputs(dataset):
    x_right = dataset.X[dataset.y == 0]

    x_rest = dataset.X[dataset.y == 2]

    inputs_a = np_to_var(x_right[:160,0:1,:,None], dtype=np.float32).cuda()

    inputs_b = np_to_var(x_rest[:160,0:1,:,None], dtype=np.float32).cuda()
    inputs = [inputs_a, inputs_b]
    return inputs


In [None]:
train_cnt = load_file('/data/schirrmr/schirrmr/HGD-public/reduced/train/4.mat')
train_cnt = train_cnt.reorder_channels(['C3', 'C4'])
train_set = create_set(train_cnt)
train_inputs = create_inputs(train_set)

In [None]:
test_cnt = load_file('/data/schirrmr/schirrmr/HGD-public/reduced/test/4.mat')
test_cnt = test_cnt.reorder_channels(['C3', 'C4'])
test_set = create_set(test_cnt)
test_inputs = create_inputs(test_set)

In [None]:
fig = plt.figure(figsize=(8,4))
for i_class in range(len(train_inputs)):
    ins = var_to_np(train_inputs[i_class].squeeze())
    bps = np.abs(np.fft.rfft(ins.squeeze()))
    plt.plot(np.fft.rfftfreq(ins.squeeze().shape[1], d=1/ins.squeeze().shape[1]), np.median(bps, axis=0))

    
plt.title("Spectrum")
plt.xlabel('Frequency [Hz]')

plt.ylabel('Amplitude')
plt.legend(['Real Right', 'Fake Right', 'Real Rest', 'Fake Rest'])
display(fig)
plt.close(fig)

In [None]:
from matplotlib.lines import Line2D
plt.figure(figsize=(10,6))
for i_class in range(2):
    plt.plot(var_to_np(train_inputs[i_class].squeeze()).T, color=seaborn.color_palette()[i_class],lw=0.5);
lines = [Line2D([0], [0], color=seaborn.color_palette()[i_class],) for i_class in range(2)]
plt.legend(lines, ['Right', 'Rest',], bbox_to_anchor=(1,1,0,0))
plt.title('Input signals')

In [None]:
def rev_block(n_c, n_i_c):
     return ReversibleBlockOld(
        nn.Sequential(
        nn.Conv2d(n_c // 2, n_i_c,(3,1), stride=1, padding=(1,0),bias=True),
        nn.ReLU(),
            nn.Conv2d(n_i_c, n_c // 2,(3,1), stride=1, padding=(1,0),bias=True)),
         
        nn.Sequential(
        nn.Conv2d(n_c // 2, n_i_c,(3,1), stride=1, padding=(1,0),bias=True),
        nn.ReLU(),
            nn.Conv2d(n_i_c, n_c // 2,(3,1), stride=1, padding=(1,0),bias=True))
    )
def dense_rev_block(n_c, n_i_c):
     return ReversibleBlockOld(
        nn.Sequential(
            nn.Linear(n_c // 2, n_i_c, bias=True),
            nn.ReLU(),
            nn.Linear(n_i_c, n_c // 2,bias=True)),
        nn.Sequential(
            nn.Linear(n_c // 2, n_i_c, bias=True),
            nn.ReLU(),
            nn.Linear(n_i_c, n_c // 2, bias=True))
    )
    
def res_block(n_c, n_i_c):
     return ResidualBlock(
        nn.Sequential(
        nn.Conv2d(n_c, n_i_c, (3,1), stride=1, padding=(1,0),bias=True),
        nn.ReLU(),
            nn.Conv2d(n_i_c, n_c, (3,1), stride=1, padding=(1,0),bias=True)),
    )

In [None]:
from rfft import RFFT, Interleave

from discriminator import ProjectionDiscriminator
from reversible.revnet import SubsampleSplitter, ViewAs
from reversible.util import set_random_seeds
from reversible.revnet import init_model_params
from torch.nn import ConstantPad2d
import torch as th
from conv_spectral_norm import conv_spectral_norm
from disttransform import DistTransformResNet


set_random_seeds(2019011641, True)
feature_model = nn.Sequential(
    ViewAs((-1,1,64,1), (-1,64)),
    RFFT(),
    dense_rev_block(64,64),
    dense_rev_block(64,64),
    dense_rev_block(64,64),
    dense_rev_block(64,64),
    dense_rev_block(64,64),
    dense_rev_block(64,64),
)
feature_model.cuda()



from reversible.training import hard_init_std_mean
n_dims = train_inputs[0].shape[2]
n_clusters = len(train_inputs)
means_per_cluster = [th.autograd.Variable(th.ones(n_dims).cuda(), requires_grad=True)
                     for _ in range(n_clusters)]
# keep in mind this is in log domain so 0 is std 1
stds_per_cluster = [th.autograd.Variable(th.zeros(n_dims).cuda(), requires_grad=True)
                    for _ in range(n_clusters)]

for i_class in range(n_clusters):
    this_outs = feature_model(train_inputs[i_class])
    means_per_cluster[i_class].data = th.mean(this_outs, dim=0).view(-1).data
    stds_per_cluster[i_class].data = th.log(th.std(this_outs, dim=0),).view(-1).data



from copy import deepcopy
optimizer = th.optim.Adam(
                          [
    {'params': list(feature_model.parameters()),
    'lr': 1e-3,
    'weight_decay': 0},], betas=(0,0.9))

optim_dist = th.optim.Adam(
                          [
    {'params': means_per_cluster + stds_per_cluster,
    'lr': 1e-2,
    'weight_decay': 0},], betas=(0,0.9))


In [None]:

from reversible.gaussian import get_gauss_samples
from reversible.uniform import get_uniform_samples

from reversible.gaussian import get_gauss_samples
from reversible.uniform import get_uniform_samples
from reversible.revnet import invert 
import pandas as pd
from gradient_penalty import gradient_penalty
import time


df = pd.DataFrame()
g_loss = np_to_var([np.nan],dtype=np.float32)
g_grad = np.nan
d_loss = np_to_var([np.nan],dtype=np.float32)
d_grad = np.nan
gradient_loss = np_to_var([np.nan],dtype=np.float32)

In [None]:
def invert_hierarchical(features):
    return invert(feature_model, features)


def get_samples(n_samples, i_class):
    mean = means_per_cluster[i_class]
    std = th.exp(stds_per_cluster[i_class])
    # let's create a mask for the std for now
    samples = get_gauss_samples(n_samples, mean, std, truncate_to=3)
    return samples
import ot
from reversible.util import ensure_on_same_device, np_to_var, var_to_np

def ot_euclidean_loss_for_samples(samples_a, samples_b):
    diffs = samples_a.unsqueeze(1) - samples_b.unsqueeze(0)
    diffs = th.sqrt(th.clamp(th.sum(diffs * diffs, dim=2), min=1e-6))

    transport_mat = ot.emd([], [], var_to_np(diffs))
    # sometimes weird low values, try to prevent them
    transport_mat = transport_mat * (transport_mat > (1.0/(diffs.numel())))

    transport_mat = np_to_var(transport_mat, dtype=np.float32)
    diffs, transport_mat = ensure_on_same_device(diffs, transport_mat)
    loss = th.sum(transport_mat * diffs)
    return loss

In [None]:
n_epochs = 5001
rng = RandomState(349384)
for i_epoch in range(n_epochs):
    start_time = time.time()
    optimizer.zero_grad()
    optim_dist.zero_grad()
    for i_class in range(len(train_inputs)):
        this_inputs = train_inputs[i_class]
        n_samples = len(this_inputs) * 5
        samples = get_samples(n_samples, i_class)
        inverted = invert_hierarchical(samples)
        g_loss = ot_euclidean_loss_for_samples(this_inputs.view(this_inputs.shape[0],-1),
                              inverted.view(inverted.shape[0],-1))
        g_loss.backward()
    g_grad = np.mean([th.sum(p.grad **2).item() for p in itertools.chain(feature_model.parameters())])
    dist_grad = np.mean([th.sum(p.grad **2).item() for p in  means_per_cluster + stds_per_cluster])
    optimizer.step()
    optim_dist.step()
    with th.no_grad():
        sample_wd_row = {}
        for setname, setinputs in [('train', train_inputs), ('test', test_inputs)]:
            for i_class in range(len(setinputs)):
                this_inputs = setinputs[i_class]
                n_samples = len(this_inputs)
                samples = get_samples(n_samples, i_class)
                inverted = invert_hierarchical(samples)
                in_np = var_to_np(this_inputs).reshape(len(this_inputs), -1)
                fake_np = var_to_np(inverted).reshape(len(inverted), -1)
                import ot

                dist = np.sqrt(np.sum(np.square(in_np[:,None] - fake_np[None]), axis=2))
                match_matrix = ot.emd([],[], dist)
                cost = np.sum(dist * match_matrix)
                sample_wd_row.update({
                    setname + '_sampled_wd' + str(i_class): cost,
                })
        end_time = time.time()
        epoch_row = {
        'g_loss': g_loss.item(),
        'g_grad': g_grad,
        'dist_grad': dist_grad,
        'runtime': end_time -start_time,}
        epoch_row.update(sample_wd_row)
        df = df.append(epoch_row, ignore_index=True)
        if i_epoch % (max(1,n_epochs // 20)) == 0:
            display_text("Epoch {:d}".format(i_epoch))
            display(df.iloc[-5:])
        if i_epoch % (n_epochs // 20) == 0:
            print("stds\n", var_to_np(th.exp(th.stack(stds_per_cluster))))
            fig = plt.figure(figsize=(8,4))
            plt.plot(var_to_np(th.exp(th.stack(stds_per_cluster))).squeeze().T)
            plt.title("Standard deviation\nper dimension")
            display(fig)
            plt.close(fig)
            
            
            fig = plt.figure(figsize=(8,4))
            set_inputs = train_inputs
            for i_class in range(len(set_inputs)):
                ins = var_to_np(set_inputs[i_class].squeeze())
                bps = np.abs(np.fft.rfft(ins.squeeze()))
                plt.plot(np.fft.rfftfreq(ins.squeeze().shape[1], d=1/ins.squeeze().shape[1]), np.median(bps, axis=0))

                n_samples = 5000
                samples = get_samples(n_samples, i_class)
                inverted = var_to_np(invert_hierarchical(samples).squeeze())
                bps = np.abs(np.fft.rfft(inverted.squeeze()))
                plt.plot(np.fft.rfftfreq(inverted.squeeze().shape[1], d=1/ins.squeeze().shape[1]), np.median(bps, axis=0),
                        color=seaborn.color_palette()[i_class], ls='--')
            plt.title("Spectrum")
            plt.xlabel('Frequency [Hz]')

            plt.ylabel('Amplitude')
            plt.legend(['Real Right', 'Fake Right', 'Real Rest', 'Fake Rest'])
            display(fig)
            plt.close(fig)
            
            set_inputs = train_inputs
            for i_class in range(len(set_inputs)):
                fig = plt.figure(figsize=(5,5))
                mean = means_per_cluster[i_class]
                log_std = stds_per_cluster[i_class]
                std = th.exp(log_std)
                y = np_to_var([i_class]).cuda()
                n_samples = 5000
                samples = get_samples(n_samples, i_class)
                inverted = var_to_np(invert_hierarchical(samples).squeeze())
                plt.plot(inverted.squeeze()[:,0], inverted.squeeze()[:,1],
                         ls='', marker='o', color=seaborn.color_palette()[i_class + 2], alpha=0.5, markersize=2)
                plt.plot(var_to_np(set_inputs[i_class].squeeze())[:,0], var_to_np(set_inputs[i_class].squeeze())[:,1],
                         ls='', marker='o', color=seaborn.color_palette()[i_class])

                display(fig)
                plt.close(fig)
                fig = plt.figure(figsize=(8,3))
                plt.plot(inverted[:1000].T, color=seaborn.color_palette()[0],lw=0.5);
                display(fig)
                plt.close(fig)
                
                i_dims = np.argsort(var_to_np(stds_per_cluster[0]))[::-1][:2]

                with th.no_grad():
                    mean = means_per_cluster[i_class]
                    std = th.exp(stds_per_cluster[i_class])
                    samples = get_samples(5000, i_class)
                    outs = feature_model(set_inputs[i_class])
                fig = plt.figure(figsize=(3,3))
                plt.plot(var_to_np(samples)[:,i_dims[0]].squeeze(),
                         var_to_np(samples)[:,i_dims[1]].squeeze(), marker='o', ls='')
                plt.plot(var_to_np(outs)[:,i_dims[0]].squeeze(),
                         var_to_np(outs)[:,i_dims[1]].squeeze(), marker='o', ls='')
                plt.legend(["Fake", "Real"])
                display(fig)
                plt.close(fig)
            i_dims = (np.argsort(np.max(var_to_np(th.stack(stds_per_cluster)), axis=0))[::-1][:4])
            set_inputs = train_inputs
            for i_dim in i_dims:
                display_text("Dimension {:d}".format(i_dim))
                examples_per_class = []
                outs_per_class = []
                for i_class in range(2):
                    mean = means_per_cluster[i_class]
                    std = th.exp(stds_per_cluster[i_class])
                    i_f_vals = th.linspace((mean[i_dim] - 2 * std[i_dim]).item(),
                                           (mean[i_dim] +  2 *std[i_dim]).item(), 21)
                    examples = mean.repeat(len(i_f_vals), 1)
                    examples.data[:,i_dim] = i_f_vals.data
                    examples_per_class.append(examples)
                    outs_per_class.append(feature_model(set_inputs[i_class]))
                #display_text(["Right", "Rest"][i_class])
                fig, axes = plt.subplots(1,2, figsize=(6,3), sharex=True, sharey=True)
                for i_class in range(2):
                    from matplotlib import rcParams, cycler
                    cmap = plt.cm.coolwarm
                    N = len(examples)
                    examples = examples_per_class[i_class]
                    axes[i_class].plot(var_to_np(outs_per_class[i_class])[:,i_dim].squeeze(),
                                       var_to_np(outs_per_class[i_class])[:,i_dim].squeeze() * 0 - 0.01,
                                      ls='', marker='o', alpha=0.25, markersize=3,
                                      color=seaborn.color_palette()[i_class])
                    axes[i_class].scatter(var_to_np(examples)[:,i_dim].squeeze(),
                                          var_to_np(examples)[:,i_dim].squeeze() * 0,
                       c=cmap(np.linspace(0, 1, N)))
                    if i_class == 0:
                        axes[i_class].set_title("Latent space:")

                display(fig)
                plt.close(fig)
                with plt.rc_context({'axes.prop_cycle': cycler(color=cmap(np.linspace(0, 1, N)))}):
                    fig, axes = plt.subplots(1,2, figsize=(16,3), sharex=True, sharey=True)
                    for i_class in range(2):
                        inverted = invert_hierarchical(examples_per_class[i_class])
                        axes[i_class].plot(var_to_np(inverted).squeeze().T);
                    display(fig)
                    plt.close(fig)



In [None]:
ot_euclidean_loss_for_samples(train_inputs[1].view(-1,64), test_inputs[1].view(-1,64))

In [None]:
ot_euclidean_loss_for_samples(train_inputs[0].view(-1,64), test_inputs[0].view(-1,64))