In [None]:
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
%matplotlib inline
%config InlineBackend.figure_format = 'png'
matplotlib.rcParams['figure.figsize'] = (12.0, 4.0)
matplotlib.rcParams['font.size'] = 7

import matplotlib.lines as mlines
import seaborn
seaborn.set_style('darkgrid')
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('DEBUG')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.DEBUG, stream=sys.stdout)
seaborn.set_palette('colorblind')
import os

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np

from reversible.plot import create_bw_image, create_rgb_image


from PIL import Image
import numpy as np

In [None]:
# Your celeba 64x64 images
image_folder = '/data/schirrmr/schirrmr/celeba/CELEB_64/'
# For  Frechet Inception distance computation within this training
# https://drive.google.com/open?id=1H5dnLqe1UgwDZK3wewdqJvbl5pmSZxh1
# These are not the eventual Frechet Inception Distances reported in paper 
# as final results, they are based on https://github.com/bioinf-jku/TTUR
# in tensorflow
statistic_file = '/data/schirrmr/schirrmr/reversible-icml/fid_stats_celeb64_resize_normalize.npz'
# For saving the model
model_save_folder = '/data//schirrmr/schirrmr/reversible-icml/models/celeba/Only_Clamp/'

### Load CelebA in some way :)

In [None]:
log.setLevel('WARNING')
n_examples = 150000
images = []
for i_image in range(n_examples):
    images.append(np.array(Image.open(os.path.join(image_folder, 'celeb_{:d}.png'.format(
        i_image)))))
    if i_image % 1000 == 0:
        print("Loaded {:d}...".format(i_image))


In [None]:
log.setLevel('INFO')
x = np.array(images).astype(np.float32).transpose(0,3,1,2) /255.0
plt.imshow(x.transpose(0,2,3,1)[0], interpolation='nearest')

In [None]:
from reversible.util import np_to_var, var_to_np
inputs = np_to_var(x, dtype=np.float32)
# we put all faces into one class 
targets = np_to_var(np.ones((len(x), 1)), dtype=np.float32)
del x

## Create model and latent distribution parameters

In [None]:
from reversible.models import create_celebA_model
from reversible.revnet import SubsampleSplitter, ViewAs
from reversible.util import set_random_seeds
from torch.nn import ConstantPad2d
import torch as th
torch.backends.cudnn.benchmark = True
set_random_seeds(3049, True)
feature_model = create_celebA_model()
feature_model = feature_model.cuda()
n_dims = 4096*3
n_clusters = 1
# will be initialized properly later
means_per_cluster = [th.autograd.Variable(th.zeros(n_dims).cuda(), requires_grad=True)
                     for _ in range(n_clusters)]
stds_per_cluster = [th.autograd.Variable(th.ones(n_dims).cuda(), requires_grad=True)
                    for _ in range(n_clusters)]

In [None]:
outs = [var_to_np(feature_model(inputs[i:i+500].cuda()))  for i in range (0,len(inputs), 500)]

outs = np.concatenate(outs)
stds = np.std(outs, axis=0)
# 64 largest stds are active dims
n_wanted_stds = 64
i_active_dims = np.argsort(stds)[-n_wanted_stds:]
stds_per_cluster[0].data[:] = 0
for i_dim in i_active_dims:
    stds_per_cluster[0].data[i_active_dims] =  1

In [None]:
optimizer = th.optim.Adam(
                          [
    {'params': list(feature_model.parameters()),
    'lr': 0.0001,
    'weight_decay': 0},], betas=(0,0.9))

In [None]:
import time
from reversible.gaussian import get_gauss_samples
from reversible.revnet import invert 

def take_only_large_stds(l_out, std, n_wanted_stds):
    i_stds = th.sort(std)[1][-n_wanted_stds:]
    l_out = l_out.index_select(index=i_stds, dim=1)
    return l_out

