# <center> Generate Galaxies using Random Numbers (WGAN-GP)
    
---
    

## Practical Part: WGAN-GP

### Dataset
* We are using [5000 images of Galaxies captured via telescopes](https://www.kaggle.com/datasets/spartificial/low-and-high-resolution-galaxy-images-cgan-data).
* We are going ONLY going to use the Groud Truth Files for this project.

### Table of Contents of Implementing WGAN-GP
1. <a href='#1'>Importing dependencies</a>
2. <a href='#2'>Setting Memory Growth for each GPUs</a>
3. <a href="#3">Project Setup</a>
4. <a href='#4'>Dataset Preparation</a>
5. <a href='#5'>Building Generator and Critic</a>
6. <a href='#6'>Learning Rate Setup</a>
7. <a href='#7'>Building Checkpoints</a>
8. <a href='#8'>Function to Generate Images and Save it after each epoch</a>
9. <a href='#9'>Build the Training Step for Critic and Generator</a>
10. <a href='#10'>Training of WGAN-GP</a>
11. <a href='#11'>Compare Generated Data with Real Data</a>

## 1. Importing Dependencies<section id="1">

In [None]:
# Import necessary modules and functions for file handling, timing, and mathematical operations.
import os
import time
import math

# Import NumPy for numerical operations and handling arrays.
import numpy as np

# Import matplotlib for plotting graphs and visualizations.
import matplotlib.pyplot as plt

# Import clear_output from IPython.display to clear the output in Jupyter notebooks.
from IPython.display import clear_output

# Import TensorFlow and its components for creating and training neural networks.
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Activation, Reshape, LayerNormalization, BatchNormalization
from tensorflow.keras.layers import Input, Dropout, Dense, LeakyReLU, Flatten
from tensorflow.keras import Model
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal

# Define a constant for optimizing TensorFlow data loading performance.
AUTOTUNE = tf.data.experimental.AUTOTUNE

## 2. Setting Memory Growth for each GPUs<section id="2">

In [None]:
# Step 1: List all physical GPU devices available on the system.
# This function checks the computer for any GPUs that TensorFlow can use and returns a list of those devices.
devices = tf.config.experimental.list_physical_devices("GPU")

# Step 2: Iterate over each GPU device found.
# The 'devices' list contains all GPUs detected by TensorFlow. We loop through this list to configure each GPU.
for device in devices:
    # Step 3: Set memory growth for each GPU.
    # Memory growth configuration allows TensorFlow to gradually increase the amount of GPU memory it uses, rather than allocating all available memory at once.
    # This is useful to avoid running out of memory if other applications are also using the GPU or if TensorFlow only needs a small amount of memory to start.
    tf.config.experimental.set_memory_growth(device=device, enable=True)

## 3. Project Setup<section id="3">

In [None]:
# Define the name of the model. In this case, 'DCGAN' stands for Deep Convolutional Generative Adversarial Network.
MODEL_NAME = 'DCGAN'

# Define the directory where image data is stored. 'dataset/images/' is the path to the folder containing image files.
DATA_BASE_DIR = 'dataset/images/'

# Define the output path where results and generated outputs will be saved. It combines 'outputs' directory with the model name.
OUTPUT_PATH = os.path.join('outputs', MODEL_NAME)

# Define the directory for storing training logs. This path is used by TensorFlow to save logs during training for later analysis.
TRAIN_LOGDIR = os.path.join("logs", "tensorflow", MODEL_NAME, 'train_data')

# Check if the OUTPUT_PATH directory does not exist. If it does not exist, create the directory.
if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

# Define the target size to which all images will be resized. Here, images will be scaled to 64x64 pixels.
TARGET_IMG_SIZE = 128 

# Define the batch size, which is the number of images processed together in one iteration of training.
BATCH_SIZE = 64

# Define the dimension of the noise vector. This vector is input to the generator in the GAN, typically of fixed size.
NOISE_DIM = 100

# Define a hyperparameter lambda (λ) used for gradient penalty in the training of GANs. It helps stabilize training.
LAMBDA = 10 

# Define the number of epochs for training. One epoch is one complete pass through the training dataset.
EPOCHS = 1000

# Define the starting epoch. Training will start from epoch 1.
CURRENT_EPOCH = 1 

# Define how frequently (in terms of epochs) to save model checkpoints. Here, a checkpoint will be saved every 15 epochs.
SAVE_EVERY_N_EPOCH = 15 

# Define the number of times to train the critic (discriminator) before training the generator. This helps balance the GAN training.
N_CRITIC = 3 

# Define the initial learning rate for training. The learning rate determines how much to adjust the weights during training.
LR = 1e-4

# Define the minimum learning rate. This ensures the learning rate doesn’t go below this value, which can be useful for fine-tuning.
MIN_LR = 0.000001 

# Define the decay factor for the learning rate. This factor controls how the learning rate decreases over time.
DECAY_FACTOR = 1.00004 

# Define the number of images to display after the training is completed
NUM_IMAGES = 36

# Create a file writer object for TensorFlow’s summary logs. This object writes logs to the directory defined in TRAIN_LOGDIR.
file_writer = tf.summary.create_file_writer(TRAIN_LOGDIR)

## 4. Dataset Preparation<section id='4'>

In [None]:
# Modify the directory to the path of your dataset
# Create a TensorFlow dataset object that contains file paths to images. This dataset includes all files with a .jpeg extension in the specified directory.
list_ds = tf.data.Dataset.list_files('/kaggle/input/Final Ground Truth/*.jpeg')

# Loop through the dataset and take the first 5 file paths.
# For each file path in these 5 examples, convert the TensorFlow tensor to a NumPy array and print it. 
# This shows the actual file paths of the images.
for f in list_ds.take(5):
    print(f.numpy())

In [None]:
def normalize(image):
    '''
    Normalizes the image pixel values to the range [-1, 1].

    Pixel values in images typically range from 0 to 255. This function rescales 
    them to the range [-1, 1] to make them more suitable for neural network training.
    
    Args:
        image (tf.Tensor): The input image tensor with pixel values in the range [0, 255].
    
    Returns:
        tf.Tensor: The normalized image tensor with pixel values in the range [-1, 1].
    '''
    # Convert the image tensor to a float32 type to perform mathematical operations.
    image = tf.cast(image, tf.float32)
    
    # Normalize the image pixel values.
    # Subtract 127.5 to center the values around 0.
    # Divide by 127.5 to scale the values to the range [-1, 1].
    image = (image - 127.5) / 127.5
    
    return image

def resize_and_center_crop(image, target_size):
    '''
    Resizes an image to fit within the target size while preserving the aspect ratio,
    then crops the center of the image to the exact target size.

    Args:
        image (tf.Tensor): The input image tensor.
        target_size (tuple): The target size (height, width) for resizing and cropping.

    Returns:
        tf.Tensor: The resized and center-cropped image tensor.
    '''
    target_height, target_width = target_size
    
    # Get the current dimensions of the image.
    img_height, img_width = tf.shape(image)[0], tf.shape(image)[1]

    # Calculate the aspect ratio of the original image and the target size.
    img_aspect = img_width / img_height
    target_aspect = target_width / target_height

    if img_aspect > target_aspect:
        # If the image is wider than the target aspect ratio, resize based on width.
        new_width = target_width
        new_height = int(target_width / img_aspect)
    else:
        # If the image is taller or the same aspect ratio, resize based on height.
        new_width = int(target_height * img_aspect)
        new_height = target_height

    # Resize the image to the new dimensions.
    image_resized = tf.image.resize(image, [new_height, new_width], method='bicubic', antialias=True)
    
    # Calculate the coordinates for center cropping.
    start_y = (new_height - target_height) // 2
    start_x = (new_width - target_width) // 2

    # Crop the center of the resized image to the target dimensions.
    image_cropped = tf.image.crop_to_bounding_box(image_resized, start_y, start_x, target_height, target_width)

    return image_cropped

def preprocess_image(file_path):
    '''
    Preprocesses an image from a file path by reading, decoding, resizing with aspect ratio preservation,
    center cropping, and normalizing it.

    Args:
        file_path (str): The file path to the image.

    Returns:
        tf.Tensor: The preprocessed image tensor ready for input into a neural network.
    '''
    # Read the image file from the given file path.
    image = tf.io.read_file(file_path)
    
    # Decode the JPEG image file to a tensor.
    image = tf.image.decode_jpeg(image, channels=1)
    
    # Resize and crop the image while preserving the aspect ratio.
    #image = resize_and_center_crop(image, (TARGET_IMG_SIZE, TARGET_IMG_SIZE))
    
    # Normalize the image to the range [-1, 1].
    image = normalize(image)
    
    return image

In [None]:
# Apply the preprocessing function to each image in the dataset
# Shuffle the dataset with a buffer size of 1000 to ensure randomness
# Cache the dataset in memory to improve performance
# Batch the data with the specified batch size for training
# Prefetch batches to improve performance by preparing data in the background while the model is training
train_data = list_ds.map(preprocess_image).shuffle(1000).cache().batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE) 

