In [None]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [None]:
%%capture
import os
import site
os.sys.path.insert(0, '/home/schirrmr/code/reversible/')
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')
os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')


%load_ext autoreload
%autoreload 2
import numpy as np
import logging
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14
import seaborn
seaborn.set_style('darkgrid')

from reversible2.sliced import sliced_from_samples
from numpy.random import RandomState

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import copy
import math

import itertools
import torch as th
from braindecode.torch_ext.util import np_to_var, var_to_np
from reversible2.splitter import SubsampleSplitter

from reversible2.view_as import ViewAs
from reversible2.invert import invert

from reversible2.affine import AdditiveBlock
from reversible2.plot import display_text, display_close
from reversible2.bhno import load_file, create_inputs

In [None]:
rng = RandomState(201904113)#2 ganz gut # 13 sehr gut)

X = rng.randn(5,2) * np.array([1,0])[None] + np.array([-1,0])[None]


plt.figure(figsize=(5,5))
plt.scatter(X[:,0], X[:,1])
plt.scatter([-1],[0], color='black')
plt.xlim(-2.5,5.5)
plt.ylim(-4,4)

In [None]:
import sklearn.datasets

In [None]:
X,y  = sklearn.datasets.make_moons(200, shuffle=False, noise=1e-4)

In [None]:
plt.figure(figsize=(4,4))
plt.scatter(X[:100,0], X[:100,1])
plt.scatter(X[100:,0], X[100:,1])
train_inputs = np_to_var(X[:100][::2], dtype=np.float32)
valid_inputs = np_to_var(X[:100][1::2], dtype=np.float32)
cuda = False

In [None]:
from reversible2.distribution import TwoClassDist
from reversible2.spectral_norm import spectral_norm
from reversible2.blocks import dense_add_block, conv_add_block_3x3
from reversible2.rfft import RFFT, Interleave
from reversible2.util import set_random_seeds
from torch.nn import ConstantPad2d
import torch as th
from reversible2.splitter import SubsampleSplitter
from copy import deepcopy


set_random_seeds(2019011641, cuda)
feature_model_a = nn.Sequential(
    dense_add_block(2,200),
    dense_add_block(2,200),
    dense_add_block(2,200),
    dense_add_block(2,200),
)
if cuda:
    feature_model_a.cuda()
adv_model = deepcopy(feature_model_a)
from reversible2.ot_exact import ot_euclidean_loss_for_samples
class_dist = TwoClassDist(2,0, [0,1])
if cuda:
    class_dist.cuda()
spec_norms_adv = []
for module in adv_model.modules():
    if hasattr(module, 'weight'):
        norm = th.ones(1, device=module.weight.device, requires_grad=True)
        spectral_norm(module, 'weight', norm, n_power_iterations=3)
        spec_norms_adv.append(norm)

optim_model_a = th.optim.Adam(feature_model_a.parameters())
optim_adv = th.optim.Adam(adv_model.parameters())
optim_dist = th.optim.Adam(class_dist.parameters(), lr=1e-2)
optim_norms = th.optim.Adam(spec_norms_adv, lr=1e-3)

In [None]:
from reversible.gaussian import get_gauss_samples
n_epochs = 2001
gen_frequency = 2
for i_epoch in range(n_epochs):
    samples = class_dist.get_samples(0, 100)
    inverted = invert(feature_model_a, samples)
    outs_fake = adv_model(inverted)
    # no grad through dist
    log_probs_fake = deepcopy(class_dist).get_class_log_prob(0, outs_fake)
    nll_fake = -th.mean(log_probs_fake)
    if (i_epoch % gen_frequency) == (gen_frequency - 1):
        loss = nll_fake
        optim_model_a.zero_grad()
        loss.backward()
        optim_model_a.step()
    else:
        outs_real = adv_model(train_inputs)
        log_probs_real = class_dist.get_class_log_prob(0, outs_real)
        nll_real = -th.mean(log_probs_real)
        loss = 2 * nll_real - nll_fake
        optim_adv.zero_grad()
        optim_dist.zero_grad()
        loss.backward()
        optim_adv.step()
        optim_dist.step()
        
        outs_valid =  adv_model(valid_inputs)
        log_probs_valid = class_dist.get_class_log_prob(0, outs_valid)
        nll_valid= -th.mean(log_probs_valid)
        loss = nll_valid
        optim_norms.zero_grad()
        loss.backward()
        optim_norms.step()
        
    
    if i_epoch % (n_epochs // 20) == 0:
        display_text("NLL Real: {:.1E}, NLL Valid: {:.1E} NLL Fake: {:.1E}".format(
            nll_real.item(), nll_valid.item(), nll_fake.item()))

        fig,axes = plt.subplots(1,2, figsize=(10,4))
        model = feature_model_a
        rng = RandomState(201904114)
        outs = model(train_inputs)
        other_X  = sklearn.datasets.make_moons(200, shuffle=False, noise=1e-4)[0][:100]
        other_ins = np_to_var(other_X, dtype=np.float32)
        other_outs = model(other_ins)

        axes[0].plot(var_to_np(other_outs[:,0]), var_to_np(other_outs[:,1]), label="All Outputs",
                color=seaborn.color_palette()[1])
        axes[0].scatter(var_to_np(outs[:,0]), var_to_np(outs[:,1]), s=30, c=[seaborn.color_palette()[0]],
                   label="Actual data outputs")
        axes[0].axis('equal')
        axes[0].set_title("Output space")
        plt.axis('equal')
        samples = class_dist.get_samples(0, 100)
        inverted = invert(feature_model_a, samples)
        axes[1].scatter(var_to_np(inverted)[:,0], var_to_np(inverted)[:,1], s=30, label="Fake/Unknown Samples",
                   c=[seaborn.color_palette()[1]])
        axes[1].scatter(var_to_np(train_inputs)[:,0], var_to_np(train_inputs)[:,1], s=30, label="Real data",
                   c=[seaborn.color_palette()[0]])
        axes[1].legend(bbox_to_anchor=(1,1,0,0))
        axes[1].set_title("Input space")
        axes[1].axis('equal')
        display_close(fig)
        fig = plt.figure(figsize=(8,2))
        plt.plot(var_to_np(th.stack(spec_norms_adv)))
        display_close(fig)

In [None]:
nll_valid