In [1]:
import os
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers

In [30]:
class WGAN:
    def __init__(self):
        (self.x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
        self.init_dim = 7
        self.strides = (2, 2, 1)
        self.data_shape = self.x_train.shape + (1,)

        # get basic inputs
        self.batch_size =  128
        self.noise_dim =  128
        self.total_epoch =  100
        self.critic_step =  1
        self.visualize = True
        self.out_path = os.getcwd()

        # storage for the objectives
        self.batch_num = int(self.data_shape[0] / self.batch_size) + (self.data_shape[0] % self.batch_size != 0)
        self.d_obj = np.zeros([self.batch_num, self.total_epoch, self.critic_step])
        self.g_obj = np.zeros([self.batch_num, self.total_epoch])

        # set regularization parameters
        self.grad_penalty =  10.0
        self.perturb_factor =  1.0

        # normalize dataset
        self.x_train = self.x_train.reshape(self.data_shape).astype('float32')
        self.x_train = (self.x_train - 127.5) / 127.5  # Normalize RGB to [-1, 1]
        self.x_train = \
            tf.data.Dataset.from_tensor_slices(self.x_train).shuffle(self.data_shape[0]).batch(self.batch_size)

        # setup optimizers
        self.D_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4,
                                                    beta_1=0.5,
                                                    beta_2= 0.999,
                                                    epsilon=1e-7,
                                                    amsgrad= False)
        self.G_optimizer = tf.keras.optimizers.Adam(learning_rate= 5e-5,
                                                    beta_1=0.2,
                                                    beta_2= 0.999,
                                                    epsilon= 1e-7,
                                                    amsgrad=False)

        # setup models
        self.G = self.set_generator()
        self.D = self.set_discriminator()

    def set_generator(self):
        g = tf.keras.Sequential()
        g.add(layers.Dense(self.init_dim * self.init_dim * 256, use_bias=False, input_shape=(self.noise_dim,)))
        g.add(layers.BatchNormalization())
        g.add(layers.LeakyReLU())
        g.add(layers.Reshape((self.init_dim, self.init_dim, 256)))

        g.add(layers.Conv2DTranspose(128, 5, strides=self.strides[0], padding='same', use_bias=False))
        g.add(layers.BatchNormalization())
        g.add(layers.LeakyReLU())

        g.add(layers.Conv2DTranspose(64, 5, strides=self.strides[1], padding='same', use_bias=False))
        g.add(layers.BatchNormalization())
        g.add(layers.LeakyReLU())

        g.add(layers.Conv2DTranspose(32, 5, strides=self.strides[2], padding='same', use_bias=False))
        g.add(layers.BatchNormalization())
        g.add(layers.LeakyReLU())

        g.add(layers.Conv2DTranspose(self.data_shape[3], 5, strides=self.strides[2],
                                     padding='same', use_bias=False, activation='tanh'))

        return g

    def set_discriminator(self):
        d = tf.keras.Sequential()
        d.add(layers.Conv2D(32, kernel_size=5, strides=2, padding='same', input_shape=self.data_shape[1:]))
        d.add(layers.LeakyReLU())

        d.add(layers.Conv2D(64, kernel_size=5, strides=2, padding='same'))
        d.add(layers.LayerNormalization())
        d.add(layers.LeakyReLU())

        d.add(layers.Conv2D(128, kernel_size=5, strides=2, padding='same'))
        d.add(layers.LayerNormalization())
        d.add(layers.LeakyReLU())

        d.add(layers.Flatten())
        d.add(layers.Dense(1))

        return d

    @tf.function
    def lipschitz_penalty(self, x, x_hat):
        # DRAGAN-like sampling scheme
        x_join = tf.concat([x, x_hat], axis=0)
        _, batch_var = tf.nn.moments(x_join, axes=[0, 1, 2, 3])
        delta = tf.random.normal(x_join.shape, stddev=self.perturb_factor * tf.sqrt(batch_var))
        x_tilde = x_join + delta

        # compute gradient penalty
        with tf.GradientTape() as D_tape:
            D_tape.watch(x_tilde)
            y_tilde = self.D(x_tilde)
        d_grad = D_tape.gradient(y_tilde, x_tilde)
        grad_norm = tf.sqrt(tf.reduce_sum(tf.square(d_grad), axis=[1, 2, 3]))

        return tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.0)))

    @tf.function
    def train_discriminator(self, x_batch):
        with tf.GradientTape() as D_tape:
            # sample data
            x_gen = self.G(tf.random.uniform([x_batch.shape[0], self.noise_dim]), training=True)

            # scoring with the discriminator
            y_real = self.D(x_batch, training=True)
            y_gen = self.D(x_gen, training=True)

            # compute the objective
            d_obj = tf.math.reduce_mean(y_gen) - tf.math.reduce_mean(y_real)
            d_obj_pen = d_obj + self.grad_penalty * self.lipschitz_penalty(x_batch, x_gen)
        # update the discriminator
        d_grad = D_tape.gradient(d_obj_pen, self.D.trainable_variables)
        self.D_optimizer.apply_gradients(zip(d_grad, self.D.trainable_variables))

        return d_obj

    @tf.function
    def train_generator(self, x_batch_size):
        with tf.GradientTape() as G_tape:
            x_gen = self.G(tf.random.uniform([x_batch_size, self.noise_dim]), training=True)
            y_gen = self.D(x_gen, training=True)

            # compute the objective
            g_obj = -tf.math.reduce_mean(y_gen)
        # update the generator
        g_grad = G_tape.gradient(g_obj, self.G.trainable_variables)
        self.G_optimizer.apply_gradients(zip(g_grad, self.G.trainable_variables))

        return g_obj

    def train(self):
        vis_seed = None
        if self.visualize:
            # Seed for checking training progress
            vis_seed = tf.random.uniform([16, self.noise_dim])

        # Record current time and start training
        print("Training...")
        ts_start = tf.timestamp()
        for t in range(self.total_epoch):
            batch_id = 0
            for b in self.x_train:
                for k in range(self.critic_step):
                    self.d_obj[batch_id, t, k] = self.train_discriminator(b)
                self.g_obj[batch_id, t] = self.train_generator(b.shape[0])
                batch_id += 1

            # Print time
            print("Time used for epoch {} are {:0.2f} seconds.".format(t + 1, tf.timestamp() - ts_start))

            # Check current generator
            if self.visualize:
                vis_gen = self.G(vis_seed, training=False)
                fig = plt.figure(figsize=(4, 4))
                plt.suptitle('Epoch: {:03d}'.format(t + 1))
                for i in range(vis_gen.shape[0]):
                    plt.subplot(4, 4, i + 1)
                    if self.data_shape[3] == 1:
                        plt.imshow(vis_gen[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
                    else:
                        plt.imshow((vis_gen[i, :, :] + 1) / 2)
                    plt.axis('off')
                plt.savefig(os.path.join(self.out_path, "WGAN_{}_Epoch_{:03d}.png".format("MNIST", t + 1)))
                plt.clf()
                plt.close(fig)
        print("Done! {:0.2f} seconds have passed.".format(tf.timestamp() - ts_start))

In [32]:
model = WGAN()

In [33]:
model.train()

Training...
Time used for epoch 1 are 1064.78 seconds.
Time used for epoch 2 are 2124.62 seconds.


KeyboardInterrupt: 

In [None]:
vis_seed = tf.random.uniform([100, model.noise_dim])
vis_gen = model.G(vis_seed, training=False)
if model_param.dataset == "MNIST":
    plt.figure(figsize=(3.45, 3.45))
else:
    plt.figure(figsize=(3.85, 3.85))
for i in range(vis_gen.shape[0]):
    x_pos = i % 10
    y_pos = int(i / 10)
    if model_param.dataset == "MNIST":
        plt.figimage(vis_gen[i, :, :, 0] * 127.5 + 127.5,
                     10 + x_pos * (28 + 5), 10 + y_pos * (28 + 5), cmap='gray')
    else:
        plt.figimage((vis_gen[i, :, :] + 1) / 2,
                     10 + x_pos * (32 + 5), 10 + y_pos * (32 + 5))
    plt.axis('off')
plt.savefig(os.path.join(model_param.output,
                         "{}_{}_Example.png".format(model_param.model, model_param.dataset)))

# plot median value of the objective functions
plt.figure()
plt.title("Objective Functions of {} (Dataset: {})".format(model_param.model, model_param.dataset))
plt.xlabel("Epoch")
plt.ylabel("Median Value")
plt.plot(range(1, 1 + model_param.total_epoch), np.median(model.d_obj, axis=[-0, -1]))
plt.plot(range(1, 1 + model_param.total_epoch), np.median(model.g_obj, axis=[-0]))
plt.legend(['Discriminator', 'Generator'])
plt.savefig(os.path.join(model_param.output,
                         "{}_{}_Objective.png".format(model_param.model, model_param.dataset)))