In [None]:
"""GPU setup"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
"""Imports, define AE model"""
import tensorflow as tf
import numpy as np

from utils.data import tfr_dataset_eager, parse_img_label_tfr
from utils.math import compute_mmd
from utils.models import gen_conv_mnist, gen_fc_mnist, enc_fc_mnist, enc_conv_mnist, wrap_sigmoid
from utils.viz import random_sample_grid, imshow, interpolate


# data
svhn = False
batch_size = 256
train_steps = 5000 if svhn else 1500
dim_code = 100

if svhn:
    train_files = ["/cache/tfrs/svhn_train.tfr", "/cache/tfrs/svhn_extra.tfr"]
    test_files = ["/cache/tfrs/svhn_test.tfr"]
    parse_fn = lambda x: parse_img_label_tfr(x, (32, 32, 3))
else:
    train_files = ["/cache/tfrs/mnist_train.tfr"]
    test_files = ["/cache/tfrs/mnist_test.tfr"]
    parse_fn = lambda x: parse_img_label_tfr(x, (32, 32, 1))

data = tfr_dataset_eager(train_files, batch_size, parse_fn, shufrep=600000 if svhn else 60000)


conv = False
if conv:
    encoder = enc_conv_mnist(dim_code, use_bn=True)
    decoder = gen_conv_mnist(use_bn=True, final_act=None, channels=3 if svhn else 1)
else:
    encoder = enc_fc_mnist(dim_code, use_bn=True)
    decoder = gen_fc_mnist(use_bn=True, final_act=None, channels=3 if svhn else 1)
sig_decoder = wrap_sigmoid(decoder)


ae_loss = tf.losses.BinaryCrossentropy(from_logits=True)

ae_opt = tf.optimizers.Adam()

In [None]:
"""Train AE"""
tf.keras.backend.set_learning_phase(1)

@tf.function
def train(batch):
    with tf.GradientTape() as tape:
        recon = decoder(encoder(batch))
        recon_loss = ae_loss(batch, recon)
    grads = tape.gradient(recon_loss, encoder.trainable_variables + decoder.trainable_variables)
    ae_opt.apply_gradients(zip(grads, encoder.trainable_variables + decoder.trainable_variables))
    return recon_loss


for step, (img_batch, _) in enumerate(data):
    if step > train_steps:
        break
    l = train(img_batch)

    #with tf.GradientTape() as tape:
    #    recon = decoder(encoder(img_batch))
    #    recon_loss = ae_loss(img_batch, recon)
    #grads = tape.gradient(recon_loss, encoder.trainable_variables + decoder.trainable_variables)
    #ae_opt.apply_gradients(zip(grads, encoder.trainable_variables + decoder.trainable_variables))

    if not step % 50:
        pass
        print("Step", step)
        print("Loss", l)

tf.keras.backend.set_learning_phase(0)

In [None]:
"""Check reconstructions"""
tf.keras.backend.set_learning_phase(0)

test_data = tfr_dataset_eager(test_files, batch_size, parse_fn)

for img_batch, _ in test_data:
    recon_batch = sig_decoder(encoder(img_batch))
    for img, rec in zip(img_batch, recon_batch):
        compare = np.concatenate((img, rec), axis=1)
        imshow(compare)
        input()

In [None]:
"""Create training set of codes"""
train_data_straight = tfr_dataset_eager(train_files, batch_size, parse_fn)

train_code_batches = []
for step, (img_batch, _) in enumerate(train_data_straight):
    train_code_batches.append(encoder(img_batch))
train_codes = np.concatenate(train_code_batches)

In [None]:
"""Set up data/model for training code GMMN"""
code_data = tf.data.Dataset.from_tensor_slices(train_codes)
code_data = code_data.apply(tf.data.experimental.shuffle_and_repeat(len(train_codes)))
code_data = code_data.batch(batch_size//2)

dim_noise = 10


def noise_fn(n_samples): return tf.random.uniform([n_samples, dim_noise], minval=-1, maxval=1)


code_gmmn = tf.keras.Sequential([tf.keras.layers.Dense(32, tf.nn.leaky_relu),
                                 tf.keras.layers.Dense(64, tf.nn.leaky_relu),
                                 tf.keras.layers.Dense(64, tf.nn.leaky_relu),
                                 tf.keras.layers.Dense(dim_code)])


def gmmn_loss(real, generated):
    return tf.sqrt(compute_mmd(real, generated, [0.03, 0.1, 0.3, 1., 3., 10.]))


gmmn_opt = tf.optimizers.Adam()

In [None]:
"""Check target loss"""
print("MMD for 15 pairs of real batches...")
data_iter = iter(code_data)
for step in range(15):
    mmd = gmmn_loss(next(data_iter), next(data_iter))
    print(mmd.numpy())

In [None]:
"""Train code GMMN"""

@tf.function
def train(batch):
    with tf.GradientTape() as tape:
        gen_batch = code_gmmn(noise_fn(tf.shape(batch)[0]))
        mmd = gmmn_loss(batch, gen_batch)
    grads = tape.gradient(mmd, code_gmmn.trainable_variables)
    gmmn_opt.apply_gradients(zip(grads, code_gmmn.trainable_variables))
    return mmd

tf.keras.backend.set_learning_phase(1)
for step, code_batch in enumerate(code_data):
    if step > train_steps:
        break
    
    mmd = train(code_batch)
    
    if not step % 50:
        print("Step", step)
        print("Loss", mmd)

tf.keras.backend.set_learning_phase(0)

In [None]:
"""Check what the code space looks like"""
print(train_codes.min(axis=0))
print(train_codes.max(axis=0))

In [None]:
"""Decodings directly from random code space entries"""
tf.keras.backend.set_learning_phase(0)

grid = random_sample_grid(sig_decoder, lambda bs: tf.random.uniform([bs, dim_code], minval=-3, maxval=3))

In [None]:
"""Interpolate between some AE codes"""
ind1 = np.random.choice(len(train_codes))
ind2 = np.random.choice(len(train_codes))
a_code = train_codes[ind1]
b_code = train_codes[ind2]

interpolate(a_code, b_code, gen=sig_decoder, method="slerp")

In [None]:
"""Decodings from random noise first converted to codes"""
grid = random_sample_grid(lambda x: sig_decoder(code_gmmn(x)), noise_fn)

In [None]:
"""Interpolation behavior in noise space"""
codes = noise_fn(2)
a_code = codes[0]
b_code = codes[1]
interpolate(a_code, b_code, gen=lambda x: sig_decoder(code_gmmn(x)), method="slerp")