In [None]:
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 4.0)
matplotlib.rcParams['font.size'] = 7

import matplotlib.lines as mlines
import seaborn
seaborn.set_style('darkgrid')
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('DEBUG')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.DEBUG, stream=sys.stdout)
seaborn.set_palette('colorblind')

import os

import itertools
from reversible.plot import create_bw_image

In [None]:
## from  http://deeplearning.net/data/mnist/mnist.pkl.gz I assume?
mnist_folder = 'data/mnist/mnist.pkl.gz'
# For saving the model
model_save_folder = '/data//schirrmr/schirrmr/reversible-icml/models/mnist/OptimalTransportPerClassClampedStd/'

In [None]:
import pickle
import gzip
train, val, test = pickle.load(gzip.open(mnist_folder), encoding='bytes')

X_train, y_train = train
X_val, y_val = val

X_train_topo = X_train.reshape(X_train.shape[0], 1, 28,28)
X_val_topo = X_val.reshape(X_val.shape[0], 1, 28,28)
from numpy.random import RandomState


In [None]:
mask = y_train < 10 # all, can use this to only take a subset of classes
x = X_train_topo[mask]#[:1000]
y = y_train[mask]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np

In [None]:
from reversible.util import set_random_seeds
from reversible.revnet import init_model_params
import torch as th
from reversible.models import create_MNIST_model
th.backends.cudnn.benchmark = True
set_random_seeds(34093049, True)
feature_model = create_MNIST_model()

feature_model = feature_model.cuda()
init_model_params(feature_model, 1)

n_dims = 1024#int(np.prod(x.shape[1:])) 
n_clusters = int(len(np.unique(y)))

# will be initialized properly later
means_per_cluster = [th.autograd.Variable(th.zeros(n_dims).cuda(), requires_grad=True)
                     for _ in range(n_clusters)]
stds_per_cluster = [th.autograd.Variable(th.ones(n_dims).cuda(), requires_grad=True)
                    for _ in range(n_clusters)]


In [None]:
from reversible.training import hard_init_std_mean
from reversible.util import np_to_var

inputs = np_to_var(x, dtype=np.float32).cuda()
targets = np_to_var(np.array([y == i for i in range(len(np.unique(y)))]).T, dtype=np.float32).cuda()


hard_init_std_mean(means_per_cluster, stds_per_cluster, feature_model, inputs[:10000], targets[:10000], )

optimizer = th.optim.Adam([
    {'params': list(feature_model.parameters()) + 
     means_per_cluster + stds_per_cluster,
    'lr': 0.001},],)

In [None]:
from reversible.iterators import BalancedBatchSizeIterator

batch_size = 650
iterator = BalancedBatchSizeIterator(batch_size,)

In [None]:
from reversible.revnet import invert
from reversible.gaussian import get_gauss_samples
def reconstruct_loss(feature_model, inputs,outputs,):
    inputs = invert(feature_model, outputs)
    perturbation = get_gauss_samples(
        len(outputs), th.zeros_like(outputs[0]).detach(), th.ones_like(outputs[0].detach()) * 0.01)
    o_perturbed = outputs + perturbation
    inputs_perturbed = invert(feature_model, o_perturbed)
    diffs = inputs - inputs_perturbed
    loss = th.mean(diffs * diffs)
    loss += th.mean(th.abs(diffs))
    return loss

In [None]:
import ot
from reversible.util import ensure_on_same_device, np_to_var, var_to_np
def ot_euclidean_loss(outs, mean, std):
    gauss_samples = get_gauss_samples(len(outs), mean, std)

    diffs = outs.unsqueeze(1) - gauss_samples.unsqueeze(0)
    del gauss_samples
    diffs = th.sqrt(th.clamp(th.sum(diffs * diffs, dim=2), min=1e-6))

    transport_mat = ot.emd([],[], var_to_np(diffs))
    # sometimes weird low values, try to prevent them
    transport_mat = transport_mat * (transport_mat > (1.0/(diffs.numel())))

    transport_mat = np_to_var(transport_mat, dtype=np.float32)
    diffs, transport_mat = ensure_on_same_device(diffs, transport_mat)
    loss = th.sum(transport_mat * diffs)
    return loss

