In [None]:
"""Imports, define GAN"""
import tensorflow as tf
import tensorflow.keras.layers as layers

from utils.custom_layers import InstanceNormalization, LayerNormalization
from utils.data import tfr_dataset_eager, parse_img_label_tfr
from utils.math import compute_mmd
from utils.models import gen_conv_mnist_nn, gen_fc_mnist, enc_conv_mnist, enc_fc_mnist, add_norm
from utils.viz import random_sample_grid, imshow


# data
svhn = True
batch_size = 128  # this is a "half batch"!
train_steps = 25000 if svhn else 10000
dim_noise = 100
clip = 0.1  # main parameter to tune!
n_critic = 5
train_steps *= n_critic

if svhn:
    train_files = ["data/tfrs/svhn_train.tfr", "data/tfrs/svhn_extra.tfr"]
    test_files = ["data/tfrs/svhn_test.tfr"]
    parse_fn = lambda x: parse_img_label_tfr(x, (32, 32, 3))
else:
    train_files = ["data/tfrs/mnist_train.tfr"]
    test_files = ["data/tfrs/mnist_test.tfr"]
    parse_fn = lambda x: parse_img_label_tfr(x, (32, 32, 1))

data = tfr_dataset_eager(train_files, batch_size, parse_fn, shufrep=600000 if svhn else 60000)


class MBD(tf.keras.layers.Layer):
    def __init__(self, size_p, size_q):
        super(MBD, self).__init__()
        self.p = size_p
        self.q = size_q
        
    def build(self, inp_shape):
        self.t = self.add_weight(shape=(inp_shape[-1],) + (self.p, self.q),
                                 initializer="glorot_uniform",
                                 trainable=True)
        
    def call(self, inp):
        #print(inp.shape)  # b x d
        # t is d x p x q
        # we broadcast features over p and q dims, do pointwise multiplication and sum over the d dim :shrug:
        # result is a batch of matrices
        weird_mult = tf.reduce_sum(inp[:, :, tf.newaxis, tf.newaxis] * self.t, axis=1)  # b x p x q
        #print(weird_mult.shape)
        # broadcast to get a b x b x p x q tensor for all (absolute) matrix differences
        # then sum over the columns (q) to get differences between rows
        # result is b x b x p
        weird_diff = tf.exp(-tf.reduce_sum(tf.abs(weird_mult[tf.newaxis] - weird_mult[:, tf.newaxis]), axis=-1))
        #print(weird_diff.shape)
        # finally sum over all examples to arrive at b x p
        return tf.concat([inp, tf.reduce_sum(weird_diff, axis=1)], axis=1)
    
    
def enc_fc_mnist(final_dim, use_bn=True, h_act=layers.LeakyReLU, clip=None,
                 norm=layers.BatchNormalization):
    const = (lambda v: tf.clip_by_value(v, -clip, clip)) if clip else None
    seq = [layers.Flatten(),
           layers.Dense(512, kernel_constraint=const, bias_constraint=const,
                        use_bias=not use_bn),
           h_act(),
           layers.Dense(256, kernel_constraint=const, bias_constraint=const,
                        use_bias=not use_bn),
           h_act(),
           layers.Dense(128, kernel_constraint=const, bias_constraint=const,
                        use_bias=not use_bn),
           h_act(),
           MBD(16, 20),
           #MBSD(False),
           layers.Dense(final_dim, kernel_constraint=const,
                        bias_constraint=const)]
    if use_bn:
        seq = add_norm(seq, norm=norm)
    return tf.keras.Sequential(seq)


