In [None]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [None]:
%%capture
import os
import site
os.sys.path.insert(0, '/home/schirrmr/code/reversible/')
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')
os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')


%load_ext autoreload
%autoreload 2
import numpy as np
import logging
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14
import seaborn
seaborn.set_style('darkgrid')

from reversible2.sliced import sliced_from_samples
from numpy.random import RandomState

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import copy
import math

import itertools
import torch as th
from braindecode.torch_ext.util import np_to_var, var_to_np
from reversible2.splitter import SubsampleSplitter

from reversible2.view_as import ViewAs
from reversible2.invert import invert

from reversible2.affine import AdditiveBlock
from reversible2.plot import display_text, display_close
from reversible2.bhno import load_file, create_inputs

In [None]:
# 2 gaussians different stds and means.. or uniform dims..
# train, show what would happen for different visualizations


In [None]:
rng = RandomState(20190416)

XA = rng.rand(30,2) * np.array([0.5,2])[None] + np.array([-1,-1])[None]
XB = rng.rand(30,2) * np.array([3,0.5])[None] + np.array([1,1])[None]

In [None]:
plt.figure(figsize=(5,5))
plt.scatter(XA[:,0], XA[:,1])
plt.scatter(XB[:,0], XB[:,1])
plt.xlim(-2.5,5.5)
plt.ylim(-4,4)

In [None]:
train_inputs = np_to_var(XA, dtype=np.float32), np_to_var(XB, dtype=np.float32)

In [None]:
cuda = False

In [None]:
from reversible2.distribution import TwoClassDist

from reversible2.blocks import dense_add_block, conv_add_block_3x3
from reversible2.rfft import RFFT, Interleave
from reversible2.util import set_random_seeds
from torch.nn import ConstantPad2d
import torch as th
from reversible2.splitter import SubsampleSplitter


set_random_seeds(2019011641, cuda)
feature_model_a = nn.Sequential(
    dense_add_block(2,200),
    dense_add_block(2,200),
    dense_add_block(2,200),
    dense_add_block(2,200),
)
feature_model_b = nn.Sequential(
    dense_add_block(2,200),
    dense_add_block(2,200),
    dense_add_block(2,200),
    dense_add_block(2,200),
)
if cuda:
    feature_model_a.cuda()
    feature_model_b.cuda()
from reversible2.ot_exact import ot_euclidean_loss_for_samples
class_dist = TwoClassDist(2,0)
if cuda:
    class_dist.cuda()

optim_model_a = th.optim.Adam(feature_model_a.parameters())
optim_model_b = th.optim.Adam(feature_model_b.parameters())
optim_dist = th.optim.Adam(class_dist.parameters(), lr=1e-3)

