In [None]:
"""GPU setup"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [None]:
"""Imports, define AE model"""
import tensorflow as tf
import numpy as np
tf.enable_eager_execution()

from matplotlib import pyplot as plt

from utils.data import mnist_eager
from utils.math import compute_mmd
from utils.models import gen_conv_mnist, gen_fc_mnist, enc_fc_mnist, enc_conv_mnist


# data
batch_size = 256
train_steps = 1500
dim_code = 2*16

data = mnist_eager("data/mnist_train", batch_size)


conv = False
if conv:
    encoder = enc_conv_mnist(dim_code, use_bn=True)
    decoder = gen_conv_mnist(use_bn=True)
else:
    encoder = enc_fc_mnist(dim_code, use_bn=True)
    decoder = gen_fc_mnist(use_bn=True)


def ae_loss(imgs, recon):
    return tf.losses.mean_squared_error(imgs, recon)
    #return tf.losses.absolute_difference(imgs, recon)
    #return tf.losses.sigmoid_cross_entropy(multi_class_labels=imgs, logits=recon)


def kl_loss(means, logvars):
    return 0.5 * tf.reduce_sum(tf.square(means) + tf.exp(logvars) - logvars - 1)


ae_opt = tf.train.AdamOptimizer()


In [None]:
"""Train AE"""
tf.keras.backend.set_learning_phase(1)
for step, (img_batch, _) in enumerate(data):
    if step > train_steps:
        break
    
    with tf.GradientTape() as tape:
        code = encoder(img_batch)
        means, logvars = tf.split(code, 2, axis=1)
        code_samples = tf.random_normal(tf.shape(means))*tf.sqrt(tf.exp(logvars)) + means

        recon = decoder(code)
        recon_loss = ae_loss(img_batch, recon)
        reg_loss = kl_loss(means, logvars)
        total_loss = recon_loss + reg_loss
    grads = tape.gradient(total_loss, encoder.variables + decoder.variables)
    ae_opt.apply_gradients(zip(grads, encoder.variables + decoder.variables))
    
    if not step % 50:
        print("Step", step)
        print("Loss", total_loss)

tf.keras.backend.set_learning_phase(0)

In [None]:
"""Check reconstructions"""
test_data = mnist_eager("data/mnist_test", batch_size, train=False)

for img_batch, _ in test_data:
    recon_batch = decoder(encoder(img_batch))
    for img, rec in zip(img_batch, recon_batch):
        compare = np.concatenate((img.numpy().reshape((32, 32)), rec.numpy().reshape((32, 32))), axis=1)
        plt.imshow(compare, cmap="Greys_r", vmin=0, vmax=1)
        plt.show()
        input()