In [None]:
# Get a single batch of images from the train_data dataset.
sample_img = next(iter(train_data))

# Plot the first image from the batch with a title 'Sample'.
plt.title('Sample')

# Display the image using matplotlib. 
# The image needs to be converted back to the [0, 1] range for visualization.
# This involves reversing the normalization done earlier, which was in the range [-1, 1].
# Multiply by 0.5 and add 0.5 to rescale from [-1, 1] to [0, 1].
# Use np.clip to ensure that pixel values stay within [0, 1] range, which prevents displaying errors.
plt.imshow(sample_img[0] * 0.5 + 0.5, cmap='gray')

# Show the plot.
plt.show()

## 5. Building Generator and Critic<section id='5'>

In [None]:
def CGAN_generator(input_z_shape=NOISE_DIM):
    """
    Creates a generator model for a Conditional Generative Adversarial Network (CGAN) using a DCGAN-like architecture.
    
    The generator takes a noise vector (usually a random vector) as input and generates synthetic images. 
    The architecture is designed to progressively upsample the noise vector to produce an image with the desired dimensions.
    
    Args:
        input_z_shape (int): The shape of the input noise vector. Default is NOISE_DIM, which represents the dimensionality of the noise vector.
    
    Returns:
        tensorflow.keras.Model: A Keras Model object that represents the generator network.
    
    The generator architecture consists of:
    1. A Dense layer that transforms the noise vector into a high-dimensional tensor.
    2. A Reshape layer that reshapes this tensor into a 4x4 spatial resolution with 512 channels.
    3. A series of Conv2DTranspose layers that progressively upsample the image. Each layer increases the spatial resolution (height and width) and reduces the number of channels.
    4. BatchNormalization layers are applied after each Conv2DTranspose layer to normalize the activations, which helps stabilize and speed up training.
    5. LeakyReLU activation functions are used to introduce non-linearity and avoid dead neurons.
    6. The final Conv2DTranspose layer produces an image with 3 channels (RGB) using the 'tanh' activation function to scale pixel values to the range [-1, 1].
    """
    # Define the input layer with the shape of the noise vector. 
    # This is the input to the generator network.
    input_z_layer = Input((input_z_shape,))
    
    # First, fully connect the noise vector to a high-dimensional tensor.
    # This tensor will be reshaped into a 4x4 image with 512 channels.
    z = Dense(4*4*512, use_bias=False)(input_z_layer)
    z = Reshape((4, 4, 512))(z)
    
    # Apply a series of Conv2DTranspose layers to upsample the image.
    # Each Conv2DTranspose layer increases the spatial dimensions of the image.

    # First transposed convolution layer
    x = Conv2DTranspose(512, (4, 4), strides=(1, 1), padding='same', use_bias=False, kernel_initializer=RandomNormal(mean=0.0, stddev=0.02))(z)
    x = BatchNormalization()(x)  # Normalize the output of the convolution to help stabilize training
    x = LeakyReLU()(x)  # Apply Leaky ReLU activation function for non-linearity
    
    # Second transposed convolution layer
    x = Conv2DTranspose(256, (4, 4), strides=(2, 2), padding='same', use_bias=False, kernel_initializer=RandomNormal(mean=0.0, stddev=0.02))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    
    # Third transposed convolution layer
    x = Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', use_bias=False, kernel_initializer=RandomNormal(mean=0.0, stddev=0.02))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    
    # Fourth transposed convolution layer
    x = Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same', use_bias=False, kernel_initializer=RandomNormal(mean=0.0, stddev=0.02))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    
    # Fifth transposed convolution layer
    x = Conv2DTranspose(32, (4, 4), strides=(2, 2), padding='same', use_bias=False, kernel_initializer=RandomNormal(mean=0.0, stddev=0.02))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    
    # Final transposed convolution layer to generate the output image
    # The output image has 3 channels (RGB) and uses 'tanh' activation to ensure pixel values are between -1 and 1
    output = Conv2DTranspose(1, (4, 4), strides=(2, 2), padding='same', use_bias=False, activation="tanh",
                             kernel_initializer=RandomNormal(mean=0.0, stddev=0.02))(x)
    
    # Create a model with the input and output layers defined above
    model = Model(inputs=input_z_layer, outputs=output)
    return model

