# Deep Convolutional Generative Adversarial Network Example

Build a deep convolutional generative adversarial network (DCGAN) to generate digit images from a noise distribution with Tensorflow v2

References:

- Unsupervised representation learning with deep convolutional generative adversarial networks. A Radford, L Metz, S Chintala, 2016.
- Understanding the difficulty of training deep feedforward neural networks. X Glorot, Y Bengio. Aistats 9, 249-256
- Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. Sergey Ioffe, Christian Szegedy. 2015.

In [None]:
! pip install tensorflow

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, layers

In [None]:
# MNIST Dataset parameters
num_features = 784 # data features (img shape: 28*28).

# Training parameters.
lr_generator = 0.0002
lr_discriminator = 0.0002
training_steps = 20000
batch_size = 128
display_step = 500

# Network parameters.
noise_dim = 100 # Noise data points.

In [None]:
# Prepare MNIST data.
from tensorflow.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Convert to float32.
x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)
# Normalize images value from [0, 255] to [0, 1].
x_train, x_test = x_train / 255., x_test / 255.

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [None]:
# Use tf.data API to shuffle and batch data.
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data = train_data.repeat().shuffle(10000).batch(batch_size).prefetch(1)

In [None]:
# Create TF Model
class Generator(Model):
  # Set layers
  def __init__(self):
    super(Generator, self).__init__()
    self.fc1 = layers.Dense(7 * 7 * 128)
    self.bn1 = layers.BatchNormalization()
    self.conv2tr1 = layers.Conv2DTranspose(64, 5, strides=2, padding="SAME")
    self.bn2 = layers.BatchNormalization()
    self.conv2tr2 = layers.Conv2DTranspose(1, 5, strides=2, padding="SAME")

  # Set forward pass
  def call(self, x, is_training=False):
    x = self.fc1(x)
    x = self.bn1(x, training=is_training)
    x = tf.nn.leaky_relu(x)

    # Reshape to a 4D array of images: (batch, height, width, channels)
    # New shape: (batch, 7, 7, 128)
    x = tf.reshape(x, shape=[-1, 7, 7, 128])
    # Deconvulotion, image shape: (batch, 14, 14, 64)
    x = self.conv2tr1(x)
    x = self.bn2(x, training=is_training)
    x = tf.nn.leaky_relu(x)
    # Deconvolution, image shape (batch, 28, 28, 1)
    x = self.conv2tr2(x)
    x = tf.nn.tanh(x)

    return x

# Generator Network
# Input: Noise, Output: Image
# Note that batch normalization has different behaviour at training and inference time
# we then use a placeholder to indicate the layer if we are training or not

class Discriminator(Model):
  # Set layers
  def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = layers.Conv2D(64, 5, strides=2, padding='SAME')
        self.bn1 = layers.BatchNormalization()
        self.conv2 = layers.Conv2D(128, 5, strides=2, padding='SAME')
        self.bn2 = layers.BatchNormalization()
        self.flatten = layers.Flatten()
        self.fc1 = layers.Dense(1024)
        self.bn3 = layers.BatchNormalization()
        self.fc2 = layers.Dense(2)

  # Set forward pass.
  def call(self, x, is_training=False):
        x = tf.reshape(x, [-1, 28, 28, 1])
        x = self.conv1(x)
        x = self.bn1(x, training=is_training)
        x = tf.nn.leaky_relu(x)
        x = self.conv2(x)
        x = self.bn2(x, training=is_training)
        x = tf.nn.leaky_relu(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.bn3(x, training=is_training)
        x = tf.nn.leaky_relu(x)
        return self.fc2(x)
    

# Build neural netowrk model
generator = Generator()
discriminator = Discriminator()

In [None]:
# Losses
def generator_loss(reconstructed_images) -> float:
  gen_loss: float = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
      logits=reconstructed_images, labels=tf.ones([batch_size], dtype=tf.int32)
  ))
  return gen_loss

def discriminator_loss(disc_fake, disc_real) -> float:
  disc_loss_real: float = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
      logits=disc_real, labels=tf.ones([batch_size], dtype=tf.int32)
  ))
  disc_loss_fake: float = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
      logits=disc_fake, labels=tf.ones([batch_size], dtype=tf.int32)
  ))
  
  return disc_loss_real + disc_loss_fake