def train_gen_on_batch(b_X, only_clamp, n_wanted_stds):
    start_time = time.time()
    result = {}
    i_class = 0
    mean = means_per_cluster[i_class]
    std = stds_per_cluster[i_class] * stds_per_cluster[i_class]
    outs_real = feature_model(b_X)
    
    # as a hack just set max min to +-200 std, will zero out the
    # dimensions that should be zero
    min_vals = (mean - 200 * std).unsqueeze(0)
    max_vals = (mean + 200 * std).unsqueeze(0)
    outs_clamped = th.min(th.max(outs_real,min_vals.detach()), max_vals.detach())
    
    out_clamp_loss = th.mean((outs_clamped - outs_real) ** 2) * 10
    inverted_clamped = invert(feature_model, outs_clamped)
    in_clamp_loss = th.mean(th.abs(b_X - inverted_clamped)) * 10
    
    out_clamp_loss = th.autograd.Variable(th.zeros(1).cuda())
    g_loss = in_clamp_loss + out_clamp_loss
    
    optimizer.zero_grad()
    g_loss.backward()
    all_params = [p for group in optimizer.param_groups for p in group['params']]
    grad_norm = th.nn.utils.clip_grad_norm_(all_params, 5, 2)
    optimizer.step()
    runtime = time.time() - start_time
    result['g_loss'] = var_to_np(g_loss)
    result['in_loss'] = var_to_np(in_clamp_loss)
    result['out_loss'] = var_to_np(out_clamp_loss)
    result['runtime_g'] = runtime
    result['grad_norm_g'] = grad_norm
    return result

