In [None]:
import tensorflow as tf
import tensorflow.keras.layers as layers

from utils.custom_layers import InstanceNormalization, LayerNormalization
from utils.data import tfr_dataset_eager, parse_img_label_tfr
from utils.models import model_up_to, add_norm
from utils.viz import random_sample_grid, interpolate

In [None]:
# data
#svhn = False  # set True to train on svhn, else MNIST
cifer = True   #  set True to train on cifar10, else MNIST
batch_size = 128  # this is a "half batch"!
train_steps = 25000 if cifer else 10000
dim_noise = 100  # tune
label_smoothing = 0.9  # tune

if cifer:
    train_files = ["data/tfrs/cifar10_train.tfr"]
    test_files = ["data/tfrs/cifar10_test.tfr"]
    parse_fn = lambda x: parse_img_label_tfr(x, (32, 32, 3))
else:
    train_files = ["/cache/tfrs/mnist_train.tfr"]
    test_files = ["/cache/tfrs/mnist_test.tfr"]
    parse_fn = lambda x: parse_img_label_tfr(x, (32, 32, 1))

data = tfr_dataset_eager(train_files, batch_size, parse_fn, shufrep=600000 if cifer else 60000)

In [None]:
# feature matching loss for generator (from improved techniques for training GANs)
def feature_match(real, gened):
    real_feats = disc_part(real)
    gen_feats = disc_part(gened)
    return tf.norm(tf.reduce_mean(real_feats, axis=0) - tf.reduce_mean(gen_feats, axis=0))

# FC discriminator; can comment out the MBD layer or uncomment the MBSD layer if desired
def enc_fc_mnist(final_dim, use_bn=True, h_act=layers.LeakyReLU, clip=None,
                 norm=layers.BatchNormalization):
    const = (lambda v: tf.clip_by_value(v, -clip, clip)) if clip else None
    seq = [layers.Flatten(),
           layers.Dense(512, kernel_constraint=const, bias_constraint=const,
                        use_bias=not use_bn),
           h_act(),
           layers.Dense(256, kernel_constraint=const, bias_constraint=const,
                        use_bias=not use_bn),
           h_act(),
           layers.Dense(128, kernel_constraint=const, bias_constraint=const,
                        use_bias=not use_bn),
           h_act(),
           #MBD(16, 20),
           #MBSD(False),
           layers.Dense(final_dim, kernel_constraint=const,
                        bias_constraint=const)]
    if use_bn:
        seq = add_norm(seq, norm=norm)
    return tf.keras.Sequential(seq)

# convolutional discriminator; once again, can use MBD or MBSD or neither
def enc_conv_mnist(final_dim, use_bn=True, h_act=layers.LeakyReLU, clip=None,
                   norm=layers.BatchNormalization):
    const = (lambda v: tf.clip_by_value(v, -clip, clip)) if clip else None
    seq = [layers.Conv2D(32, 4, padding="same", kernel_constraint=const,
                         bias_constraint=const, use_bias=not use_bn),
           h_act(),
           layers.AveragePooling2D(padding="same"),
           layers.Conv2D(64, 4, padding="same", kernel_constraint=const,
                         bias_constraint=const, use_bias=not use_bn),
           h_act(),
           layers.AveragePooling2D(padding="same"),
           layers.Conv2D(128, 4, padding="same", kernel_constraint=const,
                         bias_constraint=const, use_bias=not use_bn),
           h_act(),
           layers.AveragePooling2D(padding="same"),
           layers.Conv2D(256, 4, padding="same", kernel_constraint=const,
                         bias_constraint=const, use_bias=not use_bn),
           h_act(),
           layers.AveragePooling2D(padding="same"),
           #MBSD(True),
           layers.Flatten(),
           #MBD(16, 20),
           layers.Dense(final_dim, kernel_constraint=const,
                        bias_constraint=const, use_bias=not use_bn)]
    if use_bn:
        seq = add_norm(seq, norm=norm)
    return tf.keras.Sequential(seq)

def gen_fc_mnist(use_bn=True, h_act=layers.LeakyReLU, final_act=tf.nn.sigmoid,
                 norm=layers.BatchNormalization, channels=1):
    seq = [layers.Dense(256, use_bias=not use_bn),
           h_act(),
           layers.Dense(512, use_bias=not use_bn),
           h_act(),
           layers.Dense(1024, use_bias=not use_bn),
           h_act(),
           layers.Dense(32*32*channels, final_act),
           layers.Reshape((32, 32, channels))]
    if use_bn:
        seq = add_norm(seq, up_to=-2, norm=norm)
    return tf.keras.Sequential(seq)

