In [None]:
%cd /home/schirrmr/

%load_ext autoreload
%autoreload 2


import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 4.0)
matplotlib.rcParams['font.size'] = 7

import matplotlib.lines as mlines
import seaborn
seaborn.set_style('darkgrid')
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('DEBUG')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.DEBUG, stream=sys.stdout)
seaborn.set_palette('colorblind')

In [None]:
import os
# add the repo itself
os.sys.path.insert(0, '/home/schirrmr/code/explaining/reversible//')

In [None]:
import pickle
import gzip
## from  http://deeplearning.net/data/mnist/mnist.pkl.gz I assume?
train, val, test = pickle.load(gzip.open('data/mnist/mnist.pkl.gz'), encoding='bytes')

X_train, y_train = train
X_val, y_val = val

X_train_topo = X_train.reshape(X_train.shape[0], 1, 28,28)
X_val_topo = X_val.reshape(X_val.shape[0], 1, 28,28)
from numpy.random import RandomState
#X_train_topo = np.pad(X_train_topo,((0,0),(0,0),(2,2),(2,2)), 'constant')
#X_val_topo = np.pad(X_val_topo,((0,0),(0,0),(2,2),(2,2)), 'constant')

In [None]:
from reversible.revnet import ReversibleBlock
import torch.nn as nn
def rev_block(n_chans, n_intermediate_chans):
    c = n_chans // 2
    n_i_c = n_intermediate_chans
    return ReversibleBlock(
        nn.Sequential(
            (nn.Linear(c, n_i_c,)),
             nn.ReLU(),
             nn.Linear(n_i_c, c,)),
        nn.Sequential(
            (nn.Linear(c, n_i_c,)),
             nn.ReLU(),
             nn.Linear(n_i_c, c,)))

def plot_sorted_examples(sorted_examples, cmap=cm.Greys_r, vmin=0,vmax=1):
    fig, axes = plt.subplots(2,10, figsize=(20,5))
    for ax, im in zip(axes.flatten(), sorted_examples.squeeze()):
        ax.imshow(im, vmin=vmin, vmax=vmax,cmap=cmap)
    return fig


In [None]:
mask = (y_train == 0) | (y_train == 1)
x = X_train_topo[mask]#[:1000]
y = y_train[mask]

In [None]:
from reversible.iterator import GenerativeIterator
from reversible.revnet import SubsampleSplitter, ViewAs
from reversible.util import set_random_seeds
from reversible.revnet import init_model_params
import torch as th
set_random_seeds(34093049, True)
feature_model = th.nn.Sequential(
    SubsampleSplitter(stride=2,checkerboard=True),
    ViewAs((-1,4,14,14),(-1,4*14*14)),
    rev_block(784,2000),
    rev_block(784,2000),
    rev_block(784,2000),)
feature_model = feature_model.cuda()
init_model_params(feature_model, 1)

n_dims = int(np.prod(x.shape[1:]))
n_clusters = int(len(np.unique(y)))
means_per_dim = th.autograd.Variable(th.zeros(n_clusters,n_dims).cuda() * 1.0, requires_grad=True)
stds_per_dim = th.autograd.Variable(th.ones(n_clusters,n_dims).cuda()  * 0.5, requires_grad=True)

In [None]:
from reversible.sliced import sample_directions
from reversible.util import np_to_var
directions_adv = th.cat([sample_directions(n_dims, True,True),
                         sample_directions(n_dims, True,True),
                         sample_directions(n_dims, True,True),
                        ],dim=0)
directions_adv = th.autograd.Variable(directions_adv.data, requires_grad=True)

inputs = np_to_var(x, dtype=np.float32).cuda()
targets = np_to_var(np.array([y == 0, y == 1]).T, dtype=np.float32).cuda()

from reversible.training import init_std_mean

init_std_mean(feature_model, inputs, targets, means_per_dim, stds_per_dim,
                 set_phase_interval=True)

optimizer = th.optim.Adam([
    {'params': list(feature_model.parameters()) + 
                       [means_per_dim, stds_per_dim],
    'lr': 0.001},],
{'params':[directions_adv,],
    'lr': -0.001},)