# Optimizers
optimizer_gen = tf.optimizers.Adam(learning_rate=lr_generator) # beta_1=0.5, beta_2=0.99
optimizer_disc = tf.optimizers.Adam(learning_rate=lr_discriminator) # beta_1=0.5, beta_2=0.99

In [None]:
# Optimization process. Inputs: real and noise
def run_optimization(real_images):

  # Rescale to [-1, 1], the input range of the discriminator
  real_images = real_images * 2.0 - 1.0

  # Generate Noise
  noise = np.random.normal(-1.0, 1.0, size=[batch_size, noise_dim]).astype(np.float32)

  with tf.GradientTape() as g:
    fake_images = generator(noise, is_training=True)
    disc_fake = discriminator(fake_images, is_training=True)
    disc_real = discriminator(real_images, is_training=True)

    disc_loss: float = discriminator_loss(disc_fake, disc_real)
  
  # Training Variables for each optimizer
  gradients_disc = g.gradient(disc_loss, discriminator.trainable_variables)
  optimizer_disc.apply_gradients(zip(gradients_disc, discriminator.trainable_variables))

  # Generate Noise
  noise = np.random.normal(-1.0, 1.0, size=[batch_size, noise_dim]).astype(np.float32)

  with tf.GradientTape() as g:
    fake_images = generator(noise, is_training=True)
    disc_fake = discriminator(fake_images, is_training=True)

    gen_loss = generator_loss(disc_fake)
  
  gradients_gen = g.gradient(gen_loss, generator.trainable_variables)
  optimizer_gen.apply_gradients(zip(gradients_gen, generator.trainable_variables))

  return gen_loss, disc_loss

In [None]:
# Run training for the given number of steps
for step, (batch_x, _) in enumerate(train_data.take(training_steps + 1)):
  if step == 0:
    # Generate noise
    noise = np.random.normal(-1.0, 1.0, size=[batch_size, noise_dim]).astype(np.float32)
    gen_loss = generator_loss(discriminator(generator(noise)))
    disc_loss = discriminator_loss(discriminator(batch_x), discriminator(generator(noise)))
    print(f"Initial: generator loss: {gen_loss}, discriminator loss: {disc_loss}")
    continue
  
  # Run the optimization
  gen_loss, disc_loss = run_optimization(batch_x)

  if step % display_step == 0:
    print(f"Step: {step}, generator loss: {gen_loss}, discriminator loss: {disc_loss}")

Initial: generator loss: 0.6894901990890503, discriminator loss: 1.3845398426055908
Step: 500, generator loss: 0.0019445784855633974, discriminator loss: 0.0028123047668486834
Step: 1000, generator loss: 0.0006655532633885741, discriminator loss: 0.0009768111631274223
Step: 1500, generator loss: 0.0003221117949578911, discriminator loss: 0.0004940598737448454
Step: 2000, generator loss: 0.0001908612612169236, discriminator loss: 0.0002851457102224231
Step: 2500, generator loss: 0.00012523136683739722, discriminator loss: 0.00018676454783417284
Step: 3000, generator loss: 8.692976552993059e-05, discriminator loss: 0.0001265893515665084
Step: 3500, generator loss: 6.305928400252014e-05, discriminator loss: 9.020417201099917e-05
Step: 4000, generator loss: 4.316943886806257e-05, discriminator loss: 6.424854655051604e-05
Step: 4500, generator loss: 3.233210736652836e-05, discriminator loss: 4.7906396503094584e-05


In [None]:
# Visualize predictions.
import matplotlib.pyplot as plt

In [2]:
Testing
# Generate images from noise, using the generator network.
n = 6
canvas = np.empty((28 * n, 28 * n))
for i in range(n):
    # Noise input.
    z = np.random.normal(-1., 1., size=[n, noise_dim]).astype(np.float32)
    # Generate image from noise.
    g = generator(z).numpy()
    # Rescale to original [0, 1]
    g = (g + 1.) / 2
    # Reverse colours for better display
    g = -1 * (g - 1)
    for j in range(n):
        # Draw the generated digits
        canvas[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = g[j].reshape([28, 28])

plt.figure(figsize=(n, n))
plt.imshow(canvas, origin="upper", cmap="gray")
plt.show()