In [None]:
def CGAN_discriminator(input_x_shape=(TARGET_IMG_SIZE, TARGET_IMG_SIZE, 1)):
    """
    Creates a discriminator model for a Conditional Generative Adversarial Network (CGAN) using a DCGAN-like architecture.
    
    The discriminator takes an image as input and determines whether it is a real image from the dataset or a synthetic image generated by the generator. 
    The architecture is designed to progressively downsample the image to produce a single score indicating its authenticity.

    Args:
        input_x_shape (tuple): The shape of the input image. Default is (TARGET_IMG_SIZE, TARGET_IMG_SIZE, 3), where 
                               TARGET_IMG_SIZE is the height and width of the image, and 3 represents the RGB color channels.
    
    Returns:
        tensorflow.keras.Model: A Keras Model object that represents the discriminator network.

    The discriminator architecture consists of:
    1. A series of Conv2D layers that progressively reduce the spatial dimensions of the image while increasing the number of feature channels.
    2. Each Conv2D layer is followed by a LeakyReLU activation function, which introduces non-linearity and prevents dead neurons.
    3. An optional LayerNormalization step is commented out. Uncommenting it will normalize the activations across the features.
    4. The final Conv2D layer reduces the output to a single channel.
    5. A Flatten layer converts the 3D tensor into a 1D tensor.
    6. A Dense layer produces a single score that represents the probability of the image being real or fake.
    """
    # Define the input layer with the shape of the input image.
    # This is the input to the discriminator network.
    input_x_layer = Input(input_x_shape)
    
    # Apply a series of Conv2D layers to downsample the image.
    # Each Conv2D layer reduces the spatial dimensions of the image.

    # First convolutional layer
    x = Conv2D(64, (4, 4), strides=(2, 2), padding='same', use_bias=False, kernel_initializer=RandomNormal(mean=0.0, stddev=0.02))(input_x_layer)
    # x = LayerNormalization()(x)  # Uncomment this line to use LayerNormalization instead of BatchNormalization
    x = LeakyReLU()(x)  # Apply Leaky ReLU activation function for non-linearity
    
    # Second convolutional layer
    x = Conv2D(128, (4, 4), strides=(2, 2), padding='same', use_bias=False, kernel_initializer=RandomNormal(mean=0.0, stddev=0.02))(x)
    # x = LayerNormalization()(x)
    x = LeakyReLU()(x)
    
    # Third convolutional layer
    x = Conv2D(256, (4, 4), strides=(2, 2), padding='same', use_bias=False, kernel_initializer=RandomNormal(mean=0.0, stddev=0.02))(x)
    # x = LayerNormalization()(x)
    x = LeakyReLU()(x)
    
    # Fourth convolutional layer
    x = Conv2D(512, (4, 4), strides=(2, 2), padding='same', use_bias=False, kernel_initializer=RandomNormal(mean=0.0, stddev=0.02))(x)
    # x = LayerNormalization()(x)
    x = LeakyReLU()(x)
    
    # Final convolutional layer that reduces the output to a single channel
    # This final layer does not have activation because we will apply a sigmoid function later to get the probability of the image being real or fake
    x = Conv2D(1, (4, 4), strides=(1, 1), padding='same', use_bias=False, kernel_initializer=RandomNormal(mean=0.0, stddev=0.02))(x)
    
    # Flatten the 3D output into a 1D tensor
    x = Flatten()(x)
    
    # Apply a Dense layer to output a single score indicating whether the image is real or fake
    output = Dense(1)(x)
    
    # Create a model with the input and output layers defined above
    model = Model(inputs=input_x_layer, outputs=output)
    return model