In [None]:
iterator = GenerativeIterator(upsample_supervised=True, batch_size=10610//8)


In [None]:
from reversible.sinkhorn import sinkhorn_to_gauss_dist
from reversible.sliced import sliced_from_samples_for_gauss_dist
from reversible.loss_util import hard_loss_per_cluster
from reversible.gaussian import get_gauss_samples
from reversible.revnet import invert
from reversible.ot_exact import ot_emd_loss

def reconstruct_loss(o,m,s):
    o = o[:len(o)//2]
    inputs = invert(feature_model, o)
    o_perturbed = o + get_gauss_samples(len(o), m.detach()*0, s.detach() * 0 + 0.01)
    inputs_perturbed = invert(feature_model, o_perturbed)
    diffs = inputs - inputs_perturbed
    loss = th.mean(diffs * diffs)
    loss += th.mean(th.abs(diffs))
    return loss

hard_loss_fn = lambda o,m,s : (ot_emd_loss(o,m,s) +
                              + 15 * sliced_from_samples_for_gauss_dist(o,m,s,n_dirs=4, adv_dirs=None)
                              + 15 * reconstruct_loss(o,m,s))
loss_fn = lambda o,d,t,m,s: hard_loss_per_cluster(o,t,m,s, hard_loss_fn)

loss_fn_adv = None # no adversarial training

In [None]:
def train_one_batch(iterator, feature_model, loss_fn, means_per_dim, stds_per_dim,
                   optimizer):
    b = iterator.get_batches(inputs, targets, None, None)
    b_X, b_y = b.__next__()
    outs = feature_model(b_X)
    loss = loss_fn(outs, None, b_y, means_per_dim, stds_per_dim)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    loss = var_to_np(loss)[0]
    return loss

In [None]:
n_batches_per_dataset = len([b for b in iterator.get_batches(inputs, targets, None, None)])
n_critic_updates = 0
n_updates_per_epoch = int(np.ceil(n_batches_per_dataset / (n_critic_updates + 1.0)))

In [None]:
import pandas as pd
from reversible.util import var_to_np
from reversible.revnet import get_inputs_from_reverted_samples

rng = RandomState(1)
epochs_dataframe = pd.DataFrame()
for i_epoch in range(100001):
    feature_model.train()
    for i_update in range(n_updates_per_epoch):
        if i_update % (n_critic_updates + 1) == n_critic_updates:
            # now run generator
            loss = train_one_batch(iterator, feature_model, loss_fn, means_per_dim, stds_per_dim, optimizer)
            stds_per_dim.data.clamp_(min=0)
        else:
            loss = train_one_batch(iterator, feature_model, loss_fn_adv, means_per_dim, stds_per_dim, optimizer_adv)
            stds_per_dim.data.clamp_(min=0) # should not be necessary...
    feature_model.eval()
    epochs_dataframe = epochs_dataframe.append({
        'total_loss': np.mean(loss),
    },
        ignore_index=True)
    if i_epoch % 10 == 0:
        display(epochs_dataframe.iloc[-1:])
    if i_epoch % 100 == 0:
        
        all_outs = feature_model(inputs)
        all_outs = var_to_np(all_outs).squeeze()
        
        for i_cluster in range(2):
            fig = plt.figure()
            plt.plot(var_to_np(stds_per_dim[i_cluster]))
            plt.plot(np.std(all_outs[y == i_cluster], axis=0))
            plt.legend(('Distribution', 'Outputs'))
            plt.title("Stds of dimensions in gaussian and in actual outputs", fontsize=18)
            display(fig)
            plt.close(fig)
        
        for i_cluster in range(2):
            rec_examples, _ = get_inputs_from_reverted_samples(
                 15, means_per_dim[i_cluster:i_cluster+1], stds_per_dim[i_cluster:i_cluster+1],
                np_to_var([1,]), feature_model,
                to_4d=False)

            fig, axes = plt.subplots(3,5, figsize=(20,9))

            for i_example, ax in enumerate(axes.flatten()):
                ax.imshow(rec_examples[i_example].squeeze(), vmin=0, vmax=1, cmap=cm.Greys)
            fig.suptitle("Reverted examples using gaussian mean/std", fontsize=18)
            display(fig)
            plt.close(fig)
        
        for i_cluster in range(len(means_per_dim)):
            mean =  np_to_var(np.mean(all_outs[y == i_cluster], axis=0, keepdims=True), dtype=np.float32).cuda()
            std = np_to_var(np.std(all_outs[y == i_cluster], axis=0, keepdims=True), dtype=np.float32).cuda()
            rec_examples, _ = get_inputs_from_reverted_samples(
                 15, mean, std,
                np_to_var([1,]), feature_model,
                to_4d=False)

            fig, axes = plt.subplots(3,5, figsize=(20,9))

            for i_example, ax in enumerate(axes.flatten()):
                ax.imshow(rec_examples[i_example].squeeze(), vmin=0, vmax=1, cmap=cm.Greys)
            fig.suptitle("Reverted examples using mean/std of outputs", fontsize=18)
            display(fig)
            plt.close(fig)
        
        for i_cluster in range(2):
            stds = np.std(all_outs[y == i_cluster], axis=0)
            sorted_stds = np.argsort(stds)[::-1]
            for i_large_std in sorted_stds[:3]:
                stds_cloned = stds_per_dim.clone()
                stds_cloned = stds_cloned * 0
                stds_cloned[i_cluster,i_large_std] = float(stds[i_large_std])
                rec_examples, gauss_samples = get_inputs_from_reverted_samples(
                    1000, means_per_dim[i_cluster:i_cluster+1],
                    stds_cloned[i_cluster:i_cluster+1], np_to_var([1]), feature_model, to_4d=False)
                i_sort = np.argsort(var_to_np(gauss_samples)[:, i_large_std])
                sorted_examples = rec_examples[i_sort]
                sorted_examples = sorted_examples[::1000//20]
                fig = plot_sorted_examples(sorted_examples)
                fig.suptitle("Dimension {:d}".format(i_large_std), fontsize=16)
                display(fig)
                plt.close(fig)

In [None]:
from reversible.training import select_outs_from_targets

In [None]:
outs = feature_model(inputs[:800])
outs = select_outs_from_targets(outs, targets[:800], 0)

In [None]:
%%time
sinkhorn_to_gauss_dist(outs,mean,std, epsilon=1e-1, stop_threshold=0.001)

In [None]:
%%time
emd_loss(outs,mean,std)

In [None]:
epochs_dataframe.plot()

In [None]:
all_outs = feature_model(inputs)
all_outs = var_to_np(all_outs).squeeze()

for i_cluster in range(2):
    fig = plt.figure()
    plt.plot(var_to_np(stds_per_dim[i_cluster]))
    plt.plot(np.std(all_outs[y == i_cluster], axis=0))
    plt.legend(('Distribution', 'Outputs'))
    plt.title("Stds of dimensions in gaussian and in actual outputs", fontsize=18)
    display(fig)
    plt.close(fig)

for i_cluster in range(2):
    rec_examples, _ = get_inputs_from_reverted_samples(
         15, means_per_dim[i_cluster:i_cluster+1], stds_per_dim[i_cluster:i_cluster+1],
        np_to_var([1,]), feature_model,
        to_4d=False)

    fig, axes = plt.subplots(3,5, figsize=(20,9))

    for i_example, ax in enumerate(axes.flatten()):
        ax.imshow(rec_examples[i_example].squeeze(), vmin=0, vmax=1, cmap=cm.Greys)
    fig.suptitle("Reverted examples using gaussian mean/std", fontsize=18)
    display(fig)
    plt.close(fig)

for i_cluster in range(len(means_per_dim)):
    mean =  np_to_var(np.mean(all_outs[y == i_cluster], axis=0, keepdims=True), dtype=np.float32).cuda()
    std = np_to_var(np.std(all_outs[y == i_cluster], axis=0, keepdims=True), dtype=np.float32).cuda()
    rec_examples, _ = get_inputs_from_reverted_samples(
         15, mean, std,
        np_to_var([1,]), feature_model,
        to_4d=False)

    fig, axes = plt.subplots(3,5, figsize=(20,9))

    for i_example, ax in enumerate(axes.flatten()):
        ax.imshow(rec_examples[i_example].squeeze(), vmin=0, vmax=1, cmap=cm.Greys)
    fig.suptitle("Reverted examples using mean/std of outputs", fontsize=18)
    display(fig)
    plt.close(fig)

for i_cluster in range(2):
    stds = np.std(all_outs[y == i_cluster], axis=0)
    sorted_stds = np.argsort(stds)[::-1]
    for i_large_std in sorted_stds[:3]:
        stds_cloned = stds_per_dim.clone()
        stds_cloned = stds_cloned * 0
        stds_cloned[i_cluster,i_large_std] = float(stds[i_large_std])
        rec_examples, gauss_samples = get_inputs_from_reverted_samples(
            1000, means_per_dim[i_cluster:i_cluster+1],
            stds_cloned[i_cluster:i_cluster+1], np_to_var([1]), feature_model, to_4d=False)
        i_sort = np.argsort(var_to_np(gauss_samples)[:, i_large_std])
        sorted_examples = rec_examples[i_sort]
        sorted_examples = sorted_examples[::1000//20]
        fig = plot_sorted_examples(sorted_examples)
        fig.suptitle("Dimension {:d}".format(i_large_std), fontsize=16)
        display(fig)
        plt.close(fig)

In [None]:
all_outs = feature_model(inputs)
all_outs = var_to_np(all_outs).squeeze()

for i_cluster in range(2):
    fig = plt.figure()
    plt.plot(var_to_np(stds_per_dim[i_cluster]))
    plt.plot(np.std(all_outs[y == i_cluster], axis=0))
    plt.legend(('Distribution', 'Outputs'))
    plt.title("Stds of dimensions in gaussian and in actual outputs", fontsize=18)
    display(fig)
    plt.close(fig)

for i_cluster in range(2):
    rec_examples, _ = get_inputs_from_reverted_samples(
         len(inputs), means_per_dim[i_cluster:i_cluster+1], stds_per_dim[i_cluster:i_cluster+1],
        np_to_var([1,]), feature_model,
        to_4d=False)

    fig, axes = plt.subplots(3,5, figsize=(20,9))

    for i_example, ax in enumerate(axes.flatten()):
        ax.imshow(rec_examples[i_example].squeeze(), vmin=0, vmax=1, cmap=cm.Greys)
    fig.suptitle("Reverted examples using gaussian mean/std", fontsize=18)
    display(fig)
    plt.close(fig)

for i_cluster in range(len(means_per_dim)):
    mean =  np_to_var(np.mean(all_outs[y == i_cluster], axis=0, keepdims=True), dtype=np.float32).cuda()
    std = np_to_var(np.std(all_outs[y == i_cluster], axis=0, keepdims=True), dtype=np.float32).cuda()
    rec_examples, _ = get_inputs_from_reverted_samples(
         len(inputs), mean, std,
        np_to_var([1,]), feature_model,
        to_4d=False)

    fig, axes = plt.subplots(3,5, figsize=(20,9))

    for i_example, ax in enumerate(axes.flatten()):
        ax.imshow(rec_examples[i_example].squeeze(), vmin=0, vmax=1, cmap=cm.Greys)
    fig.suptitle("Reverted examples using mean/std of outputs", fontsize=18)
    display(fig)
    plt.close(fig)

for i_cluster in range(2):
    stds = np.std(all_outs[y == i_cluster], axis=0)
    sorted_stds = np.argsort(stds)[::-1]
    for i_large_std in sorted_stds[:3]:
        stds_cloned = stds_per_dim.clone()
        stds_cloned = stds_cloned * 0
        stds_cloned[i_cluster,i_large_std] = float(stds[i_large_std])
        rec_examples, gauss_samples = get_inputs_from_reverted_samples(
            1000, means_per_dim[i_cluster:i_cluster+1],
            stds_cloned[i_cluster:i_cluster+1], np_to_var([1]), feature_model, to_4d=False)
        i_sort = np.argsort(var_to_np(gauss_samples)[:, i_large_std])
        sorted_examples = rec_examples[i_sort]
        sorted_examples = sorted_examples[::1000//20]
        fig = plot_sorted_examples(sorted_examples)
        fig.suptitle("Dimension {:d}".format(i_large_std), fontsize=16)
        display(fig)
        plt.close(fig)

In [None]:
# discrete sag

