In [None]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [None]:
%%capture
import os
import site
os.sys.path.insert(0, '/home/schirrmr/code/reversible/')
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')
os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')


%load_ext autoreload
%autoreload 2
import numpy as np
import logging
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 1.0)
matplotlib.rcParams['font.size'] = 14
import seaborn
seaborn.set_style('darkgrid')

from reversible2.sliced import sliced_from_samples
from numpy.random import RandomState

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import copy
import math

import itertools
import torch as th
from braindecode.torch_ext.util import np_to_var, var_to_np
from reversible2.splitter import SubsampleSplitter

from reversible2.view_as import ViewAs
from reversible2.invert import invert

from reversible2.affine import AdditiveBlock
from reversible2.plot import display_text, display_close
from reversible2.bhno import load_file, create_inputs

In [None]:
import sklearn.datasets
X,y  = sklearn.datasets.make_moons(100, shuffle=False, noise=1e-4)
plt.figure(figsize=(4,4))
plt.scatter(X[:50,0], X[:50,1])
plt.scatter(X[50:,0], X[50:,1])
train_inputs = np_to_var(X[:50], dtype=np.float32)
cuda = False

In [None]:
from reversible2.distribution import TwoClassDist

from reversible2.blocks import dense_add_block, conv_add_block_3x3
from reversible2.rfft import RFFT, Interleave
from reversible2.util import set_random_seeds
from torch.nn import ConstantPad2d
import torch as th
from reversible2.splitter import SubsampleSplitter


set_random_seeds(2019011641, cuda)
feature_model_a = nn.Sequential(
    dense_add_block(2,200),
    dense_add_block(2,200),
    dense_add_block(2,200),
    dense_add_block(2,200),
)
mixture_model = nn.Sequential(
    dense_add_block(2,200),
    dense_add_block(2,200),
    dense_add_block(2,200),
    dense_add_block(2,200),
)
if cuda:
    feature_model_a.cuda()
from reversible2.ot_exact import ot_euclidean_loss_for_samples
class_dist = TwoClassDist(2,0, [0,1])
if cuda:
    class_dist.cuda()

means_per_point = [th.zeros(2,requires_grad=True) for _ in train_inputs]
log_stds_per_point = [th.zeros(2,requires_grad=True) for _ in train_inputs]

for m,l in zip(means_per_point, log_stds_per_point):
    m.data[:] = th.randn(2) * 0.05
    l.data[:] = -2
    
optim_model_a = th.optim.Adam(feature_model_a.parameters())
optim_mixture_model = th.optim.Adam(mixture_model.parameters())
optim_dist = th.optim.Adam(class_dist.parameters(), lr=1e-2)
optim_mixture_dist = th.optim.Adam(means_per_point + log_stds_per_point, 
                                  lr=1e-2)