In [None]:
n_epochs = 2001
for i_epoch in range(n_epochs):
    for (net_a, optim_a), (net_b, optim_b) in (((feature_model_a, optim_model_a), (feature_model_b, optim_model_b)),):
        for i_class in range(len(train_inputs)):
            optim_a.zero_grad()
            optim_b.zero_grad()
            class_ins = train_inputs[i_class]
            outs_r_a = net_a(class_ins)
            outs_r_b = net_b(class_ins)
            p_real_a = th.exp(th.mean(class_dist.get_total_log_prob(i_class, outs_r_a)))
            p_real_b = th.exp(th.mean(class_dist.get_total_log_prob(i_class, outs_r_b)))
            # Create samples
            inv_a = invert(net_a, class_dist.get_samples(i_class, len(train_inputs[i_class]) * 5))
            inv_b = invert(net_b, class_dist.get_samples(i_class, len(train_inputs[i_class]) * 5))
            # Compute probs with opposite network
            outs_f_b = net_b(inv_a)
            p_a = th.exp(th.mean(class_dist.get_total_log_prob(i_class, outs_f_b)))
            outs_f_a = net_a(inv_b)
            p_b = th.exp(th.mean(class_dist.get_total_log_prob(i_class, outs_f_a)))
            
            
            #p_b.backward() # p b should get higher for net b!
            #_ = [p.grad.data.neg_() for p in net_b.parameters()]
            #old_vals = [p.grad.data for p in net_a.parameters()]
            #_ = [p.grad.data.zero_() for p in net_a.parameters()]
            #p_a.backward() # p a should get higher for net a
            #_ = [p.grad.data.neg_() for p in net_a.parameters()]
            #_ = [p.grad.data.add_(old_val) for p, old_val in zip(net_a.parameters(), old_vals)]
            real_loss=  10*(-p_real_a - p_real_b)
            optim_dist.zero_grad()
            real_loss.backward()
            optim_a.step()
            optim_b.step()
            optim_dist.step()
    if i_epoch % (n_epochs // 20) == 0:
        display_text("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
        display_text("PrA, PrB, PfA, PfB {:.1E} {:.1E} {:.1E} {:.1E}".format(
            p_real_a.item(), p_real_b.item(), p_a.item(), p_b.item()))
        fig = plt.figure(figsize=(5,5))
        for i_class in range(len(train_inputs)):
            samples = class_dist.get_samples(i_class, len(train_inputs[i_class]) * 5)
            inverted = invert(feature_model_a, samples)
            sample_loss = ot_euclidean_loss_for_samples(train_inputs[i_class].squeeze(), inverted.squeeze())
            plt.scatter(var_to_np(inverted)[:,0], var_to_np(inverted)[:,1])
        plt.title("Samples\OT: {:.2E}".format(sample_loss.item()) )
        display_close(fig)
        fig = plt.figure(figsize=(5,5))
        for i_class in range(len(train_inputs)):
            outs = feature_model_a(train_inputs[i_class])

            plt.scatter(var_to_np(outs)[:,0], var_to_np(outs)[:,1])
        plt.title("Outs\OT: {:.2E}".format(sample_loss.item()) )
        display_close(fig)

In [None]:
n_epochs = 2001
for i_epoch in range(n_epochs):
    for (net_a, optim_a), (net_b, optim_b) in (((feature_model_a, optim_model_a), (feature_model_b, optim_model_b)),):
        for i_class in range(len(train_inputs)):
            optim_a.zero_grad()
            optim_b.zero_grad()
            class_ins = train_inputs[i_class]
            outs_r_a = net_a(class_ins)
            outs_r_b = net_b(class_ins)
            p_real_a = th.exp(th.mean(class_dist.get_total_log_prob(i_class, outs_r_a)))
            p_real_b = th.exp(th.mean(class_dist.get_total_log_prob(i_class, outs_r_b)))
            # Create samples
            inv_a = invert(net_a, class_dist.get_samples(i_class, len(train_inputs[i_class]) * 5))
            inv_b = invert(net_b, class_dist.get_samples(i_class, len(train_inputs[i_class]) * 5))
            # Compute probs with opposite network
            outs_f_b = net_b(inv_a)
            p_a = th.exp(th.mean(class_dist.get_total_log_prob(i_class, outs_f_b)))
            outs_f_a = net_a(inv_b)
            p_b = th.exp(th.mean(class_dist.get_total_log_prob(i_class, outs_f_a)))
            
            
            #p_b.backward() # p b should get higher for net b!
            #_ = [p.grad.data.neg_() for p in net_b.parameters()]
            #old_vals = [p.grad.data for p in net_a.parameters()]
            #_ = [p.grad.data.zero_() for p in net_a.parameters()]
            #p_a.backward() # p a should get higher for net a
            #_ = [p.grad.data.neg_() for p in net_a.parameters()]
            #_ = [p.grad.data.add_(old_val) for p, old_val in zip(net_a.parameters(), old_vals)]
            real_loss=  10*(-p_real_a - p_real_b)
            optim_dist.zero_grad()
            real_loss.backward()
            optim_a.step()
            optim_b.step()
            optim_dist.step()
    if i_epoch % (n_epochs // 20) == 0:
        display_text("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
        display_text("PrA, PrB, PfA, PfB {:.1E} {:.1E} {:.1E} {:.1E}".format(
            p_real_a.item(), p_real_b.item(), p_a.item(), p_b.item()))
        fig = plt.figure(figsize=(5,5))
        for i_class in range(len(train_inputs)):
            samples = class_dist.get_samples(i_class, len(train_inputs[i_class]) * 5)
            inverted = invert(feature_model_a, samples)
            sample_loss = ot_euclidean_loss_for_samples(train_inputs[i_class].squeeze(), inverted.squeeze())
            plt.scatter(var_to_np(inverted)[:,0], var_to_np(inverted)[:,1])
        plt.title("Samples\OT: {:.2E}".format(sample_loss.item()) )
        display_close(fig)
        fig = plt.figure(figsize=(5,5))
        for i_class in range(len(train_inputs)):
            outs = feature_model_a(train_inputs[i_class])

            plt.scatter(var_to_np(outs)[:,0], var_to_np(outs)[:,1])
        plt.title("Outs\OT: {:.2E}".format(sample_loss.item()) )
        display_close(fig)

In [None]:
# asymmetric case: B tries to increase likelihood of real data, tries to reduce likelihood of fake data
# then A would try to increase likelihood of his fake data
# What would happen?

In [None]:
class_dist.get_mean_std(0)

In [None]:
    for i_class in range(len(train_inputs)):
        
        other_class_ins = train_inputs[1-i_class]
        outs = feature_model(other_class_ins)
        changed_outs = class_dist.change_to_other_class(outs, i_class_from=1-i_class, i_class_to=i_class)
        changed_inverted = invert(feature_model, changed_outs)
        transform_class_loss = ot_euclidean_loss_for_samples(class_ins.squeeze(), changed_inverted.squeeze())
        loss = sample_loss + transform_class_loss
        loss.backward()
    optim_model.step()
    optim_dist.step()
    if i_epoch % (n_epochs // 20) == 0:
        class_ins = train_inputs[i_class]
        samples = class_dist.get_samples(i_class, len(train_inputs[i_class]) * 5)
        inverted = invert(feature_model, samples)
        loss = ot_euclidean_loss_for_samples(class_ins.squeeze(), inverted.squeeze())
        display_text("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
        fig = plt.figure(figsize=(5,5))
        for i_class in range(len(train_inputs)):
            samples = class_dist.get_samples(i_class, len(train_inputs[i_class]) * 5)
            inverted = invert(feature_model, samples)

            plt.scatter(var_to_np(inverted)[:,0], var_to_np(inverted)[:,1])
        plt.title("Samples\nLoss: {:.2E}".format(loss.item()) )
        display_close(fig)
        fig = plt.figure(figsize=(5,5))
        for i_class in range(len(train_inputs)):
            outs = feature_model(train_inputs[i_class])

            plt.scatter(var_to_np(outs)[:,0], var_to_np(outs)[:,1])
        plt.title("Outs\nLoss: {:.2E}".format(loss.item()) )
        display_close(fig)