In [None]:
# Create an instance of the CGAN generator model.
# This model takes a noise vector as input and generates an image.
# The noise vector has a dimensionality defined by NOISE_DIM (default is 100).
generator = CGAN_generator()
generator.summary()

In [None]:
# Create an instance of the CGAN discriminator model.
# This model takes an image as input and outputs a score indicating whether the image is real or fake.
# The input image shape is defined by the parameter (default is (TARGET_IMG_SIZE, TARGET_IMG_SIZE, 3)),
# where TARGET_IMG_SIZE is the height and width of the image and 3 represents the RGB color channels.
discriminator = CGAN_discriminator()
discriminator.summary()

## 6. Learning Rate Setup<section id='6'>

In [None]:
# Optimizers for the GAN models.
# The Adam optimizer is used for both the discriminator (D_optimizer) and the generator (G_optimizer).
# The learning rate for both optimizers is set to LR, which controls how much to adjust the weights in each iteration.
# The beta_1 parameter is set to 0.5 to help stabilize the training process by controlling the exponential decay rate of the first moment estimates.

D_optimizer = Adam(learning_rate=LR, beta_1=0.5)
G_optimizer = Adam(learning_rate=LR, beta_1=0.5)

def learning_rate_decay(current_lr, decay_factor=DECAY_FACTOR):
    '''
    Calculate the new learning rate by applying a decay factor.
    
    The learning rate decay function reduces the learning rate gradually over time, which helps to stabilize training 
    as the model converges. The decay factor controls how quickly the learning rate decreases. This function ensures 
    that the learning rate does not fall below a minimum threshold defined by MIN_LR.
    
    Args:
        current_lr (float): The current learning rate value.
        decay_factor (float): The factor by which the learning rate is divided to get the new rate. Default is DECAY_FACTOR.
        
    Returns:
        float: The updated learning rate, which is either the decayed rate or MIN_LR, whichever is higher.
    '''
    # Calculate the new learning rate by dividing the current learning rate by the decay factor.
    new_lr = max(current_lr / decay_factor, MIN_LR)
    return new_lr