In [None]:
n_epochs = 2001
eps = 1e-7
for i_epoch in range(n_epochs):
    point_dists = [th.distributions.MultivariateNormal(m, covariance_matrix=th.diag(th.exp(s) * th.exp(s)))
                   for m, s in zip(means_per_point, log_stds_per_point)]

    out_samples = class_dist.get_samples(0,100,)
    
    mixture_samples = th.cat([(th.randn(2,len(m)) * th.exp(l).unsqueeze(0)) + m.unsqueeze(0)
                          for m,l in zip(means_per_point, log_stds_per_point)])
    in_samples_mixture = mixture_model(mixture_samples)
    out_mixture_samples = feature_model_a(in_samples_mixture)
    
    all_out_samples = th.cat((out_samples, out_mixture_samples))

    prior_log_probs = class_dist.get_total_log_prob(0,all_out_samples)

    inverted = invert(feature_model_a, out_samples)
    inverted_to_mixture = invert(mixture_model, inverted)
    all_pre_in_samples = th.cat((inverted_to_mixture, mixture_samples))


    # jetzt likelihood inverted point ausrechnen
    log_probs_per_dist = th.stack([dist.log_prob(all_pre_in_samples) for dist in point_dists], 0)

    mixture_probs = th.mean(th.exp(log_probs_per_dist) + eps, dim=0)
    mixture_log_probs = th.log(mixture_probs)
    prior_probs = th.exp(prior_log_probs)
    sym_kl_div = -0.5 * th.sum(prior_probs * (mixture_log_probs - prior_log_probs)) - (
        0.5 * th.sum(mixture_probs * (prior_log_probs - mixture_log_probs)))
    in_diffs = in_samples_mixture.view(len(train_inputs), -1, train_inputs.shape[1]) - train_inputs.unsqueeze(1)
    OT_in = th.mean(th.norm(in_diffs,p=2, dim=-1))
    out_real = feature_model_a(train_inputs)
    out_diffs = out_mixture_samples.view(len(train_inputs), -1, out_real.shape[1]) - out_real.unsqueeze(1)
    OT_out = th.mean(th.norm(out_diffs,p=2, dim=-1))
    
    loss =  OT_in * 100 + OT_out * 100 + sym_kl_div
    optim_mixture_dist.zero_grad()
    optim_model_a.zero_grad()
    optim_dist.zero_grad()
    optim_mixture_model.zero_grad()
    loss.backward()
    optim_mixture_dist.step()
    optim_model_a.step()
    optim_dist.step()
    optim_mixture_model.step()
    if i_epoch % (n_epochs // 20) == 0:
        print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
        print("JS: {:.1E}".format(sym_kl_div.item()))
        print("OT: {:.1E}".format((OT_in + OT_out).item()))
        print("Mean std {:.1E}".format(th.mean(th.stack([th.exp(l) for l in log_stds_per_point]))))
        print("mean and std prior",class_dist.get_mean_std(0))
        

        with th. no_grad():
            out_samples = class_dist.get_samples(0,100,)
            inverted = invert(feature_model_a, out_samples)
            mixture_samples = th.cat([dist.sample((2,)) for dist in point_dists])
            in_mixture_samples = mixture_model(mixture_samples)
            out_mixture_samples = feature_model_a(in_mixture_samples)
            out_real = feature_model_a(train_inputs)
            
            # regenerate  transformed circles
            radians = np.linspace(0,2*np.pi,24)
            circle_points = np.stack([np.cos(radians), np.sin(radians)], axis=-1)
            circle_th = np_to_var(circle_points, device=train_inputs.device, dtype=np.float32)
            ms = th.stack(means_per_point)
            stds = th.exp(th.stack(log_stds_per_point))
            circles_per_point = ms.unsqueeze(1) + (circle_th.unsqueeze(0) * stds.unsqueeze(1))
            in_circles = mixture_model(circles_per_point.view(-1, circles_per_point.shape[-1]))
            out_circles = feature_model_a(in_circles)
            in_circles = in_circles.view(circles_per_point.shape)
            out_circles= out_circles.view(circles_per_point.shape)


        fig, axes = plt.subplots(1,2, figsize=(8,4))
        axes[0].scatter(var_to_np(inverted)[:,0], var_to_np(inverted)[:,1],
                       alpha=0.75)
        axes[0].scatter(var_to_np(in_mixture_samples)[:,0], var_to_np(in_mixture_samples)[:,1],
                       alpha=0.75)
        axes[0].scatter(var_to_np(train_inputs)[:,0], var_to_np(train_inputs)[:,1])
        for c in var_to_np(in_circles):
            axes[0].plot(c[:,0], c[:,1],color='black')
        axes[0].set_title("Input space")
        axes[0].axis('equal')
        axes[1].scatter(var_to_np(out_samples)[:,0], var_to_np(out_samples)[:,1],
                       alpha=0.75,
                       label='Latent Prior')
        axes[1].scatter(var_to_np(out_mixture_samples)[:,0], var_to_np(out_mixture_samples)[:,1],
                       alpha=0.75,
                       label='Mixture')
        axes[1].scatter(var_to_np(out_real)[:,0], var_to_np(out_real)[:,1], label='Real data')
        for c in var_to_np(out_circles):
            axes[1].plot(c[:,0], c[:,1],color='black')
        axes[1].set_title("Output space")
        axes[1].axis('equal')
        axes[1].legend(bbox_to_anchor=(1,1,0,0))
        display_close(fig)
        



In [None]:
n_epochs = 20001
eps = 1e-7
for i_epoch in range(n_epochs):
    point_dists = [th.distributions.MultivariateNormal(m, covariance_matrix=th.diag(th.exp(s) * th.exp(s)))
                   for m, s in zip(means_per_point, log_stds_per_point)]

    out_samples = class_dist.get_samples(0,100,)
    
    mixture_samples = th.cat([(th.randn(2,len(m)) * th.exp(l).unsqueeze(0)) + m.unsqueeze(0)
                          for m,l in zip(means_per_point, log_stds_per_point)])
    in_samples_mixture = mixture_model(mixture_samples)
    out_mixture_samples = feature_model_a(in_samples_mixture)
    
    all_out_samples = th.cat((out_samples, out_mixture_samples))

    prior_log_probs = class_dist.get_total_log_prob(0,all_out_samples)

    inverted = invert(feature_model_a, out_samples)
    inverted_to_mixture = invert(mixture_model, inverted)
    all_pre_in_samples = th.cat((inverted_to_mixture, mixture_samples))


    # jetzt likelihood inverted point ausrechnen
    log_probs_per_dist = th.stack([dist.log_prob(all_pre_in_samples) for dist in point_dists], 0)

    mixture_probs = th.mean(th.exp(log_probs_per_dist) + eps, dim=0)
    mixture_log_probs = th.log(mixture_probs)
    prior_probs = th.exp(prior_log_probs)
    sym_kl_div = -0.5 * th.sum(prior_probs * (mixture_log_probs - prior_log_probs)) - (
        0.5 * th.sum(mixture_probs * (prior_log_probs - mixture_log_probs)))
    in_diffs = in_samples_mixture.view(len(train_inputs), -1, train_inputs.shape[1]) - train_inputs.unsqueeze(1)
    OT_in = th.mean(th.norm(in_diffs,p=2, dim=-1))
    out_real = feature_model_a(train_inputs)
    out_diffs = out_mixture_samples.view(len(train_inputs), -1, out_real.shape[1]) - out_real.unsqueeze(1)
    OT_out = th.mean(th.norm(out_diffs,p=2, dim=-1))
    
    loss =  OT_in * 100 + OT_out * 100 + sym_kl_div
    optim_mixture_dist.zero_grad()
    optim_model_a.zero_grad()
    optim_dist.zero_grad()
    optim_mixture_model.zero_grad()
    loss.backward()
    optim_mixture_dist.step()
    optim_model_a.step()
    optim_dist.step()
    optim_mixture_model.step()
    if i_epoch % (n_epochs // 20) == 0:
        print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
        print("JS: {:.1E}".format(sym_kl_div.item()))
        print("OT: {:.1E}".format((OT_in + OT_out).item()))
        print("Mean std {:.1E}".format(th.mean(th.stack([th.exp(l) for l in log_stds_per_point]))))
        print("mean and std prior",class_dist.get_mean_std(0))
        

        with th. no_grad():
            out_samples = class_dist.get_samples(0,100,)
            inverted = invert(feature_model_a, out_samples)
            mixture_samples = th.cat([dist.sample((2,)) for dist in point_dists])
            in_mixture_samples = mixture_model(mixture_samples)
            out_mixture_samples = feature_model_a(in_mixture_samples)
            out_real = feature_model_a(train_inputs)
            
            # regenerate  transformed circles
            radians = np.linspace(0,2*np.pi,24)
            circle_points = np.stack([np.cos(radians), np.sin(radians)], axis=-1)
            circle_th = np_to_var(circle_points, device=train_inputs.device, dtype=np.float32)
            ms = th.stack(means_per_point)
            stds = th.exp(th.stack(log_stds_per_point))
            circles_per_point = ms.unsqueeze(1) + (circle_th.unsqueeze(0) * stds.unsqueeze(1))
            in_circles = mixture_model(circles_per_point.view(-1, circles_per_point.shape[-1]))
            out_circles = feature_model_a(in_circles)
            in_circles = in_circles.view(circles_per_point.shape)
            out_circles= out_circles.view(circles_per_point.shape)


        fig, axes = plt.subplots(1,2, figsize=(8,4))
        axes[0].scatter(var_to_np(inverted)[:,0], var_to_np(inverted)[:,1],
                       alpha=0.75)
        axes[0].scatter(var_to_np(in_mixture_samples)[:,0], var_to_np(in_mixture_samples)[:,1],
                       alpha=0.75)
        axes[0].scatter(var_to_np(train_inputs)[:,0], var_to_np(train_inputs)[:,1])
        for c in var_to_np(in_circles):
            axes[0].plot(c[:,0], c[:,1],color='black')
        axes[0].set_title("Input space")
        axes[0].axis('equal')
        axes[1].scatter(var_to_np(out_samples)[:,0], var_to_np(out_samples)[:,1],
                       alpha=0.75,
                       label='Latent Prior')
        axes[1].scatter(var_to_np(out_mixture_samples)[:,0], var_to_np(out_mixture_samples)[:,1],
                       alpha=0.75,
                       label='Mixture')
        axes[1].scatter(var_to_np(out_real)[:,0], var_to_np(out_real)[:,1], label='Real data')
        for c in var_to_np(out_circles):
            axes[1].plot(c[:,0], c[:,1],color='black')
        axes[1].set_title("Output space")
        axes[1].axis('equal')
        axes[1].legend(bbox_to_anchor=(1,1,0,0))
        display_close(fig)
        



### Only OT for a Test

In [None]:
n_epochs = 2001
eps = 1e-7
for i_epoch in range(n_epochs):
    point_dists = [th.distributions.MultivariateNormal(m, covariance_matrix=th.diag(th.exp(s) * th.exp(s)))
                   for m, s in zip(means_per_point, log_stds_per_point)]

    out_samples = class_dist.get_samples(0,100,)
    
    mixture_samples = th.cat([(th.randn(2,len(m)) * th.exp(l).unsqueeze(0)) + m.unsqueeze(0)
                          for m,l in zip(means_per_point, log_stds_per_point)])
    in_samples_mixture = mixture_model(mixture_samples)
    out_mixture_samples = feature_model_a(in_samples_mixture)
    
    all_out_samples = th.cat((out_samples, out_mixture_samples))

    prior_log_probs = class_dist.get_total_log_prob(0,all_out_samples)

    inverted = invert(feature_model_a, out_samples)
    inverted_to_mixture = invert(mixture_model, inverted)
    all_pre_in_samples = th.cat((inverted_to_mixture, mixture_samples))


    # jetzt likelihood inverted point ausrechnen
    log_probs_per_dist = th.stack([dist.log_prob(all_pre_in_samples) for dist in point_dists], 0)

    mixture_probs = th.mean(th.exp(log_probs_per_dist) + eps, dim=0)
    mixture_log_probs = th.log(mixture_probs)
    prior_probs = th.exp(prior_log_probs)
    sym_kl_div = -0.5 * th.sum(prior_probs * (mixture_log_probs - prior_log_probs)) - (
        0.5 * th.sum(mixture_probs * (prior_log_probs - mixture_log_probs)))
    in_diffs = in_samples_mixture.view(len(train_inputs), -1, train_inputs.shape[1]) - train_inputs.unsqueeze(1)
    OT_in = th.mean(th.norm(in_diffs,p=2, dim=-1))
    out_real = feature_model_a(train_inputs)
    out_diffs = out_mixture_samples.view(len(train_inputs), -1, out_real.shape[1]) - out_real.unsqueeze(1)
    OT_out = th.mean(th.norm(out_diffs,p=2, dim=-1))
    
    loss =  OT_in * 100 + OT_out * 100 #+ sym_kl_div
    optim_mixture_dist.zero_grad()
    optim_model_a.zero_grad()
    optim_dist.zero_grad()
    optim_mixture_model.zero_grad()
    loss.backward()
    optim_mixture_dist.step()
    optim_model_a.step()
    optim_dist.step()
    optim_mixture_model.step()
    if i_epoch % (n_epochs // 20) == 0:
        print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
        print("JS: {:.1E}".format(sym_kl_div.item()))
        print("OT: {:.1E}".format((OT_in + OT_out).item()))
        print("Mean std {:.1E}".format(th.mean(th.stack([th.exp(l) for l in log_stds_per_point]))))
        print("mean and std prior",class_dist.get_mean_std(0))
        

        with th. no_grad():
            out_samples = class_dist.get_samples(0,100,)
            inverted = invert(feature_model_a, out_samples)
            mixture_samples = th.cat([dist.sample((2,)) for dist in point_dists])
            in_mixture_samples = mixture_model(mixture_samples)
            out_mixture_samples = feature_model_a(in_mixture_samples)
            out_real = feature_model_a(train_inputs)
            
            # regenerate  transformed circles
            radians = np.linspace(0,2*np.pi,24)
            circle_points = np.stack([np.cos(radians), np.sin(radians)], axis=-1)
            circle_th = np_to_var(circle_points, device=train_inputs.device, dtype=np.float32)
            ms = th.stack(means_per_point)
            stds = th.exp(th.stack(log_stds_per_point))
            circles_per_point = ms.unsqueeze(1) + (circle_th.unsqueeze(0) * stds.unsqueeze(1))
            in_circles = mixture_model(circles_per_point.view(-1, circles_per_point.shape[-1]))
            out_circles = feature_model_a(in_circles)
            in_circles = in_circles.view(circles_per_point.shape)
            out_circles= out_circles.view(circles_per_point.shape)


        fig, axes = plt.subplots(1,2, figsize=(8,4))
        axes[0].scatter(var_to_np(inverted)[:,0], var_to_np(inverted)[:,1],
                       alpha=0.75)
        axes[0].scatter(var_to_np(in_mixture_samples)[:,0], var_to_np(in_mixture_samples)[:,1],
                       alpha=0.75)
        axes[0].scatter(var_to_np(train_inputs)[:,0], var_to_np(train_inputs)[:,1])
        for c in var_to_np(in_circles):
            axes[0].plot(c[:,0], c[:,1],color='black')
        axes[0].set_title("Input space")
        axes[0].axis('equal')
        axes[1].scatter(var_to_np(out_samples)[:,0], var_to_np(out_samples)[:,1],
                       alpha=0.75,
                       label='Latent Prior')
        axes[1].scatter(var_to_np(out_mixture_samples)[:,0], var_to_np(out_mixture_samples)[:,1],
                       alpha=0.75,
                       label='Mixture')
        axes[1].scatter(var_to_np(out_real)[:,0], var_to_np(out_real)[:,1], label='Real data')
        for c in var_to_np(out_circles):
            axes[1].plot(c[:,0], c[:,1],color='black')
        axes[1].set_title("Output space")
        axes[1].axis('equal')
        axes[1].legend(bbox_to_anchor=(1,1,0,0))
        display_close(fig)
        



### Only KL-DIV

In [None]:
n_epochs = 2001
eps = 1e-7
for i_epoch in range(n_epochs):
    point_dists = [th.distributions.MultivariateNormal(m, covariance_matrix=th.diag(th.exp(s) * th.exp(s)))
                   for m, s in zip(means_per_point, log_stds_per_point)]

    out_samples = class_dist.get_samples(0,100,)
    
    mixture_samples = th.cat([(th.randn(2,len(m)) * th.exp(l).unsqueeze(0)) + m.unsqueeze(0)
                          for m,l in zip(means_per_point, log_stds_per_point)])
    in_samples_mixture = mixture_model(mixture_samples)
    out_mixture_samples = feature_model_a(in_samples_mixture)
    
    all_out_samples = th.cat((out_samples, out_mixture_samples))

    prior_log_probs = class_dist.get_total_log_prob(0,all_out_samples)

    inverted = invert(feature_model_a, out_samples)
    inverted_to_mixture = invert(mixture_model, inverted)
    all_pre_in_samples = th.cat((inverted_to_mixture, mixture_samples))


    # jetzt likelihood inverted point ausrechnen
    log_probs_per_dist = th.stack([dist.log_prob(all_pre_in_samples) for dist in point_dists], 0)

    mixture_probs = th.mean(th.exp(log_probs_per_dist) + eps, dim=0)
    mixture_log_probs = th.log(mixture_probs)
    prior_probs = th.exp(prior_log_probs)
    sym_kl_div = -0.5 * th.sum(prior_probs * (mixture_log_probs - prior_log_probs)) - (
        0.5 * th.sum(mixture_probs * (prior_log_probs - mixture_log_probs)))
    in_diffs = in_samples_mixture.view(len(train_inputs), -1, train_inputs.shape[1]) - train_inputs.unsqueeze(1)
    OT_in = th.mean(th.norm(in_diffs,p=2, dim=-1))
    out_real = feature_model_a(train_inputs)
    out_diffs = out_mixture_samples.view(len(train_inputs), -1, out_real.shape[1]) - out_real.unsqueeze(1)
    OT_out = th.mean(th.norm(out_diffs,p=2, dim=-1))
    
    loss =   sym_kl_div + OT_in * 100 + OT_out * 100 
    optim_mixture_dist.zero_grad()
    optim_model_a.zero_grad()
    optim_dist.zero_grad()
    optim_mixture_model.zero_grad()
    loss.backward()
    optim_mixture_dist.step()
    optim_model_a.step()
    optim_dist.step()
    optim_mixture_model.step()
    if i_epoch % (n_epochs // 20) == 0:
        print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
        print("JS: {:.1E}".format(sym_kl_div.item()))
        print("OT: {:.1E}".format((OT_in + OT_out).item()))
        print("Mean std {:.1E}".format(th.mean(th.stack([th.exp(l) for l in log_stds_per_point]))))
        print("mean and std prior",class_dist.get_mean_std(0))
        

        with th. no_grad():
            out_samples = class_dist.get_samples(0,100,)
            inverted = invert(feature_model_a, out_samples)
            mixture_samples = th.cat([dist.sample((2,)) for dist in point_dists])
            in_mixture_samples = mixture_model(mixture_samples)
            out_mixture_samples = feature_model_a(in_mixture_samples)
            out_real = feature_model_a(train_inputs)
            
            # regenerate  transformed circles
            radians = np.linspace(0,2*np.pi,24)
            circle_points = np.stack([np.cos(radians), np.sin(radians)], axis=-1)
            circle_th = np_to_var(circle_points, device=train_inputs.device, dtype=np.float32)
            ms = th.stack(means_per_point)
            stds = th.exp(th.stack(log_stds_per_point))
            circles_per_point = ms.unsqueeze(1) + (circle_th.unsqueeze(0) * stds.unsqueeze(1))
            in_circles = mixture_model(circles_per_point.view(-1, circles_per_point.shape[-1]))
            out_circles = feature_model_a(in_circles)
            in_circles = in_circles.view(circles_per_point.shape)
            out_circles= out_circles.view(circles_per_point.shape)


        fig, axes = plt.subplots(1,2, figsize=(8,4))
        axes[0].scatter(var_to_np(inverted)[:,0], var_to_np(inverted)[:,1],
                       alpha=0.75)
        axes[0].scatter(var_to_np(in_mixture_samples)[:,0], var_to_np(in_mixture_samples)[:,1],
                       alpha=0.75)
        axes[0].scatter(var_to_np(train_inputs)[:,0], var_to_np(train_inputs)[:,1])
        for c in var_to_np(in_circles):
            axes[0].plot(c[:,0], c[:,1],color='black')
        axes[0].set_title("Input space")
        axes[0].axis('equal')
        axes[1].scatter(var_to_np(out_samples)[:,0], var_to_np(out_samples)[:,1],
                       alpha=0.75,
                       label='Latent Prior')
        axes[1].scatter(var_to_np(out_mixture_samples)[:,0], var_to_np(out_mixture_samples)[:,1],
                       alpha=0.75,
                       label='Mixture')
        axes[1].scatter(var_to_np(out_real)[:,0], var_to_np(out_real)[:,1], label='Real data')
        for c in var_to_np(out_circles):
            axes[1].plot(c[:,0], c[:,1],color='black')
        axes[1].set_title("Output space")
        axes[1].axis('equal')
        axes[1].legend(bbox_to_anchor=(1,1,0,0))
        display_close(fig)
        



In [None]:
n_epochs = 20001
eps = 1e-7
for i_epoch in range(n_epochs):
    point_dists = [th.distributions.MultivariateNormal(m, covariance_matrix=th.diag(th.exp(s) * th.exp(s)))
                   for m, s in zip(means_per_point, log_stds_per_point)]

    out_samples = class_dist.get_samples(0,100,)
    
    mixture_samples = th.cat([(th.randn(2,len(m)) * th.exp(l).unsqueeze(0)) + m.unsqueeze(0)
                          for m,l in zip(means_per_point, log_stds_per_point)])
    in_samples_mixture = mixture_model(mixture_samples)
    out_mixture_samples = feature_model_a(in_samples_mixture)
    
    all_out_samples = th.cat((out_samples, out_mixture_samples))

    prior_log_probs = class_dist.get_total_log_prob(0,all_out_samples)

    inverted = invert(feature_model_a, out_samples)
    inverted_to_mixture = invert(mixture_model, inverted)
    all_pre_in_samples = th.cat((inverted_to_mixture, mixture_samples))


    # jetzt likelihood inverted point ausrechnen
    log_probs_per_dist = th.stack([dist.log_prob(all_pre_in_samples) for dist in point_dists], 0)

    mixture_probs = th.mean(th.exp(log_probs_per_dist) + eps, dim=0)
    mixture_log_probs = th.log(mixture_probs)
    prior_probs = th.exp(prior_log_probs)
    sym_kl_div = -0.5 * th.sum(prior_probs * (mixture_log_probs - prior_log_probs)) - (
        0.5 * th.sum(mixture_probs * (prior_log_probs - mixture_log_probs)))
    in_diffs = in_samples_mixture.view(len(train_inputs), -1, train_inputs.shape[1]) - train_inputs.unsqueeze(1)
    OT_in = th.mean(th.norm(in_diffs,p=2, dim=-1))
    out_real = feature_model_a(train_inputs)
    out_diffs = out_mixture_samples.view(len(train_inputs), -1, out_real.shape[1]) - out_real.unsqueeze(1)
    OT_out = th.mean(th.norm(out_diffs,p=2, dim=-1))
    
    loss =   sym_kl_div + OT_in * 100 + OT_out * 100 
    optim_mixture_dist.zero_grad()
    optim_model_a.zero_grad()
    optim_dist.zero_grad()
    optim_mixture_model.zero_grad()
    loss.backward()
    optim_mixture_dist.step()
    optim_model_a.step()
    optim_dist.step()
    optim_mixture_model.step()
    if i_epoch % (n_epochs // 20) == 0:
        print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
        print("JS: {:.1E}".format(sym_kl_div.item()))
        print("OT: {:.1E}".format((OT_in + OT_out).item()))
        print("Mean std {:.1E}".format(th.mean(th.stack([th.exp(l) for l in log_stds_per_point]))))
        print("mean and std prior",class_dist.get_mean_std(0))
        

        with th. no_grad():
            out_samples = class_dist.get_samples(0,100,)
            inverted = invert(feature_model_a, out_samples)
            mixture_samples = th.cat([dist.sample((2,)) for dist in point_dists])
            in_mixture_samples = mixture_model(mixture_samples)
            out_mixture_samples = feature_model_a(in_mixture_samples)
            out_real = feature_model_a(train_inputs)
            
            # regenerate  transformed circles
            radians = np.linspace(0,2*np.pi,24)
            circle_points = np.stack([np.cos(radians), np.sin(radians)], axis=-1)
            circle_th = np_to_var(circle_points, device=train_inputs.device, dtype=np.float32)
            ms = th.stack(means_per_point)
            stds = th.exp(th.stack(log_stds_per_point))
            circles_per_point = ms.unsqueeze(1) + (circle_th.unsqueeze(0) * stds.unsqueeze(1))
            in_circles = mixture_model(circles_per_point.view(-1, circles_per_point.shape[-1]))
            out_circles = feature_model_a(in_circles)
            in_circles = in_circles.view(circles_per_point.shape)
            out_circles= out_circles.view(circles_per_point.shape)


        fig, axes = plt.subplots(1,2, figsize=(8,4))
        axes[0].scatter(var_to_np(inverted)[:,0], var_to_np(inverted)[:,1],
                       alpha=0.75)
        axes[0].scatter(var_to_np(in_mixture_samples)[:,0], var_to_np(in_mixture_samples)[:,1],
                       alpha=0.75)
        axes[0].scatter(var_to_np(train_inputs)[:,0], var_to_np(train_inputs)[:,1])
        for c in var_to_np(in_circles):
            axes[0].plot(c[:,0], c[:,1],color='black')
        axes[0].set_title("Input space")
        axes[0].axis('equal')
        axes[1].scatter(var_to_np(out_samples)[:,0], var_to_np(out_samples)[:,1],
                       alpha=0.75,
                       label='Latent Prior')
        axes[1].scatter(var_to_np(out_mixture_samples)[:,0], var_to_np(out_mixture_samples)[:,1],
                       alpha=0.75,
                       label='Mixture')
        axes[1].scatter(var_to_np(out_real)[:,0], var_to_np(out_real)[:,1], label='Real data')
        for c in var_to_np(out_circles):
            axes[1].plot(c[:,0], c[:,1],color='black')
        axes[1].set_title("Output space")
        axes[1].axis('equal')
        axes[1].legend(bbox_to_anchor=(1,1,0,0))
        display_close(fig)
        



### without learned prior model

In [None]:
from reversible2.distribution import TwoClassDist

from reversible2.blocks import dense_add_block, conv_add_block_3x3
from reversible2.rfft import RFFT, Interleave
from reversible2.util import set_random_seeds
from torch.nn import ConstantPad2d
import torch as th
from reversible2.splitter import SubsampleSplitter


set_random_seeds(2019011641, cuda)
feature_model_a = nn.Sequential(
    dense_add_block(2,200),
    dense_add_block(2,200),
    dense_add_block(2,200),
    dense_add_block(2,200),
)
if cuda:
    feature_model_a.cuda()
from reversible2.ot_exact import ot_euclidean_loss_for_samples
class_dist = TwoClassDist(2,0, [0,1])
if cuda:
    class_dist.cuda()

means_per_point = [th.zeros(2,requires_grad=True) for _ in train_inputs]
log_stds_per_point = [th.zeros(2,requires_grad=True) for _ in train_inputs]

for m,l in zip(means_per_point, log_stds_per_point):
    m.data[:] = th.randn(2) * 0.05
    l.data[:] = -2
    
optim_model_a = th.optim.Adam(feature_model_a.parameters())
optim_dist = th.optim.Adam(class_dist.parameters(), lr=1e-2)
optim_mixture_dist = th.optim.Adam(means_per_point + log_stds_per_point, 
                                  lr=1e-2)

In [None]:
n_epochs = 20001
eps = 1e-7
for i_epoch in range(n_epochs):
    point_dists = [th.distributions.MultivariateNormal(m, covariance_matrix=th.diag(th.exp(s) * th.exp(s)))
                   for m, s in zip(means_per_point, log_stds_per_point)]

    out_samples = class_dist.get_samples(0,100,)
    
    mixture_samples = th.cat([(th.randn(2,len(m)) * th.exp(l).unsqueeze(0)) + m.unsqueeze(0)
                          for m,l in zip(means_per_point, log_stds_per_point)])
    out_mixture_samples = feature_model_a(mixture_samples)
    
    all_out_samples = th.cat((out_samples, out_mixture_samples))

    prior_log_probs = class_dist.get_total_log_prob(0,all_out_samples)

    inverted = invert(feature_model_a, out_samples)
    all_in_samples = th.cat((inverted, mixture_samples))


    # jetzt likelihood inverted point ausrechnen
    log_probs_per_dist = th.stack([dist.log_prob(all_in_samples) for dist in point_dists], 0)

    mixture_probs = th.mean(th.exp(log_probs_per_dist) + eps, dim=0)
    mixture_log_probs = th.log(mixture_probs)
    prior_probs = th.exp(prior_log_probs)
    sym_kl_div = -0.5 * th.sum(prior_probs * (mixture_log_probs - prior_log_probs)) - (
        0.5 * th.sum(mixture_probs * (prior_log_probs - mixture_log_probs)))
    in_diffs = mixture_samples.view(len(train_inputs), -1, train_inputs.shape[1]) - train_inputs.unsqueeze(1)
    OT_in = th.mean(th.norm(in_diffs,p=2, dim=-1))
    out_real = feature_model_a(train_inputs)
    out_diffs = out_mixture_samples.view(len(train_inputs), -1, out_real.shape[1]) - out_real.unsqueeze(1)
    OT_out = th.mean(th.norm(out_diffs,p=2, dim=-1))
    
    loss =  OT_in * 100 + OT_out * 100 + sym_kl_div
    optim_mixture_dist.zero_grad()
    optim_model_a.zero_grad()
    optim_dist.zero_grad()
    loss.backward()
    optim_mixture_dist.step()
    optim_model_a.step()
    optim_dist.step()
    if i_epoch % (n_epochs // 20) == 0:
        print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
        print("JS: {:.1E}".format(sym_kl_div.item()))
        print("OT: {:.1E}".format(OT.item()))
        print("Mean std {:.1E}".format(th.mean(th.stack([th.exp(l) for l in log_stds_per_point]))))
        print("mean and std prior",class_dist.get_mean_std(0))
        with th. no_grad():
            out_samples = class_dist.get_samples(0,100,)
            inverted = invert(feature_model_a, out_samples)
            mixture_samples = th.cat([dist.sample((2,)) for dist in point_dists])
            out_mixture_samples = feature_model_a(mixture_samples)
            out_real = feature_model_a(train_inputs)
        fig, axes = plt.subplots(1,2, figsize=(8,4))
        axes[0].scatter(var_to_np(inverted)[:,0], var_to_np(inverted)[:,1])
        axes[0].scatter(var_to_np(mixture_samples)[:,0], var_to_np(mixture_samples)[:,1])
        axes[0].scatter(var_to_np(train_inputs)[:,0], var_to_np(train_inputs)[:,1])
        axes[0].set_title("Input space")
        axes[0].axis('equal')
        axes[1].scatter(var_to_np(out_samples)[:,0], var_to_np(out_samples)[:,1],
                       label='Latent Prior')
        axes[1].scatter(var_to_np(out_mixture_samples)[:,0], var_to_np(out_mixture_samples)[:,1],
                       label='Mixture')
        axes[1].scatter(var_to_np(out_real)[:,0], var_to_np(out_real)[:,1], label='Real data')
        axes[1].set_title("Output space")
        axes[1].axis('equal')
        axes[1].legend(bbox_to_anchor=(1,1,0,0))
        display_close(fig)
        



In [None]:
n_epochs = 2001
eps = 1e-7
for i_epoch in range(n_epochs):
    point_dists = [th.distributions.MultivariateNormal(m, covariance_matrix=th.diag(th.exp(s) * th.exp(s)))
    optim_mixt
                   for m, s in zip(means_per_point, log_stds_per_point)]

    out_samples = class_dist.get_samples(0,100,)
    
    mixture_samples = th.cat([(th.randn(2,len(m)) * th.exp(l).unsqueeze(0)) + m.unsqueeze(0)
                          for m,l in zip(means_per_point, log_stds_per_point)])
    out_mixture_samples = feature_model_a(mixture_samples)
    
    all_out_samples = th.cat((out_samples, out_mixture_samples))

    prior_log_probs = class_dist.get_total_log_prob(0,all_out_samples)

    inverted = invert(feature_model_a, out_samples)
    all_in_samples = th.cat((inverted, mixture_samples))


    # jetzt likelihood inverted point ausrechnen
    log_probs_per_dist = th.stack([dist.log_prob(all_in_samples) for dist in point_dists], 0)

    mixture_probs = th.mean(th.exp(log_probs_per_dist) + eps, dim=0)
    mixture_log_probs = th.log(mixture_probs)
    prior_probs = th.exp(prior_log_probs)
    sym_kl_div = -0.5 * th.sum(prior_probs * (mixture_log_probs - prior_log_probs)) - (
        0.5 * th.sum(mixture_probs * (prior_log_probs - mixture_log_probs)))
    diffs = mixture_samples.view(len(train_inputs), -1, mixture_samples.shape[1]) - train_inputs.unsqueeze(1)
    OT = th.mean(th.norm(diffs,p=2, dim=-1))
    loss =  OT * 100 + sym_kl_div
    optim_mixture_dist.zero_grad()
    optim_model_a.zero_grad()
    optim_dist.zero_grad()
    loss.backward()
    optim_mixture_dist.step()
    optim_model_a.step()
    optim_dist.step()
    if i_epoch % (n_epochs // 20) == 0:
        print("Epoch {:d} of {:d}".format(i_epoch, n_epochs))
        print("JS: {:.1E}".format(sym_kl_div.item()))
        print("OT: {:.1E}".format(OT.item()))
        print("Mean std {:.1E}".format(th.mean(th.stack([th.exp(l) for l in log_stds_per_point]))))
        print("mean and std prior",class_dist.get_mean_std(0))
        with th. no_grad():
            out_samples = class_dist.get_samples(0,100,)
            inverted = invert(feature_model_a, out_samples)
            mixture_samples = th.cat([dist.sample((2,)) for dist in point_dists])
            out_mixture_samples = feature_model_a(mixture_samples)
            out_real = feature_model_a(train_inputs)
        fig, axes = plt.subplots(1,2, figsize=(8,4))
        axes[0].scatter(var_to_np(inverted)[:,0], var_to_np(inverted)[:,1])
        axes[0].scatter(var_to_np(mixture_samples)[:,0], var_to_np(mixture_samples)[:,1])
        axes[0].scatter(var_to_np(train_inputs)[:,0], var_to_np(train_inputs)[:,1])
        axes[0].set_title("Input space")
        axes[0].axis('equal')
        axes[1].scatter(var_to_np(out_samples)[:,0], var_to_np(out_samples)[:,1],
                       label='Latent Prior')
        axes[1].scatter(var_to_np(out_mixture_samples)[:,0], var_to_np(out_mixture_samples)[:,1],
                       label='Mixture')
        axes[1].scatter(var_to_np(out_real)[:,0], var_to_np(out_real)[:,1], label='Real data')
        axes[1].set_title("Output space")
        axes[1].axis('equal')
        axes[1].legend(bbox_to_anchor=(1,1,0,0))
        display_close(fig)
        



In [None]:
n_epochs = 500
eps = 1e-7
for i_epoch in range(n_epochs):
    point_dists = [th.distributions.MultivariateNormal(m, covariance_matrix=th.diag(th.exp(s) * th.exp(s)))
                   for m, s in zip(means_per_point, log_stds_per_point)]

    out_samples = class_dist.get_samples(0,100,)
    
    
    mixture_samples = th.cat([dist.sample((2,)) for dist in point_dists])
    out_mixture_samples = feature_model_a(mixture_samples)
    
    all_out_samples = th.cat((out_samples, out_mixture_samples))

    prior_log_probs = class_dist.get_total_log_prob(0,all_out_samples)

    inverted = invert(feature_model_a, out_samples)
    all_in_samples = th.cat((inverted, mixture_samples))


    # jetzt likelihood inverted point ausrechnen
    log_probs_per_dist = th.stack([dist.log_prob(all_in_samples) for dist in point_dists], 0)

    mixture_probs = th.mean(th.exp(log_probs_per_dist) + eps, dim=0)
    mixture_log_probs = th.log(mixture_probs)
    prior_probs = th.exp(prior_log_probs)
    sym_kl_div = -0.5 * th.sum(prior_probs * (mixture_log_probs - prior_log_probs)) - (
        0.5 * th.sum(mixture_probs * (prior_log_probs - mixture_log_probs)))
    optim_mixture_dist.zero_grad()
    optim_model_a.zero_grad()
    optim_dist.zero_grad()
    sym_kl_div.backward()
    optim_mixture_dist.step()
    optim_model_a.step()
    optim_dist.step()
    if i_epoch % (n_epochs // 20) == 0:
        print("{:.1E}".format(sym_kl_div.item()))
        print("Mean std {:.1E}".format(th.mean(th.stack([th.exp(l) for l in log_stds_per_point]))))
        with th. no_grad():
            out_samples = class_dist.get_samples(0,100,)
            inverted = invert(feature_model_a, out_samples)
            mixture_samples = th.cat([dist.sample((2,)) for dist in point_dists])
            out_mixture_samples = feature_model_a(mixture_samples)
        fig, axes = plt.subplots(1,2, figsize=(8,4))
        axes[0].scatter(var_to_np(inverted)[:,0], var_to_np(inverted)[:,1])
        axes[0].scatter(var_to_np(mixture_samples)[:,0], var_to_np(mixture_samples)[:,1])
        axes[0].set_title("Input space")
        axes[1].scatter(var_to_np(out_samples)[:,0], var_to_np(out_samples)[:,1],
                       label='Latent Prior')
        axes[1].scatter(var_to_np(out_mixture_samples)[:,0], var_to_np(out_mixture_samples)[:,1],
                       label='Mixture')
        axes[1].set_title("Output space")
        axes[1].legend(bbox_to_anchor=(1,1,0,0))
        display_close(fig)
        



### before only look at samples from prior

In [None]:
n_epochs = 500
eps = 1e-7
for i_epoch in range(n_epochs):
    point_dists = [th.distributions.MultivariateNormal(m, covariance_matrix=th.diag(th.exp(s) * th.exp(s)))
                   for m, s in zip(means_per_point, log_stds_per_point)]

    out_samples = class_dist.get_samples(0,100,)

    prior_log_probs = class_dist.get_total_log_prob(0,out_samples)

    inverted = invert(feature_model_a, out_samples)


    # jetzt likelihood inverted point ausrechnen
    log_probs_per_dist = th.stack([dist.log_prob(inverted) for dist in point_dists], 0)

    mixture_probs = th.mean(th.exp(log_probs_per_dist) + eps, dim=0)

    mixture_log_probs = th.log(mixture_probs)
    prior_probs = th.exp(prior_log_probs)
    sym_kl_div = -0.5 * th.sum(prior_probs * (mixture_log_probs - prior_log_probs)) - (
        0.5 * th.sum(mixture_probs * (prior_log_probs - mixture_log_probs)))
    optim_mixture_dist.zero_grad()
    optim_model_a.zero_grad()
    optim_dist.zero_grad()
    sym_kl_div.backward()
    optim_mixture_dist.step()
    optim_model_a.step()
    optim_dist.step()
    if i_epoch % (n_epochs // 20) == 0:
        print("{:.1E}".format(sym_kl_div.item()))
        print("Mean std {:.1E}".format(th.mean(th.stack([th.exp(l) for l in log_stds_per_point]))))
        with th. no_grad():
            out_samples = class_dist.get_samples(0,100,)
            mixture_samples = th.cat([dist.sample((2,)) for dist in point_dists])
            out_mixture_samples = feature_model_a(mixture_samples)
            inverted = invert(feature_model_a, out_samples)
        fig, axes = plt.subplots(1,2, figsize=(8,4))
        axes[0].scatter(var_to_np(inverted)[:,0], var_to_np(inverted)[:,1])
        axes[0].scatter(var_to_np(mixture_samples)[:,0], var_to_np(mixture_samples)[:,1])
        axes[0].set_title("Input space")
        axes[1].scatter(var_to_np(out_samples)[:,0], var_to_np(out_samples)[:,1],
                       label='Latent Prior')
        axes[1].scatter(var_to_np(out_mixture_samples)[:,0], var_to_np(out_mixture_samples)[:,1],
                       label='Mixture')
        axes[1].set_title("Output space")
        axes[1].legend(bbox_to_anchor=(1,1,0,0))
        display_close(fig)
        



In [None]:
class_dist.get_mean_std(0)

In [None]:
# associate a gaussian with every point
# bring it forward 