In [None]:
"""GPU setup"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [None]:
"""Imports, define AE model"""
import tensorflow as tf
tf.enable_eager_execution()
import numpy as np

from utils.data import mnist_eager
from utils.math import compute_mmd
from utils.models import gen_conv_mnist, gen_fc_mnist, enc_fc_mnist, enc_conv_mnist
from utils.viz import imshow, interpolate, random_sample_grid


# data
batch_size = 256
train_steps = 1500
dim_code = 2*16

data = mnist_eager("data/mnist_train", batch_size)


conv = False
if conv:
    encoder = enc_conv_mnist(dim_code, use_bn=True)
    decoder = gen_conv_mnist(use_bn=True, final_act=None)
else:
    encoder = enc_fc_mnist(dim_code, use_bn=True)
    decoder = gen_fc_mnist(use_bn=True, final_act=None)


def noise_fn(n_samples):
    return tf.random_normal([n_samples, dim_code//2])


def ae_loss(imgs, recon):
    #return tf.losses.mean_squared_error(imgs, recon)
    #return tf.losses.absolute_difference(imgs, recon)
    return tf.losses.sigmoid_cross_entropy(multi_class_labels=imgs, logits=recon)


def kl_loss(means, logvars):
    return 0.5 * tf.reduce_sum(tf.square(means) + tf.exp(logvars) - logvars - 1)


def mmd_loss(target_samples, generated):
    return tf.sqrt(compute_mmd(target_samples, generated, [0.03, 0.1, 0.3, 1., 3., 10.]))


ae_opt = tf.train.AdamOptimizer()

In [None]:
"""Train AE"""
tf.keras.backend.set_learning_phase(1)
for step, (img_batch, _) in enumerate(data):
    if step > train_steps:
        break
    
    with tf.GradientTape() as tape:
        code = encoder(img_batch)
        means, logvars = tf.split(code, 2, axis=1)
        code_samples = noise_fn(tf.shape(means)[0]) * tf.sqrt(tf.exp(logvars)) + means

        recon = decoder(code_samples)
        recon_loss = ae_loss(img_batch, recon)
        reg_loss = kl_loss(means, logvars)
        total_loss = recon_loss + reg_loss
    grads = tape.gradient(total_loss, encoder.trainable_variables + decoder.trainable_variables)
    ae_opt.apply_gradients(zip(grads, encoder.trainable_variables + decoder.trainable_variables))
    
    if not step % 50:
        print("Step", step)
        print("Loss", total_loss, recon_loss, reg_loss)
        print("Mean mean/var", tf.reduce_mean(means), tf.reduce_mean(tf.exp(logvars)))

tf.keras.backend.set_learning_phase(0)

In [None]:
"""Check reconstructions"""
test_data = mnist_eager("data/mnist_test", batch_size, train=False)

for img_batch, _ in test_data:
    # note this just uses the mean encodings
    recon_batch = decoder(tf.split(encoder(img_batch), 2, axis=1)[0])
    for img, rec in zip(img_batch, recon_batch):
        compare = np.concatenate((img.numpy().reshape((32, 32)), rec.numpy().reshape((32, 32))), axis=1)
        imshow(compare)
        input()

In [None]:
"""Create training set of codes"""
train_data_straight = mnist_eager("data/mnist_train", batch_size, train=False)

train_code_mean_batches = []
train_code_var_batches = []
for step, (img_batch, _) in enumerate(train_data_straight):
    means, logvars = tf.split(encoder(img_batch), 2, axis=1)
    train_code_mean_batches.append(means)
    train_code_var_batches.append(tf.exp(logvars))
train_code_means = np.concatenate(train_code_mean_batches)
train_code_vars = np.concatenate(train_code_var_batches)

In [None]:
print(train_code_means.min(), train_code_means.max())
print(train_code_vars.min(), train_code_vars.max())

In [None]:
"""Interpolate between some codes"""
ind1 = np.random.choice(60000)
ind2 = np.random.choice(60000)
a_code = train_code_means[ind1]
b_code = train_code_means[ind2]

interpolate(a_code, b_code, gen=decoder, method="linear")

In [None]:
"""Generate reconstructions from some random code"""
grid = random_sample_grid(decoder, noise_fn)