from reversible.sliced import sliced_from_samples
# let's add sliced metric
def compute_sliced_dist_on_outputs(n_wanted_stds):
    # Sliced distance outputs/gaussian samples
    # normalized by sliced distance of two gaussian samples
    i_class = 0
    mean = means_per_cluster[i_class]
    std = stds_per_cluster[i_class] * stds_per_cluster[i_class]
    with th.no_grad():
        reduced_outs = [var_to_np(take_only_large_stds(
            feature_model(inputs[i:i+500].cuda()), stds_per_cluster[0] * stds_per_cluster[0], n_wanted_stds))
                for i in range(0, len(inputs) // 10, 500)]
    reduced_outs = np.concatenate(reduced_outs, axis=0)
    reduced_mean = take_only_large_stds(mean.unsqueeze(0),std,n_wanted_stds=n_wanted_stds).squeeze(0)
    reduced_std = take_only_large_stds(std.unsqueeze(0),std,n_wanted_stds=n_wanted_stds).squeeze(0)

    gauss_samples = get_gauss_samples(len(reduced_outs), reduced_mean, reduced_std)

    sliced_dist = sliced_from_samples(np_to_var(reduced_outs, dtype=np.float32).cuda(),
                        gauss_samples, n_dirs=2, adv_dirs=None, orthogonalize=True,
                        dist='sqw2')

    gauss_samples_2 = get_gauss_samples(len(reduced_outs), reduced_mean, reduced_std)
    sliced_ref = sliced_from_samples(gauss_samples_2,
                        gauss_samples, n_dirs=2, adv_dirs=None, orthogonalize=True,
                        dist='sqw2')
    sliced_rel = sliced_dist / sliced_ref
    return sliced_rel

## for inception distance

In [None]:
from reversible.fid_score import calculate_activation_statistics
from reversible.fid_score import calculate_frechet_distance
from reversible.inception import InceptionV3

# Load these statistics from 
# https://drive.google.com/open?id=1H5dnLqe1UgwDZK3wewdqJvbl5pmSZxh1
statistic_file = '/data/schirrmr/schirrmr/reversible-icml/fid_stats_celeb64_resize_normalize.npz'
ref_vals = np.load(statistic_file)
mu_ref = ref_vals['mu']
sig_ref = ref_vals['sig']

model = InceptionV3(resize_input=True, normalize_input=True)
model = model.cuda()

def generate_examples(n_examples):
    samples = get_gauss_samples(n_examples, means_per_cluster[0], stds_per_cluster[0],)
    examples = var_to_np(invert(feature_model, samples)).astype(np.float64)
    return examples

def calculate_current_fid():
    # based on 5000 samples
    examples = [generate_examples(500) for _ in range(10)]
    examples = np.concatenate(examples, axis=0)
    mu, sig = calculate_activation_statistics(np.clip(examples, a_min=0, a_max=1), model, cuda=True,
                                         verbose=True)
    fid = calculate_frechet_distance(mu, sig, mu_ref, sig_ref)
    return fid

In [None]:
calculate_current_fid()

In [None]:
from reversible.iterators import BalancedBatchSizeIterator

batch_size = 250
iterator = BalancedBatchSizeIterator(batch_size)

In [None]:
import pandas as pd
epochs_dataframe = pd.DataFrame()

In [None]:
from reversible.gaussian import get_gauss_samples
import pickle
for i_epoch in range(2000):
    start_time = time.time()
    g_results = []
    d_results = []
    for b_X, b_y in iterator.get_batches(inputs, targets,shuffle=True):
        b_X = b_X.cuda()
        g_result = train_gen_on_batch(b_X, only_clamp=True, n_wanted_stds=n_wanted_stds)
        g_results.append(g_result)
    
    # Logging, showing some samples, etc
    result =  {**pd.DataFrame(g_results).mean(), **pd.DataFrame(d_results).mean()}
    relative_sliced = compute_sliced_dist_on_outputs(n_wanted_stds)
    starttime_fid = time.time()
    fid = calculate_current_fid()
    runtime_fid = time.time() - starttime_fid
    
    epoch_time = time.time() - start_time
    result['runtime_fid'] = runtime_fid
    result['runtime'] = epoch_time
    result['sliced_rel'] = var_to_np(relative_sliced)

    result['fid'] = fid
    epochs_dataframe = epochs_dataframe.append(result, ignore_index=True)
    if i_epoch % 1 == 0:
        display(epochs_dataframe.iloc[-1:])
        fig = plt.figure(figsize=(8,3))
        df_copy = epochs_dataframe.copy()
        df_copy.fid = df_copy.fid / 10
        df_copy = df_copy.drop('runtime', axis=1)
        df_copy = df_copy.drop('runtime_fid', axis=1)
        df_copy.plot(ax=fig.gca())
        display(fig)
        plt.close(fig)
        i_class = 0
        mean = means_per_cluster[i_class]
        std = stds_per_cluster[i_class] * stds_per_cluster[i_class]
        outs_real = feature_model(inputs[:1000].cuda())
        fig = plt.figure(figsize=(8,3))
        plt.plot(var_to_np(stds_per_cluster[0]))
        plt.plot(var_to_np(th.std(outs_real, dim=0)))
        display(fig)
        plt.close(fig)
        
        i_std_1, i_std_2 = np.argsort(var_to_np(stds_per_cluster[0]))[::-1][:2]
        feature_a_values = th.linspace(float(mean[i_std_1].data - 2 * std[i_std_1].data),
                               float(mean[i_std_1].data + 2 * std[i_std_1].data), 9)
        feature_b_values = th.linspace(float(mean[i_std_2].data - 2 * std[i_std_2].data),
                               float(mean[i_std_2].data + 2 * std[i_std_2].data), 9)
        image_grid = np.zeros((len(feature_a_values), len(feature_b_values), 3, 64,64))

        for i_f_a_val, f_a_val in enumerate(feature_a_values):
            for i_f_b_val, f_b_val in enumerate(feature_b_values):
                this_out = mean.clone()
                this_out.data[i_std_1] = f_a_val
                this_out.data[i_std_2] = f_b_val
                inverted = var_to_np(invert(feature_model, this_out.unsqueeze(0))[0]).squeeze()

                image_grid[i_f_a_val, i_f_b_val] = np.copy(inverted)
        im = create_rgb_image(image_grid[::-1]).resize((6*100,6*100))
        display(im)
        
        
        samples = get_gauss_samples(40, means_per_cluster[0], stds_per_cluster[0],)

        inverted = var_to_np(invert(feature_model, samples)).astype(np.float64)

        inverted = inverted.reshape(5,8,3,64,64)
        im = create_rgb_image(inverted).resize((8*64,5*64))
        display(im)
    # Save model regularly
    if i_epoch % 30 == 0:
        folder =  os.path.join(model_save_folder, str(len(epochs_dataframe)))
        os.makedirs(folder, exist_ok=False)
        pickle.dump(epochs_dataframe, open(os.path.join(folder, 'epochs_df.pkl'), 'wb'))
        th.save(optimizer.state_dict(), os.path.join(folder, 'optim_dict.pkl'))
        th.save(feature_model.state_dict(), os.path.join(folder, 'model_dict.pkl'))
        th.save(means_per_cluster, os.path.join(folder, 'means.pkl'))
        th.save(stds_per_cluster, os.path.join(folder, 'stds.pkl'))
        log.info("Saved to {:s}".format(folder))