def set_learning_rate(D_optimizer, G_optimizer, new_lr):
    '''
    Set a new learning rate for the optimizers.
    
    This function updates the learning rate for both the discriminator and generator optimizers. 
    It ensures that both optimizers use the same learning rate, which is important for balanced training.
    
    Args:
        D_optimizer (tf.keras.optimizers.Optimizer): The optimizer used for training the discriminator.
        G_optimizer (tf.keras.optimizers.Optimizer): The optimizer used for training the generator.
        new_lr (float): The new learning rate to be set for both optimizers.
        
    Returns:
        None: This function modifies the optimizers in-place and does not return a value.
    '''
    # Update the learning rate of the discriminator optimizer.
    D_optimizer.learning_rate.assign(new_lr)
    
    # Update the learning rate of the generator optimizer.
    G_optimizer.learning_rate.assign(new_lr)

## 7. Building Checkpoints<section id='7'>

In [None]:
# Define the path where the checkpoints (saved states of the model) will be stored.
# This path includes a directory for TensorFlow checkpoints and a subdirectory named after the model.
checkpoint_path = os.path.join("checkpoints", "tensorflow", MODEL_NAME)

# Create a `tf.train.Checkpoint` object.
# This object keeps track of the generator, discriminator, and both optimizers.
# It allows us to save and restore their states during training.
ckpt = tf.train.Checkpoint(generator=generator,
                           discriminator=discriminator,
                           G_optimizer=G_optimizer,
                           D_optimizer=D_optimizer)

# Create a `tf.train.CheckpointManager` to manage the checkpoints.
# This manager will save the checkpoints to the path defined earlier and keep up to 5 of the most recent checkpoints.
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# Check if there are any existing checkpoints in the checkpoint manager.
# If there is at least one checkpoint, restore the latest one.
if ckpt_manager.latest_checkpoint:
    # Restore the state of the generator, discriminator, and optimizers from the latest checkpoint.
    ckpt.restore(ckpt_manager.latest_checkpoint)
    
    # Extract the epoch number from the checkpoint file name.
    # Checkpoints are named with the format 'ckpt-{epoch_number}', so split the name to get the epoch number.
    latest_epoch = int(ckpt_manager.latest_checkpoint.split('-')[1])
    
    # Calculate the current epoch based on the latest checkpoint.
    # MULTIPLY the epoch number from the checkpoint by the number of epochs between saves (SAVE_EVERY_N_EPOCH).
    CURRENT_EPOCH = latest_epoch * SAVE_EVERY_N_EPOCH
    
    # Print a message indicating that the latest checkpoint has been restored and display the epoch number.
    print (f'Latest checkpoint of epoch {CURRENT_EPOCH} restored!!')

## 8. Function to Generate Images and Save it after each epoch<section id='8'>

In [None]:
def generate_and_save_images(model, epoch, test_input, figure_size=(12,6), subplot=(3,6), save=True, is_flatten=False):
    '''
    Generate images using a trained model and save or display them.

    This function takes a trained model and a set of input data to generate images. It then plots these images
    in a grid layout and optionally saves the generated images to a file.

    Args:
        model (tf.keras.Model): The trained model used to generate images. This model should be capable of taking
                                the `test_input` and producing predictions (i.e., generated images).
        epoch (int): The current epoch number of training. This is used to name the saved image file, indicating
                     the state of the model at this epoch.
        test_input (numpy.ndarray or tf.Tensor): The input data to feed into the model for generating images. 
                                                 This should be an array or tensor of noise vectors or other input 
                                                 that the model can process to produce images.
        figure_size (tuple of int, optional): The size of the figure (in inches) on which the images will be plotted. 
                                              Default is (12, 6), which specifies a width of 12 inches and a height of 6 inches.
        subplot (tuple of int, optional): The grid layout for displaying the images. It is specified as (rows, columns). 
                                          Default is (3, 6), which means the images will be arranged in 3 rows and 6 columns.
        save (bool, optional): If True, the generated images will be saved to a file. Default is True.
        is_flatten (bool, optional): If True, reshapes the predictions to a specific format before plotting. 
                                      Default is False. Set to True if the images are outputted in a flattened format.

    Returns:
        None: This function does not return any value but instead saves or displays the generated images.
    '''
    
    # Generate images using the provided model and input data.
    # The model's `predict` method takes the `test_input` and produces predictions (generated images).
    predictions = model.predict(test_input)
    
    # If `is_flatten` is True, reshape the predictions array to the original image dimensions.
    # This assumes that the images have been flattened (e.g., into a 1D array) and need to be reshaped.
    if is_flatten:
        predictions = predictions.reshape(-1, IMG_WIDTH, IMG_HEIGHT, 3).astype('float32')
    
    # Create a new figure with the specified size to plot the images.
    fig = plt.figure(figsize=figure_size)
    
    # Loop through each generated image and plot it.
    for i in range(predictions.shape[0]):
        # Create a subplot for each image based on the specified grid layout.
        axs = plt.subplot(subplot[0], subplot[1], i+1)
        
        # Display the image. The predictions are scaled from [-1, 1] to [0, 1] for proper visualization.
        plt.imshow(predictions[i] * 0.5 + 0.5, cmap='gray')
        
        # Hide the axis for a cleaner display of images.
        plt.axis('off')
    
    # If `save` is True, save the figure to a file. The filename includes the epoch number.
    if save:
        plt.savefig(os.path.join(OUTPUT_PATH, 'image_at_epoch_{:04d}.png'.format(epoch)))
    
    # Display the figure with the plotted images.
    plt.show()