def enc_conv_mnist(final_dim, use_bn=True, h_act=layers.LeakyReLU, clip=None,
                   norm=layers.BatchNormalization):
    const = (lambda v: tf.clip_by_value(v, -clip, clip)) if clip else None
    seq = [layers.Conv2D(32, 4, padding="same", kernel_constraint=const,
                         bias_constraint=const, use_bias=not use_bn),
           h_act(),
           layers.AveragePooling2D(padding="same"),
           layers.Conv2D(64, 4, padding="same", kernel_constraint=const,
                         bias_constraint=const, use_bias=not use_bn),
           h_act(),
           layers.AveragePooling2D(padding="same"),
           layers.Conv2D(128, 4, padding="same", kernel_constraint=const,
                         bias_constraint=const, use_bias=not use_bn),
           h_act(),
           layers.AveragePooling2D(padding="same"),
           layers.Conv2D(256, 4, padding="same", kernel_constraint=const,
                         bias_constraint=const, use_bias=not use_bn),
           h_act(),
           layers.AveragePooling2D(padding="same"),
           #MBSD(True),
           layers.Flatten(),
           MBD(16, 20),
           layers.Dense(final_dim, kernel_constraint=const,
                        bias_constraint=const, use_bias=not use_bn)]
    if use_bn:
        seq = add_norm(seq, norm=norm)
    return tf.keras.Sequential(seq)


conv = True
if conv:
    discriminator = enc_conv_mnist(1, use_bn=True, clip=clip, norm=InstanceNormalization)
    generator = gen_conv_mnist_nn(use_bn=True, channels=3 if svhn else 1, norm=InstanceNormalization)
else:
    discriminator = enc_fc_mnist(1, use_bn=True, clip=clip, norm=LayerNormalization)
    generator = gen_fc_mnist(use_bn=True, channels=3 if svhn else 1, norm=LayerNormalization)


def noise_fn(n_samples): return tf.random.uniform([n_samples, dim_noise], minval=-1, maxval=1)


lr_g = tf.optimizers.schedules.PolynomialDecay(0.0003, train_steps, 0.00001)
lr_d = tf.optimizers.schedules.PolynomialDecay(0.0003, train_steps*n_critic, 0.00001)
gen_opt = tf.optimizers.RMSprop(learning_rate=lr_g)
disc_opt = tf.optimizers.RMSprop(learning_rate=lr_d)

In [None]:
"""Train WGAN"""

@tf.function
def train(batch, step):
    # prepare mixed batch for discriminator training
    # For batchnorm to work better, we feed only real images, then only 
    # generated ones and then average the gradients
    batch_dim = tf.shape(batch)[0]
    gen_batch = generator(noise_fn(batch_dim))
    with tf.GradientTape() as d_tape:
        d_loss_real = tf.reduce_mean(discriminator(batch))
        d_loss_fake = tf.reduce_mean(discriminator(gen_batch))
        d_loss = -d_loss_real + d_loss_fake
    d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
    disc_opt.apply_gradients(zip(d_grads, discriminator.trainable_variables))
    
    if (step + 1) % n_critic != 0:
        # fresh generated batch for generator training
        with tf.GradientTape(watch_accessed_variables=False) as g_tape:
            for vari in generator.trainable_variables:
                g_tape.watch(vari)
            gen_only_batch = generator(noise_fn(2*batch_dim))
            g_loss = -tf.reduce_mean(discriminator(gen_only_batch))
        g_grads = g_tape.gradient(g_loss, generator.trainable_variables)
        gen_opt.apply_gradients(zip(g_grads, generator.trainable_variables))
        
    return d_loss
    

tf.keras.backend.set_learning_phase(1)
for step, (img_batch, _) in enumerate(data):
    if step > train_steps:
        break

    d_loss = train(img_batch, tf.constant(step, dtype=tf.int32))
    
    if not step % (250*n_critic):
        print("Step", step)
        print("Disc Loss", d_loss)
        tf.keras.backend.set_learning_phase(0)
        grid = random_sample_grid(generator, noise_fn)
        tf.keras.backend.set_learning_phase(1)

tf.keras.backend.set_learning_phase(0)

In [None]:
"""Generate samples"""
tf.keras.backend.set_learning_phase(0)

import tensorflow_probability as tfp
tfd = tfp.distributions
tdist = tfd.TruncatedNormal(0., 1., -0.5, 0.5)

t_noise_fn = lambda x: tdist.sample([x, dim_noise])
grid = random_sample_grid(generator, t_noise_fn)