# Building a GAN to Generate Handwritten Digit Images - Solved Version

In this assignment, you'll use TensorFlow to build a simple Generative Adversarial Network (GAN) that generates images of handwritten digits from the MNIST dataset.

## Steps Overview:


1. Create a Generator model and a Discriminator model
2. Create the loss functions
3. Create the training step function
4. Create a function to generate and save the images
5. Create the main training function
6. Load the sample MNIST dataset
7. Run the training function for 50 Epochs
8. Assemble an animated GIF showing the progressively better images.
    

### Exercise 1: Create the Generator and Discriminator Models
Use TensorFlow to create the generator and discriminator models.

In [None]:
# Install the required libraries
%pip install tensorflow
%pip install imageio


In [None]:
import tensorflow as tf
from tensorflow.keras import layers

# Create the generator model
def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    return model

# Create the discriminator model
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

# Instantiate the generator and discriminator models
generator = make_generator_model()
discriminator = make_discriminator_model()

# Print the model summaries
generator.summary()
discriminator.summary()


### Exercise 2: Create the Loss Functions
Create the loss functions for the generator and the discriminator.

In [4]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss


### Exercise 3: Create the Training Step Function
Define a function to perform one step of training the GAN.

In [5]:
@tf.function
def train_step(images):
    noise = tf.random.normal([batch_size, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))


### Exercise 4: Create a Function to Generate and Save Images
Generate images from random noise and save them at each epoch.

In [8]:
import matplotlib.pyplot as plt
import imageio

def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')

    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()


### Exercise 5: Create the Main Training Function
Define the function that runs the GAN training loop.

In [9]:
def train(dataset, epochs):
    seed = tf.random.normal([num_examples_to_generate, noise_dim])

    for epoch in range(epochs):
        for image_batch in dataset:
            train_step(image_batch)

        # Produce images for the GIF as we go
        generate_and_save_images(generator, epoch + 1, seed)


### Exercise 6: Load the MNIST Dataset
Load and preprocess the MNIST dataset.

In [None]:
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]

batch_size = 256
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(60000).batch(batch_size)


### Exercise 7: Run the Training Function for 50 Epochs
Run the training loop for 50 epochs and observe the generated images.

In [None]:
epochs = 50
noise_dim = 100
num_examples_to_generate = 16

# Optimizers
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

# Run the training loop
train(train_dataset, epochs)


### Exercise 8: Assemble an Animated GIF
Use the saved images to create an animated GIF showing the GAN's progress.

In [None]:
import glob
from PIL import Image

def create_gif():
    with imageio.get_writer('gan_training.gif', mode='I') as writer:
        filenames = glob.glob('image_at_epoch_*.png')
        filenames = sorted(filenames)
        for filename in filenames:
            image = imageio.imread(filename)
            writer.append_data(image)
    print("GIF saved as 'gan_training.gif'")

create_gif()