In [None]:
# Define the number of examples (images) to generate.
# This sets how many images we want to produce in a single batch.
num_examples_to_generate = 18

# Create a batch of random noise vectors to be used as input for the generator.
# `tf.random.normal` generates a tensor of random values from a normal distribution.
# The shape `[num_examples_to_generate, NOISE_DIM]` specifies that we want
# `num_examples_to_generate` noise vectors, each with dimensionality `NOISE_DIM`.
sample_noise = tf.random.normal([num_examples_to_generate, NOISE_DIM])

# Generate images using the `generator` model with the `sample_noise` as input.
# The function `generate_and_save_images` takes several arguments:
# - `generator`: The trained generator model used to create images.
# - `0`: The epoch number, used to label the image if saved (here it is 0 as we are not saving).
# - `[sample_noise]`: The input data to feed into the generator. We wrap `sample_noise` in a list 
#   to match the expected input format for the function.
# - `figure_size=(12,6)`: Specifies the size of the figure for plotting.
# - `subplot=(3,6)`: Specifies the grid layout for the subplot (3 rows and 6 columns).
# - `save=False`: Indicates that the generated images should not be saved to a file.
# - `is_flatten=False`: Specifies that the generated images are not in a flattened format.
generate_and_save_images(generator, 0, [sample_noise], figure_size=(12,6), subplot=(3,6), save=False, is_flatten=False)


## 9. Build the Training Step for Critic and Generator<section id='9'>

In [None]:
@tf.function
def WGAN_GP_train_d_step(real_image, batch_size, step):
    '''
    Perform one training step for the discriminator in the Wasserstein GAN with Gradient Penalty (WGAN-GP) framework.

    Args:
        real_image (tf.Tensor): A batch of real images from the dataset, used to train the discriminator.
        batch_size (int): The number of images in the batch.
        step (int): The current training step, used for logging purposes.

    Reference:
        This implementation is inspired by the TensorFlow DCGAN tutorial:
        https://www.tensorflow.org/tutorials/generative/dcgan
    '''
    # The @tf.function decorator converts this Python function into a TensorFlow graph function.
    # This optimization allows TensorFlow to compile and optimize the function for performance.
    # It improves execution speed by converting the function into a static computation graph,
    # which TensorFlow can optimize and run more efficiently compared to the dynamic Python execution.
    print("retrace")

    # Generate a batch of random noise vectors. These vectors will be used to create fake images.
    noise = tf.random.normal([batch_size, NOISE_DIM])

    # Generate random values for epsilon in the range [0, 1], with the same batch size as the real images.
    # This is used to create interpolated images between real and fake images.
    epsilon = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=0, maxval=1)

    ###################################
    # Train the Discriminator (D)
    ###################################

    # Start recording gradients for the discriminator and gradient penalty computations.
    # 'persistent=True' allows multiple gradient computations to be recorded, which is necessary for both the gradient penalty and discriminator loss calculations.
    with tf.GradientTape(persistent=True) as d_tape:
        # Start a nested gradient tape to compute the gradient penalty.
        with tf.GradientTape() as gp_tape:
            # Generate fake images using the generator model and the random noise.
            fake_image = generator([noise], training=True)
            
            # Create mixed images by interpolating between real images and fake images using epsilon.
            fake_image_mixed = epsilon * tf.dtypes.cast(real_image, tf.float32) + ((1 - epsilon) * fake_image)
            
            # Get the discriminator's predictions for the mixed images.
            fake_mixed_pred = discriminator([fake_image_mixed], training=True)
        
        # Compute the gradient penalty to enforce the Lipschitz constraint.
        # Calculate gradients of the discriminator's predictions with respect to the mixed images.
        grads = gp_tape.gradient(fake_mixed_pred, fake_image_mixed)
        
        # Compute the norm (magnitude) of the gradients. (x1, x2, x3) ---> sqrt(x1^2 + x2^2 + x3^3)
        grad_norms = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        
        # Calculate the gradient penalty as the mean squared difference between the gradient norms and 1.
        gradient_penalty = tf.reduce_mean(tf.square(grad_norms - 1))
        
        # Get the discriminator's predictions for the fake and real images.
        fake_pred = discriminator([fake_image], training=True)
        real_pred = discriminator([real_image], training=True)
        
        # Calculate the discriminator loss: the mean of fake predictions minus the mean of real predictions,
        # plus the gradient penalty term scaled by LAMBDA.
        D_loss = tf.reduce_mean(fake_pred) - tf.reduce_mean(real_pred) + LAMBDA * gradient_penalty
    
    # Compute gradients of the discriminator loss with respect to the discriminator's trainable variables.
    D_gradients = d_tape.gradient(D_loss, discriminator.trainable_variables)
    # (0, 1, 2), (1, 2, 3) --> (0, 1), (1, 2), (2, 3)
    # Apply the computed gradients to the discriminator's optimizer to update its weights.
    D_optimizer.apply_gradients(zip(D_gradients, discriminator.trainable_variables))
    
    # Log the discriminator loss to TensorBoard every 10 steps for monitoring purposes.
    if step % 10 == 0:
        with file_writer.as_default():
            tf.summary.scalar('D_loss', tf.reduce_mean(D_loss), step=step)

