In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from IPython import display
from time import perf_counter
from tqdm.notebook import tqdm
import os

# preserve threads for GPU
os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'

# constrain VRAM usage
gpus = tf.config.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

# enable mixed-precision training
keras.mixed_precision.set_global_policy('mixed_float16')

In [None]:
def test_model(model, length):
    predictions = (model.random_generate(length**2).numpy().reshape(-1, 28, 28) + 1.) / 2
    plt.figure(figsize=(length, length))
    for i in range(length**2):
        plt.subplot(length, length, i+1)
        plt.imshow(predictions[i], cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.show()

def show_history(history):
    plt.plot(history['disc_loss'], label='Discriminator Loss')
    plt.plot(history['gen_loss'], label='Generator Loss')
    plt.title('Learning Curve')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

In [None]:
class DCGAN:
    def __init__(self, disc, gen, disc_opt, gen_opt, z_dim=64):
        self.disc = disc
        self.gen = gen
        self.disc_opt = disc_opt
        self.gen_opt = gen_opt
        self.z_dim = z_dim
        self.BCE_loss = keras.losses.BinaryCrossentropy(from_logits=True)
        self.history = {'disc_loss':[], 'gen_loss':[]}
        self.disc_mean_loss = keras.metrics.Mean(name='disc_loss')
        self.gen_mean_loss = keras.metrics.Mean(name='gen_loss')
    
    def random_generate(self, n_images):
        noise = tf.random.normal([n_images, self.z_dim])
        return self.gen(noise, training=False)
    
    # execute in graph mode with XLA for superior performance
    @tf.function(jit_compile=True)
    def train_step(self, real):
        batch_size = real.shape[0]
        noise = tf.random.normal([batch_size, self.z_dim])
        fake = self.gen(noise, training=False)
        concatenated = tf.concat((fake, real), axis=0)
        with tf.GradientTape() as disc_tape:
            disc_concatenated_pred = self.disc(concatenated)
            concatenated_labels = tf.concat((tf.zeros([batch_size, 1]), tf.ones([batch_size, 1])), axis=0)
            disc_loss = self.BCE_loss(concatenated_labels, disc_concatenated_pred)
        grad_of_disc = disc_tape.gradient(disc_loss, self.disc.trainable_variables)
        self.disc_opt.apply_gradients(zip(grad_of_disc, self.disc.trainable_variables))
        self.disc_mean_loss.update_state(disc_loss)
        noise = tf.random.normal([batch_size, self.z_dim])
        with tf.GradientTape() as gen_tape:
            fake = self.gen(noise)
            disc_fake_pred = self.disc(fake, training=False)
            gen_loss = self.BCE_loss(tf.ones_like(disc_fake_pred), disc_fake_pred)
        grad_of_gen = gen_tape.gradient(gen_loss, self.gen.trainable_variables)
        self.gen_opt.apply_gradients(zip(grad_of_gen, self.gen.trainable_variables))
        self.gen_mean_loss.update_state(gen_loss)
    
    def fit(self, data, epochs=1):
        for epoch in range(epochs):
            tic = perf_counter()
            self.disc_mean_loss.reset_states()
            self.gen_mean_loss.reset_states()
            for real in tqdm(data):
                model.train_step(real)
            self.history['disc_loss'].append(self.disc_mean_loss.result().numpy())
            self.history['gen_loss'].append(self.gen_mean_loss.result().numpy())
            display.clear_output(wait=True)
            print("Epoch %d/%d - %.1fs | disc_loss: %.5f - gen_loss: %.5f"%(
                epoch+1, EPOCHS, perf_counter() - tic, self.history['disc_loss'][-1], self.history['gen_loss'][-1]))
            test_model(self, 4)
            show_history(self.history)
        return self.history

In [None]:
TRAIN_BUFFER=70000
BATCH_SIZE=64
EPOCHS = 50
DIMS = (28, 28, 1)
AUTOTUNE = tf.data.AUTOTUNE

z_dim = 64
lr = 2e-4
beta_1 = 0.5
beta_2 = 0.999

In [None]:
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
X = np.concatenate((x_train, x_test)).astype(np.float32).reshape(-1, 28, 28, 1) / 127.5 - 1.
dataloader = tf.data.Dataset.from_tensor_slices(X).shuffle(TRAIN_BUFFER).batch(
    BATCH_SIZE, num_parallel_calls=AUTOTUNE, deterministic=False, drop_remainder=True).prefetch(AUTOTUNE)

discriminator = keras.Sequential([
    layers.InputLayer(DIMS), 
    layers.Conv2D(32, 4, strides=2, padding='same', use_bias=False),
    layers.BatchNormalization(),
    layers.LeakyReLU(alpha=0.2),

    layers.Conv2D(64, 4, strides=2, padding='same', use_bias=False),
    layers.BatchNormalization(),
    layers.LeakyReLU(alpha=0.2),

    layers.Conv2D(128, 3, strides=1, padding='same', use_bias=False),
    layers.BatchNormalization(),
    layers.LeakyReLU(alpha=0.2),

    layers.Flatten(),
    layers.Dense(1, dtype=tf.float32)
])

generator = keras.Sequential([
    layers.InputLayer([z_dim]),
    layers.Dense(7*7*128, use_bias=False),
    layers.Reshape([7, 7, 128]),
    layers.BatchNormalization(),
    layers.ReLU(),

    layers.Conv2DTranspose(64, 4, strides=2, padding='same', use_bias=False),
    layers.BatchNormalization(),
    layers.ReLU(),

    layers.Conv2DTranspose(32, 4, strides=2, padding='same', use_bias=False),
    layers.BatchNormalization(),
    layers.ReLU(),

    layers.Conv2D(1, 3, strides=1, padding='same', activation='tanh', dtype=tf.float32)
])

model = DCGAN(
    disc=discriminator,
    gen=generator,
    disc_opt=keras.optimizers.Adam(learning_rate=lr, beta_1=beta_1, beta_2=beta_2),
    gen_opt=keras.optimizers.Adam(learning_rate=lr, beta_1=beta_1, beta_2=beta_2),
    z_dim=z_dim
)

In [None]:
history = model.fit(dataloader, epochs=EPOCHS)

In [None]:
test_model(model, 8)

In [None]:
#model_path = "./DCGAN"
#model.gen.save(model_path, include_optimizer=False)

In [None]:
"""model_path = "./DCGAN"
loaded_model = DCGAN(
    gen=keras.models.load_model(model_path, compile=False),
    disc=None,
    gen_opt=None,
    disc_opt=None,
    z_dim=64
)"""

In [None]:
#test_model(loaded_model, 8)