In [None]:
import time
def train_one_epoch():
    start_time = time.time()
    b_gens = [iterator.get_batches(inputs[targets[:,i_cluster] == 1],
                                  targets[targets[:,i_cluster] == 1], shuffle=True)
                for i_cluster in range(len(means_per_cluster))]

    more_batches = True
    rec_losses = []
    ot_losses = []
    losses = []
    while more_batches:
        optimizer.zero_grad()
        for i_cluster in range(len(b_gens)):
            b_gen = b_gens[i_cluster]
            try:
                b_X, b_y = next(b_gen)
                outs = feature_model(b_X)
                rec_loss = reconstruct_loss(feature_model, b_X, outs,)
                ot_loss = ot_euclidean_loss(outs, means_per_cluster[i_cluster], stds_per_cluster[i_cluster])
                loss = rec_loss * 15 + ot_loss
                loss.backward()
                rec_losses.append(var_to_np(rec_loss))
                ot_losses.append(var_to_np(ot_loss))
                losses.append(var_to_np(loss))
            except StopIteration:
                more_batches = False
        optimizer.step()
        for i_cluster in range(len(stds_per_cluster)):
            stds_per_cluster[i_cluster].data.clamp_(min=0)
            runtime = time.time() - start_time
    return {'rec_loss': np.mean(rec_losses),
           'ot_losses': np.mean(ot_losses),
           'loss': np.mean(losses),
           'runtime': runtime}

In [None]:
import pandas as pd
epochs_dataframe = pd.DataFrame()

In [None]:
from reversible.util import var_to_np
rng = RandomState(1)
for i_epoch in range(100001):
    feature_model.train()
    result = train_one_epoch()
    feature_model.eval()
    epochs_dataframe = epochs_dataframe.append(result, ignore_index=True)
    if i_epoch % 10 == 0:
        display(epochs_dataframe.iloc[-1:])
    if i_epoch % 10 == 0:
        
        all_outs = feature_model(inputs[:5000])
        all_outs = var_to_np(all_outs).squeeze()
        
        for i_cluster in range(len(means_per_cluster)):
            fig = plt.figure()
            plt.plot(var_to_np(stds_per_cluster[i_cluster]))
            plt.plot(np.std(all_outs[y[:5000] == i_cluster], axis=0))
            plt.legend(('Distribution', 'Outputs'))
            plt.title("Stds of dimensions in gaussian and in actual outputs", fontsize=18)
            display(fig)
            plt.close(fig)
        for i_class in range(len(means_per_cluster)):
            samples = get_gauss_samples(3*8, means_per_cluster[i_class], stds_per_cluster[i_class],)

            inverted = var_to_np(invert(feature_model, samples)).astype(np.float64)
            inverted = inverted.reshape(3,8,28,28)
            im = create_bw_image(inverted).resize((4*100,int(1.5*100)))
            display(im)
        
        
        
        for i_cluster in range(len(means_per_cluster)):
            mean = means_per_cluster[i_cluster]
            std = stds_per_cluster[i_cluster]
            i_feature_a, i_feature_b = th.sort(std)[1][-2:]
            feature_a_values = th.linspace(float(mean[i_feature_a].data - 2 * std[i_feature_a].data),
                                           float(mean[i_feature_a].data + 2 * std[i_feature_a].data), 8)
            feature_b_values = th.linspace(float(mean[i_feature_b].data - 2 * std[i_feature_b].data),
                                           float(mean[i_feature_b].data + 2 * std[i_feature_b].data), 8)

            image_grid = np.zeros((len(feature_a_values), len(feature_b_values), 28,28))

            for i_f_a_val, f_a_val in enumerate(feature_a_values):
                for i_f_b_val, f_b_val in enumerate(feature_b_values):
                    this_out = mean.clone()
                    this_out.data[i_feature_a.data] = f_a_val
                    this_out.data[i_feature_b.data] = f_b_val
                    inverted = var_to_np(invert(feature_model, this_out.unsqueeze(0))[0]).squeeze()

                    image_grid[i_f_a_val, i_f_b_val] = np.copy(inverted)
            im = create_bw_image(image_grid).resize((4*100,4*100))
            display(im)
    if i_epoch % 30 == 0:
        folder =  os.path.join(model_save_folder, str(len(epochs_dataframe)))
        os.makedirs(folder, exist_ok=False)
        epochs_dataframe.to_csv(os.path.join(folder, 'epochs_df.csv'))
        th.save(optimizer.state_dict(), os.path.join(folder, 'optim_dict.pkl'))
        th.save(feature_model.state_dict(), os.path.join(folder, 'model_dict.pkl'))
        th.save(means_per_cluster, os.path.join(folder, 'means.pkl'))
        th.save(stds_per_cluster, os.path.join(folder, 'stds.pkl'))
        log.info("Saved to {:s}".format(folder))