In [None]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [None]:
%%capture
import os
import site
os.sys.path.insert(0, '/home/schirrmr/code/reversible/')
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')
os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')
%cd /home/schirrmr/


%load_ext autoreload
%autoreload 2
import numpy as np
import logging
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14
import seaborn
seaborn.set_style('darkgrid')

from reversible2.sliced import sliced_from_samples
from numpy.random import RandomState

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import copy
import math

import itertools
import torch as th
from braindecode.torch_ext.util import np_to_var, var_to_np
from reversible2.splitter import SubsampleSplitter

from reversible2.view_as import ViewAs
from reversible2.invert import invert

from reversible2.affine import AdditiveBlock

def display_text(text, fontsize=18):
    fig = plt.figure(figsize=(12,0.1))
    plt.title(text, fontsize=fontsize)
    plt.axis('off')
    display(fig)
    plt.close(fig)

In [None]:
from braindecode.datasets.bbci import BBCIDataset
from braindecode.mne_ext.signalproc import mne_apply
from collections import OrderedDict
from braindecode.datautil.trial_segment import create_signal_target_from_raw_mne

def load_file(filename):
    cnt = BBCIDataset(filename).load()
    cnt = cnt.drop_channels(['STI 014'])
    def car(a):
        return a - np.mean(a, keepdims=True, axis=0)

    cnt = mne_apply(
        car, cnt)
    return cnt


def create_set(cnt):
    marker_def = OrderedDict([('Right Hand', [1]), ('Left Hand', [2],),
                             ('Rest', [3]), ('Feet', [4])])
    ival = [500,1500]
    from braindecode.mne_ext.signalproc import mne_apply, resample_cnt
    from braindecode.datautil.signalproc import exponential_running_standardize, bandpass_cnt

    log.info("Resampling train...")
    cnt = resample_cnt(cnt, 250.0)
    log.info("Standardizing train...")
    cnt = mne_apply(lambda a: exponential_running_standardize(a.T ,factor_new=1e-3, init_block_size=1000, eps=1e-4).T,
                         cnt)
    cnt = resample_cnt(cnt, 32.0)
    cnt = resample_cnt(cnt, 64.0)

    dataset = create_signal_target_from_raw_mne(cnt, marker_def, ival)
    return dataset
def create_inputs(dataset):
    x_right = dataset.X[dataset.y == 0]

    x_rest = dataset.X[dataset.y == 2]

    inputs_a = np_to_var(x_right[:,0:1,:,None], dtype=np.float32)

    inputs_b = np_to_var(x_rest[:,0:1,:,None], dtype=np.float32)
    inputs = [inputs_a, inputs_b]
    return inputs


In [None]:
train_cnt = load_file('/data/schirrmr/schirrmr/HGD-public/reduced/train/4.mat')
train_cnt = train_cnt.reorder_channels(['C3', 'C4'])
train_set = create_set(train_cnt)
train_inputs = create_inputs(train_set)

In [None]:
test_cnt = load_file('/data/schirrmr/schirrmr/HGD-public/reduced/test/4.mat')
test_cnt = test_cnt.reorder_channels(['C3', 'C4'])
test_set = create_set(test_cnt)
test_inputs = create_inputs(test_set)

In [None]:
cuda = True
if cuda:
    train_inputs = [i.cuda() for i in train_inputs]
    test_inputs = [i.cuda() for i in test_inputs]

In [None]:
%%writefile distribution.py
import torch as th
from reversible2.gaussian import get_gauss_samples
class TwoClassDist(object):
    def __init__(self,):
        super(TwoClassDist, self).__init__()
        self.class_means = th.zeros(2, requires_grad=True)
        self.non_class_means = th.zeros(62, requires_grad=True)
        self.class_log_stds =  th.zeros(2, requires_grad=True)
        self.non_class_log_stds = th.zeros(62, requires_grad=True)
        
    def get_mean_std(self, i_class):
        device = self.class_means.device
        cur_mean = th.cat((th.zeros(i_class, device=device),
                           self.class_means[i_class:i_class+1],
                th.zeros(len(self.class_means) - i_class-1, device=device),
                          self.non_class_means))
        cur_log_std = th.cat((th.ones(i_class, device=device) * -9,
                              self.class_log_stds[i_class:i_class+1],
                th.ones(len(self.class_log_stds) - i_class-1, device=device) * -9,
                          self.non_class_log_stds))
        return cur_mean, th.exp(cur_log_std)

    def get_samples(self, i_class, n_samples):
        cur_mean, cur_std = self.get_mean_std(i_class)
        samples = get_gauss_samples(n_samples, cur_mean, cur_std)
        return samples
    
    def cuda(self):
        self.class_means.data = self.class_means.data.cuda()
        self.non_class_means.data = self.non_class_means.data.cuda()
        self.class_log_stds.data =  self.class_log_stds.data.cuda()
        self.non_class_log_stds.data = self.non_class_log_stds.data.cuda()
        return self
    
    def parameters(self):
        return [self.class_means, self.non_class_means, self.class_log_stds, self.non_class_log_stds]

