In [None]:
"""GPU setup"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
"""Imports, define AE model"""
import tensorflow as tf
import numpy as np
tf.enable_eager_execution()

from matplotlib import pyplot as plt

from utils.data import mnist_eager
from utils.math import compute_mmd
from utils.models import gen_conv_mnist, gen_fc_mnist, enc_fc_mnist, enc_conv_mnist


# data
batch_size = 256
train_steps = 1500
dim_noise = 10
dim_code = 16

data = mnist_eager("data/mnist_train", batch_size)


conv = False
if conv:
    encoder = enc_conv_mnist(dim_code, use_bn=True)
    decoder = gen_conv_mnist(use_bn=True)
else:
    encoder = enc_fc_mnist(dim_code, use_bn=True)
    decoder = gen_fc_mnist(use_bn=True)


def ae_loss(imgs):
    recon = decoder(encoder(imgs))
    return tf.losses.mean_squared_error(imgs, recon)
    #return tf.losses.absolute_difference(imgs, recon)
    #return tf.losses.sigmoid_cross_entropy(multi_class_labels=imgs, logits=recon)


ae_opt = tf.train.AdamOptimizer()


In [None]:
"""Train AE"""
tf.keras.backend.set_learning_phase(1)
for step, (img_batch, _) in enumerate(data):
    if step > train_steps:
        break
    
    with tf.GradientTape() as tape:
        recon_loss = ae_loss(img_batch)
    grads = tape.gradient(recon_loss, encoder.variables + decoder.variables)
    ae_opt.apply_gradients(zip(grads, encoder.variables + decoder.variables))
    
    if not step % 50:
        print("Step", step)
        print("Loss", recon_loss)

tf.keras.backend.set_learning_phase(0)

In [None]:
"""Check reconstructions"""
test_data = mnist_eager("data/mnist_test", batch_size, train=False)

for img_batch, _ in test_data:
    recon_batch = decoder(encoder(img_batch))
    for img, rec in zip(img_batch, recon_batch):
        compare = np.concatenate((img.numpy().reshape((32, 32)), rec.numpy().reshape((32, 32))), axis=1)
        plt.imshow(compare, cmap="Greys_r", vmin=0, vmax=1)
        plt.show()
        input()

In [None]:
"""Create training set of codes"""
train_data_straight = mnist_eager("data/mnist_train", batch_size, train=False)

train_code_batches = []
for step, (img_batch, _) in enumerate(train_data_straight):
    train_code_batches.append(encoder(img_batch))
train_codes = np.concatenate(train_code_batches)

In [None]:
"""Interpolate between some codes"""
ind1 = 124
ind2 = 15645
granul = 20
a_code = train_codes[ind1:(ind1+1)]
b_code = train_codes[ind2:(ind2+1)]

for interp in np.linspace(0, 1, granul):
    p_code = interp*b_code + (1-interp)*a_code
    p_rec = decoder(p_code)
    plt.imshow(p_rec.numpy().reshape((32, 32)), cmap="Greys_r", vmin=0, vmax=1)
    plt.show()


In [None]:
"""Set up data/model for training code GMMN"""
code_data = tf.data.Dataset.from_tensor_slices(train_codes)
code_data = code_data.apply(tf.data.experimental.shuffle_and_repeat(60000))
code_data = code_data.batch(batch_size//2)


def noise_fn(shape): return tf.random_uniform(shape, minval=-1, maxval=1)


code_gmmn = tf.keras.Sequential([tf.keras.layers.Dense(32, tf.nn.leaky_relu),
                                 tf.keras.layers.Dense(64, tf.nn.leaky_relu),
                                 tf.keras.layers.Dense(64, tf.nn.leaky_relu),
                                 tf.keras.layers.Dense(dim_code)])


def gmmn_loss(code):
    generated = code_gmmn(noise_fn((tf.shape(code)[0], dim_noise)))
    return tf.sqrt(compute_mmd(code, generated, [0.03, 0.1, 0.3, 1., 3., 10.]))


gmmn_opt = tf.train.AdamOptimizer()


In [None]:
"""Train code GMMN"""
for step, code_batch in enumerate(code_data):
    if step > train_steps:
        break
    
    with tf.GradientTape() as tape:
        mmd = gmmn_loss(code_batch)
    grads = tape.gradient(mmd, code_gmmn.variables)
    gmmn_opt.apply_gradients(zip(grads, code_gmmn.variables))
    
    if not step % 50:
        print("Step", step)
        print("Loss", mmd)

In [None]:
"""Check what the code space looks like"""
print(train_codes.min(axis=0))
print(train_codes.max(axis=0))

In [None]:
"""Generate reconstructions from some random code"""
for _ in range(20):
    randcode = tf.random_uniform([1, dim_code], minval=-15, maxval=15)
    randrec = decoder(randcode)
    plt.imshow(randrec.numpy().reshape((32, 32)), cmap="Greys_r", vmin=0, vmax=1)
    plt.show()

In [None]:
"""Generate reconstructions from GMMN noise"""
for _ in range(20):
    randnoise = noise_fn([1, dim_noise])
    randcode = code_gmmn(randnoise)
    randrec = decoder(randcode)
    plt.imshow(randrec.numpy().reshape((32, 32)), cmap="Greys_r", vmin=0, vmax=1)
    plt.show()