In [None]:
@tf.function
def WGAN_GP_train_g_step(real_image, batch_size, step):
    '''
    Perform one training step for the generator in the Wasserstein GAN with Gradient Penalty (WGAN-GP) framework.

    Args:
        real_image (tf.Tensor): A batch of real images from the dataset, used for the discriminator.
        batch_size (int): The number of images in the batch.
        step (int): The current training step, used for logging purposes.

    Reference:
        This implementation is inspired by the TensorFlow DCGAN tutorial:
        https://www.tensorflow.org/tutorials/generative/dcgan
    '''
    # The @tf.function decorator converts this Python function into a TensorFlow graph function.
    # This optimization allows TensorFlow to compile and optimize the function for performance.
    # It improves execution speed by converting the function into a static computation graph,
    # which TensorFlow can optimize and run more efficiently compared to the dynamic Python execution.
    print("retrace")

    # Generate a batch of random noise vectors. These vectors will be used to create fake images.
    noise = tf.random.normal([batch_size, NOISE_DIM])

    ###################################
    # Train the Generator (G)x
    ###################################

    # Start recording gradients for the generator's loss calculations.
    with tf.GradientTape() as g_tape:
        # Generate fake images using the generator model and the random noise.
        fake_image = generator([noise], training=True)
        
        # Get the discriminator's predictions for the fake images.
        fake_pred = discriminator([fake_image], training=True)
        
        # Calculate the generator loss as the negative mean of the discriminator's predictions for fake images.
        # The generator's goal is to maximize this value, which is equivalent to minimizing its negative.
        G_loss = -tf.reduce_mean(fake_pred)
    
    # Compute gradients of the generator loss with respect to the generator's trainable variables.
    G_gradients = g_tape.gradient(G_loss, generator.trainable_variables)
    
    # Apply the computed gradients to the generator's optimizer to update its weights.
    G_optimizer.apply_gradients(zip(G_gradients, generator.trainable_variables))
    
    # Log the generator loss to TensorBoard every 10 steps for monitoring purposes.
    if step % 10 == 0:
        with file_writer.as_default():
            tf.summary.scalar('G_loss', G_loss, step=step)

## 10. Training of WGAN-GP<section id='10'>

In [None]:
# Initialize the current learning rate to the initial learning rate specified by LR
current_learning_rate = LR

# A flag to trace execution, typically used for debugging purposes
trace = True

# Counter to track the number of discriminator (critic) updates
n_critic_count = 0

# Loop over each epoch in the training process
for epoch in range(CURRENT_EPOCH, EPOCHS + 1):
    # Record the start time of the epoch for performance tracking
    start = time.time()
    
    # Print the start of the epoch with the current epoch number
    print('Start of epoch %d' % (epoch,))
    
    # Update the learning rate using the decay function
    current_learning_rate = learning_rate_decay(current_learning_rate)
    
    # Print the updated learning rate for monitoring
    print('current_learning_rate %f' % (current_learning_rate,))
    
    # Apply the updated learning rate to both the discriminator and generator optimizers
    set_learning_rate(D_optimizer, G_optimizer, current_learning_rate)
    
    # Iterate over the training data in batches
    for step, (image) in enumerate(train_data):
        # Get the current batch size from the image tensor
        current_batch_size = image.shape[0]
        
        # Train the discriminator (critic) on the current batch of real images
        WGAN_GP_train_d_step(
            image,
            batch_size=tf.constant(current_batch_size, dtype=tf.int64),
            step=tf.constant(step, dtype=tf.int64)
        )
        
        # Increment the discriminator update counter
        n_critic_count += 1
        
        # If the number of discriminator updates reaches the specified threshold (N_CRITIC)
        if n_critic_count >= N_CRITIC:
            # Train the generator on the current batch of images
            WGAN_GP_train_g_step(
                image,
                batch_size=tf.constant(current_batch_size, dtype=tf.int64),
                step=tf.constant(step, dtype=tf.int64)
            )
            # Reset the discriminator update counter
            n_critic_count = 0
        
        # Print a dot every 10 steps to indicate progress in training
        if step % 10 == 0:
            print ('.', end='')
    
    # Clear the output of the Jupyter notebook cell to keep the output area tidy
    clear_output(wait=True)
    
    # Generate and save images from the generator to visualize progress
    # Use a consistent sample noise to compare progress over epochs
    generate_and_save_images(
        generator,
        epoch,
        [sample_noise],
        figure_size=(12,6),
        subplot=(3,6),
        save=True,
        is_flatten=False
    )
    
    # If the current epoch is a multiple of SAVE_EVERY_N_EPOCH, save a checkpoint
    if epoch % SAVE_EVERY_N_EPOCH == 0:
        # Save the model's state to a checkpoint file
        ckpt_save_path = ckpt_manager.save()
        # Print the path where the checkpoint was saved
        print('Saving checkpoint for epoch {} at {}'.format(epoch, ckpt_save_path))
    
    # Print the time taken for the current epoch
    print('Time taken for epoch {} is {} sec\n'.format(epoch, time.time() - start))