def gen_conv_mnist_nn(use_bn=True, h_act=layers.LeakyReLU,
                      final_act=tf.nn.sigmoid, norm=layers.BatchNormalization,
                      channels=1):
    seq = [layers.Lambda(lambda x: x[:, tf.newaxis, tf.newaxis, :]),
           layers.UpSampling2D(),
           layers.Conv2D(256, 4, padding="same", use_bias=not use_bn),
           h_act(),
           layers.UpSampling2D(),
           layers.Conv2D(128, 4, padding="same", use_bias=not use_bn),
           h_act(),
           layers.UpSampling2D(),
           layers.Conv2D(64, 4, padding="same", use_bias=not use_bn),
           h_act(),
           layers.UpSampling2D(),
           layers.Conv2D(32, 4, padding="same", use_bias=not use_bn),
           h_act(),
           layers.UpSampling2D(),
           layers.Conv2D(channels, 4, padding="same", activation=final_act)]
    if use_bn:
        seq = add_norm(seq, norm=norm)
    return tf.keras.Sequential(seq)

def noise_fn(n_samples): 
    return tf.random.uniform([n_samples, dim_noise], minval=-1, maxval=1)

In [None]:
conv = True
noise_scale = 2. * 1/256  # add noise to discriminator inputs

if conv:
    discriminator = enc_conv_mnist(1, use_bn=True, norm=InstanceNormalization)
    generator = gen_conv_mnist_nn(use_bn=True, channels=3 if cifer else 1, norm=InstanceNormalization)
else:
    discriminator = enc_fc_mnist(1, use_bn=True, norm=layers.LayerNormalization)
    generator = gen_fc_mnist(use_bn=True, channels=3 if cifer else 1, norm=layers.LayerNormalization)

In [None]:
# used for feature matching -- hardcoded -- bad.
# basically builds a model that incudes the first N layers.
# right now n=16.
# but this will be different for convolutional and MLP discriminator...
# note that layer index includes conv, pooling, activations.
# easiest way is to check model.layers and decide up to which index one would like
# to run it.
disc_part = model_up_to(discriminator, 16)


loss = tf.losses.BinaryCrossentropy(from_logits=True)
lr = tf.optimizers.schedules.PolynomialDecay(0.0003, train_steps, 0.00001)

gen_opt = tf.optimizers.Adam(learning_rate=lr, beta_1=0.5)
disc_opt = tf.optimizers.Adam(learning_rate=lr, beta_1=0.5)

In [None]:
"""Train GAN"""

@tf.function
def train(batch):
    batch_dim = tf.shape(batch)[0]
    
    # prepare mixed batch for discriminator training
    # For batchnorm to work better, we feed only real images, then only 
    # generated ones and then average the loss
    gen_batch = generator(noise_fn(batch_dim))
    real_labels = label_smoothing*tf.ones([batch_dim, 1])
    gen_labels = tf.zeros([batch_dim, 1])
    with tf.GradientTape() as d_tape:
        d_loss_real = loss(real_labels, discriminator(batch + tf.random.normal(tf.shape(batch), stddev=noise_scale)))
        d_loss_fake = loss(gen_labels, discriminator(gen_batch + tf.random.normal(tf.shape(gen_batch), stddev=noise_scale)))
        d_loss = 0.5 * (d_loss_real + d_loss_fake)
    d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
    disc_opt.apply_gradients(zip(d_grads, discriminator.trainable_variables))
    
    # fresh generated batch for generator training
    with tf.GradientTape(watch_accessed_variables=False) as g_tape:
        for vari in generator.trainable_variables:
            g_tape.watch(vari)
        gen_only_batch = generator(noise_fn(2*batch_dim))
        g_loss = loss(label_smoothing*tf.ones([2*batch_dim, 1]),
                      discriminator(gen_only_batch + tf.random.normal(tf.shape(gen_only_batch), stddev=noise_scale)))
        #g_loss = feature_match(batch, gen_only_batch)
    g_grads = g_tape.gradient(g_loss, generator.trainable_variables)
    gen_opt.apply_gradients(zip(g_grads, generator.trainable_variables))
    
    return g_loss, d_loss
    

tf.keras.backend.set_learning_phase(1)
for step, (img_batch, _) in enumerate(data):
    if step > train_steps:
        break

    g_loss, d_loss = train(img_batch)
    
    if not step % 250:
        print("Step", step)
        print("Gen Loss", g_loss)
        print("Disc Loss", d_loss)
        tf.keras.backend.set_learning_phase(0)
        grid = random_sample_grid(generator, noise_fn)
        tf.keras.backend.set_learning_phase(1)

tf.keras.backend.set_learning_phase(0)

In [None]:
"""Generate samples"""
tf.keras.backend.set_learning_phase(0)

# use the "truncation trick" for sampling
# supposedly increases sample quality
import tensorflow_probability as tfp
tfd = tfp.distributions
tdist = tfd.TruncatedNormal(0., 1., -0.5, 0.5)

t_noise_fn = lambda x: tdist.sample([x, dim_noise])
grid = random_sample_grid(generator, t_noise_fn)

In [None]:
"""Interpolation behavior"""
codes = noise_fn(2)
a_code = codes[0]
b_code = codes[1]
interpolate(a_code, b_code, gen=generator, method="slerp", figsize=(3,3))