In [None]:
from reversible2.blocks import dense_add_block, conv_add_block_3x3
from reversible2.rfft import RFFT, Interleave
from reversible2.util import set_random_seeds
from torch.nn import ConstantPad2d
import torch as th
from reversible2.splitter import SubsampleSplitter


set_random_seeds(2019011641, cuda)
feature_model = nn.Sequential(
    SubsampleSplitter(stride=[2,1],chunk_chans_first=False),# 2 x 32
    conv_add_block_3x3(2,32),
    conv_add_block_3x3(2,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 4 x 16
    conv_add_block_3x3(4,32),
    conv_add_block_3x3(4,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 8 x 8
    conv_add_block_3x3(8,32),
    conv_add_block_3x3(8,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 16 x 4
    conv_add_block_3x3(16,32),
    conv_add_block_3x3(16,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 32 x 2
    conv_add_block_3x3(32,32),
    conv_add_block_3x3(32,32),
    SubsampleSplitter(stride=[2,1],chunk_chans_first=True), # 64 x 1
    ViewAs((-1,64,1, 1), (-1,64)),
    dense_add_block(64,64),
    dense_add_block(64,64),
    dense_add_block(64,64),
    dense_add_block(64,64),
    dense_add_block(64,64),
    dense_add_block(64,64),
    RFFT(),
)
if cuda:
    feature_model.cuda()
device = list(feature_model.parameters())[0].device
from distribution import TwoClassDist
from reversible2.ot_exact import ot_euclidean_loss_for_samples
class_dist = TwoClassDist()
class_dist.cuda()


In [None]:
optim_model = th.optim.Adam(feature_model.parameters())
optim_dist = th.optim.Adam(class_dist.parameters(), lr=1e-2)

In [None]:
%%writefile plot.py
import torch as th
import matplotlib.pyplot as plt
import numpy as np
from reversible2.util import var_to_np
from matplotlib.patches import Ellipse
import seaborn

def plot_outs(feature_model_a, train_inputs, test_inputs, class_dist):
     # Compute dist for mean/std of encodings
    data_cls_dists = []
    for i_class in range(len(train_inputs)):
        this_class_outs = feature_model_a(train_inputs[i_class])[:,:2]
        
        data_cls_dists.append(
            th.distributions.MultivariateNormal(th.mean(this_class_outs, dim=0),
            covariance_matrix=th.diag(th.std(this_class_outs, dim=0))))
    for setname, set_inputs in (("Train", train_inputs), ("Test", test_inputs)):

        outs = [feature_model_a(ins) for ins in set_inputs]
        c_outs = [o[:,:2] for o in outs]

        c_outs_all = th.cat(c_outs)

        cls_dists = []
        for i_class in range(len(c_outs)):
            mean, std = class_dist.get_mean_std(i_class)
            cls_dists.append(
                th.distributions.MultivariateNormal(mean[:2],covariance_matrix=th.diag(std[:2])))

        preds = th.stack([cls_dists[i_cls].log_prob(c_outs_all)
                         for i_cls in range(len(cls_dists))],
                        dim=-1)

        pred_labels = np.argmax(var_to_np(preds), axis=1)

        labels = np.concatenate([np.ones(len(set_inputs[i_cls])) * i_cls 
         for i_cls in range(len(train_inputs))])

        acc = np.mean(labels == pred_labels)
        
        data_preds = th.stack([data_cls_dists[i_cls].log_prob(c_outs_all)
                         for i_cls in range(len(cls_dists))],
                        dim=-1)
        data_pred_labels = np.argmax(var_to_np(data_preds), axis=1)
        data_acc = np.mean(labels == data_pred_labels)

        print("{:s} Accuracy: {:.2f}%".format(setname, acc * 100))
        fig = plt.figure(figsize=(5,5))
        ax = plt.gca()
        for i_class in range(len(c_outs)):
            o = var_to_np(c_outs[i_class]).squeeze()
            plt.scatter(o[:,0], o[:,1], s=20, alpha=0.75)
            means = var_to_np(class_dist.class_means).copy()
            means[1-i_class] = 0
            stds = var_to_np(th.exp(class_dist.class_log_stds))
            stds[1-i_class] = 0.05
            for sigma in [0.5,1,2,3]:
                ellipse = Ellipse(means, stds[0]*sigma, stds[1]*sigma)
                ax.add_artist(ellipse)
                ellipse.set_edgecolor(seaborn.color_palette()[i_class])
                ellipse.set_facecolor("None")
        for i_class in range(len(c_outs)):
            o = var_to_np(c_outs[i_class]).squeeze()
            plt.scatter(np.mean(o[:,0]), np.mean(o[:,1]),
                       color=seaborn.color_palette()[i_class+2], s=80, marker="^")

        plt.title("{:6s} Accuracy:        {:.2f}%\n"
                  "From data mean/std: {:.2f}%".format(setname, acc * 100, data_acc * 100))
        plt.legend(("Right", "Rest", "Right Mean", "Rest Mean"))
        display(fig)
        plt.close(fig)
        
def display_close(fig):
    display(fig)
    plt.close(fig)

In [None]:
def get_th_dist(class_dist, i_class):
    pass

In [None]:
from plot import plot_outs, display_close
n_epochs = 2001
for i_epoch in range(n_epochs):
    optim_model.zero_grad()
    optim_dist.zero_grad()
    for i_class in range(len(train_inputs)):
        class_ins = train_inputs[i_class]
        samples = class_dist.get_samples(i_class, len(train_inputs[i_class]) * 5)
        inverted = invert(feature_model, samples)
        ot_loss = ot_euclidean_loss_for_samples(class_ins.squeeze(), inverted.squeeze())
        ### from here new
        outs = feature_model(class_ins)
        mean, std = class_dist.get_mean_std(i_class)
        th_dist = th.distributions.MultivariateNormal(mean, covariance_matrix=th.diag(std))
        nll_mean = th.mean(-th_dist.log_prob(outs)) * (1/64)
        # redo this
        loss = ot_loss + nll_mean
        loss.backward()
    optim_model.step()
    optim_dist.step()
    if i_epoch % (n_epochs // 20) == 0:
        print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
        print("Loss: {:E}".format(loss.item()))
        print("OT Loss: {:E}".format(ot_loss.item()))
        print("NLL Loss: {:E}".format(nll_mean.item()))
        plot_outs(feature_model, train_inputs, test_inputs,
                 class_dist)
        fig = plt.figure(figsize=(8,2))
        plt.plot(var_to_np(th.cat((th.exp(class_dist.class_log_stds),
                                 th.exp(class_dist.non_class_log_stds)))),
                marker='o')
        display_close(fig)

In [None]:
for setname, set_inputs in (("Train", train_inputs), ("Test", test_inputs)):
    total_loss = 0
    total_spec_ot = 0
    for i_class in range(2):
        class_ins = set_inputs[i_class]
        samples = class_dist.get_samples(i_class, len(train_inputs[i_class]) * 5)
        inverted = invert(feature_model, samples)
        loss = ot_euclidean_loss_for_samples(class_ins.squeeze(), inverted.squeeze())
        total_loss += loss
        spec_ot = ot_euclidean_loss_for_samples(
            th.rfft(class_ins.squeeze(), signal_ndim=1, normalized=True).view(class_ins.shape[0],-1),
            th.rfft(inverted.squeeze(), signal_ndim=1, normalized=True).view(inverted.shape[0],-1))
        total_spec_ot += spec_ot
    print(setname + " Loss: {:.1f}".format(total_loss.item() / 2))
    print(setname + " SpecLoss: {:.1f}".format(spec_ot.item() / 2))
# Show losses between sets
for setname, set_inputs in (("Train", train_inputs),):
    total_loss = 0
    total_spec_ot = 0
    for i_class in range(2):
        class_ins = set_inputs[i_class]
        samples = class_dist.get_samples(i_class, len(train_inputs[i_class]) * 5) # take same number of samples
        inverted = test_inputs[i_class]
        loss = ot_euclidean_loss_for_samples(class_ins.squeeze(), inverted.squeeze())
        total_loss += loss
        spec_ot = ot_euclidean_loss_for_samples(
            th.rfft(class_ins.squeeze(), signal_ndim=1, normalized=True).view(class_ins.shape[0],-1),
            th.rfft(inverted.squeeze(), signal_ndim=1, normalized=True).view(inverted.shape[0],-1))
        total_spec_ot += spec_ot
    print(setname + " Loss: {:.1f}".format(total_loss.item() / 2))
    print(setname + " SpecLoss: {:.1f}".format(spec_ot.item() / 2))


In [None]:
def create_th_grid(dim_0_vals, dim_1_vals):
    curves = []
    for dim_0_val in dim_0_vals:
        this_curves = []
        for dim_1_val in dim_1_vals:
            vals = th.stack((dim_0_val, dim_1_val))
            this_curves.append(vals)
        curves.append(th.stack(this_curves))
    curves = th.stack(curves)
    return curves

In [None]:
means = th.stack([class_dist.get_mean_std(i_cls,)[0][:2]
     for i_cls in range(len(train_inputs))])
stds = th.stack([class_dist.get_mean_std(i_cls,)[1][:2]
     for i_cls in range(len(train_inputs))])
mins = th.min(means - stds * 0.5, dim=1)[0]
maxs = th.max(means + stds * 0.5, dim=1)[0]
mean_for_plot = th.mean(th.stack((mins, maxs)), dim=0)

In [None]:
#mins.data[0] = -0.2

In [None]:
n_vals = 20
dim_0_vals = th.linspace(mins[0].item(), maxs[0].item(),n_vals, device=mins.device)
dim_1_vals = th.linspace(mins[1].item(), maxs[1].item(),n_vals, device=mins.device)


grid = create_th_grid(dim_0_vals, dim_1_vals)

full_grid = th.cat((grid, class_dist.non_class_means.repeat(grid.shape[0],grid.shape[1],1)), dim=-1)

inverted = invert(feature_model, full_grid.view(-1, full_grid.shape[-1])).squeeze()

inverted_grid = inverted.view(len(dim_0_vals), len(dim_1_vals),-1)

x_len, y_len = var_to_np(maxs - mins) / full_grid.shape[:-1]

max_abs = th.max(th.abs(inverted_grid))

y_factor =  (y_len / (2*max_abs)).item() * 0.9

In [None]:
plt.figure(figsize=(32,32))

for i_x in range(full_grid.shape[0]):
    for i_y in range(full_grid.shape[1]):
        x_start = mins[0].item() + x_len * i_x + 0.1 * x_len
        x_end = mins[0].item() + x_len * i_x + 0.9 * x_len
        y_center = mins[1].item() + y_len * i_y + 0.5 * y_len
        
        curve = var_to_np(inverted_grid[i_x][i_y])
        label = ''
        if i_x == 0 and i_y == 0:
            label = 'Generated data'
        plt.plot(np.linspace(x_start, x_end, len(curve)),
                 curve * y_factor + y_center, color='black',
                label=label)
outs = [feature_model(ins) for ins in train_inputs]
c_outs = [o[:,:2] for o in outs]

plt.gca().set_autoscale_on(False)
for i_cls, c_out in enumerate(c_outs):
    plt.scatter(var_to_np(c_out[:,0]),
               var_to_np(c_out[:,1]), s=200, alpha=0.75, label='Encodings {:s}'.format(
               ["Right", "Rest"][i_cls]))
    
plt.legend(fontsize=14)

In [None]:
fig = plt.figure(figsize=(8,2))
for i_class in range(2):
    cur_mean, cur_std = class_dist.get_mean_std(i_class, )
    inverted = invert(feature_model, cur_mean.unsqueeze(0))
    plt.plot(var_to_np(inverted.squeeze()))
plt.legend(("Right Hand", "Rest"))
display(fig)
plt.close(fig)

In [None]:
from plot import display_close

In [None]:
mean_0, _ = get_mean_std(0,class_means, non_class_means,
             class_log_stds, non_class_log_stds)
mean_1, _ = get_mean_std(1,class_means, non_class_means,
             class_log_stds, non_class_log_stds)

n_interpolates = 100
alphas = th.linspace(1,0,n_interpolates, device=mean_0.device)

interpolates = mean_0.unsqueeze(0) * alphas.unsqueeze(1) + mean_1.unsqueeze(0) * (1-alphas.unsqueeze(1))
inverted = invert(feature_model,interpolates, )

from matplotlib import rcParams, cycler
cmap = plt.cm.coolwarm
with plt.rc_context(rc={'axes.prop_cycle':cycler(color=cmap(np.linspace(0, 1, len(inverted))))}):
    fig = plt.figure(figsize=(8,2))
    plt.plot(var_to_np(inverted).squeeze().T,);
    display_close(fig)