# Save a final checkpoint at the end of training
ckpt_save_path = ckpt_manager.save()
# Print the path where the final checkpoint was saved
print('Saving checkpoint for epoch {} at {}'.format(EPOCHS, ckpt_save_path))

## 11. Compare Generated Data with Real Data<section id='11'>

In [None]:
# Check for perfect square
if NUM_IMAGES < 0:
    raise ValueError("Number must be non-negative")
if not np.sqrt(NUM_IMAGES).is_integer():
    raise ValueError(f"{NUM_IMAGES} is not a perfect square")

# Create a new batch of random noise vectors as input for the generator.
# This noise will be used to generate images and assess the model's performance.
# The batch size is set to 64, and each noise vector has the dimension NOISE_DIM.
test_noise = tf.random.normal([NUM_IMAGES, NOISE_DIM])

# Generate images using the generator model with the new batch of random noise vectors.
# The generator takes the noise as input and produces corresponding fake images.
prediction = generator.predict(test_noise)

In [None]:
def image_grid(images, fig, images_to_generate):
    '''
    Display a grid of images in a specified matplotlib figure.

    Args:
        images (numpy.ndarray): A numpy array containing images to be displayed.
            The array should have shape (N, H, W, C), where N is the number of images,
            H is the height of each image, W is the width of each image, and C is the number of channels (3 for RGB).
        fig (matplotlib.figure.Figure): The matplotlib figure object where the image grid will be plotted.

    Description:
        This function takes a batch of images and a matplotlib figure object, and displays the images in an 8x8 grid layout.
        Each subplot in the grid is created using `fig.add_subplot`, and the images are shown without axis ticks.
        The images are clipped to the range [0, 1] and scaled from the range [-1, 1] if necessary.
    '''
    # Iterate over the first 64 images in the batch.
    for i in range(images_to_generate):
        # Add a subplot to the figure at the (i + 1)th position in an 8x8 grid.
        axs = fig.add_subplot(int(np.sqrt(images_to_generate)), int(np.sqrt(images_to_generate)), i + 1)
        # Remove x and y axis ticks for the subplot.
        axs.set_xticks([])
        axs.set_yticks([])
        # Display the image in the current subplot.
        # Clip the image values to be between 0 and 1 for correct visualization.
        axs.imshow(images[i] * 0.5 + 0.5, cmap='gray')

In [None]:
# Plot the real images from the dataset

# Create a new matplotlib figure with a specified size of 12x12 inches.
# This will be used to display the grid of images.
fig1 = plt.figure(figsize=(12,12))

# Call the image_grid function to add the images to the figure.
# `sample_img.numpy()[:NUM_IMAGES]` retrieves the first NUM_IMAGES images from the sample image batch,
# which are then passed to the image_grid function to be arranged in an 8x8 grid.
n = 1
image_grid(sample_img[:NUM_IMAGES], fig1, NUM_IMAGES)

# Display the figure with the plotted images.
plt.show()

In [None]:
# Plot the fake images generated by the model from the last epoch

# Create a new matplotlib figure with a specified size of 12x12 inches.
# This will be used to display the grid of generated images.
fig2 = plt.figure(figsize=(12,12))

# Call the image_grid function to add the generated images to the figure.
# `prediction` contains the fake images generated by the model.
# These images are passed to the image_grid function to be arranged in an 8x8 grid.
image_grid(prediction, fig2, NUM_IMAGES)

# Display the figure with the plotted generated images.
plt.show()

In [None]:
# Save the generator model
generator.save('generator_wgan_gp.h5')

In [None]:
import zipfile
import os

# Path to your data and the zip file
data_folder = '/kaggle/working/'
zip_file = '/kaggle/working/saved_data.zip'

# Create a zip file
with zipfile.ZipFile(zip_file, 'w', zipfile.ZIP_DEFLATED) as zf:
    for root, dirs, files in os.walk(data_folder):
        for file in files:
            file_path = os.path.join(root, file)
            zf.write(file_path, os.path.relpath(file_path, data_folder))

print(f"Data zipped and saved to {zip_file}")
