# Pix2Pix GAN
### It's a type of cGAN where a preceding image serves as a condition for image generation such as translating one image into another.
### For example, colorizing black and white photos, transforming satellite images into map routes.

In [87]:
import tensorflow as tf 
import os 
import pathlib 
import time
import datetime 
from matplotlib import pyplot as plt 
from IPython import display
from PIL import Image
import numpy as np

# Load the dataset

In [None]:
dataset_name = "maps"#@param ["cityscapes", "edges2handbags", "edges2shoes", "facades", "maps", "night2day"]
_URL = f'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/{dataset_name}.tar.gz'
# get_file function is used to download the dataset file and extract it
path2zip = tf.keras.utils.get_file(fname=f"{dataset_name}.tar.gz", origin = _URL, extract= True)
# pathlib.Path is module is used to handle file paths
path2zip = pathlib.Path(path2zip)
# 'PATH' variable represents the path to the extracted dataset.
PATH = path2zip.parent/dataset_name

#  tf.io.read_file function reads the file content as a binary string
sample_image = tf.io.read_file(str(PATH/'train/1.jpg'))
# tf.io.decode_jpeg decodes the binary string into a tensor representing the image
sample_image = tf.io.decode_jpeg(sample_image)
print(f'shape of the sample image is: {sample_image.shape}')
plt.figure()
plt.imshow(sample_image)

# Load and Preprocess the Dataset

In [89]:
# Define image size and batch size
IMG_HEIGHT = 256
IMG_WIDTH = 256
BATCH_SIZE = 1  # Depending on your GPU memory
AUTOTUNE = tf.data.experimental.AUTOTUNE

# Function to preprocess images: resize, normalize, etc.
def load(image_file):
    # Read the image from file
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image)
    
    # Split input and target images from the dataset
    w = tf.shape(image)[1] // 2
    input_image = image[:, :w, :]  # Input image
    target_image = image[:, w:, :]  # Target image
    input_image = tf.cast(input_image, tf.float32)
    target_image = tf.cast(target_image, tf.float32)

    return input_image, target_image

def resize(input_image, target_image, height, width):
    input_image = tf.image.resize(input_image, [height, width])
    target_image = tf.image.resize(target_image, [height, width])
    return input_image, target_image

def normalize(input_image, target_image):
    # Normalize to the range [-1, 1] as used in GANs
    input_image = (input_image / 127.5) - 1
    target_image = (target_image / 127.5) - 1
    return input_image, target_image

def load_image_train(image_file):
    input_image, target_image = load(image_file)
    input_image, target_image = resize(input_image, target_image, IMG_HEIGHT, IMG_WIDTH)
    input_image, target_image = normalize(input_image, target_image)
    return input_image, target_image

def load_image_test(image_file):
    input_image, target_image = load(image_file)
    input_image, target_image = resize(input_image, target_image, IMG_HEIGHT, IMG_WIDTH)
    input_image, target_image = normalize(input_image, target_image)
    return input_image, target_image

# Create Train and Test datasets

In [None]:
# Define dataset paths
train_path = PATH / 'train'  # Assuming the dataset has a 'train' folder
test_path = PATH / 'val'     # Assuming the dataset has a 'val' folder

def load_dataset(dataset_path, load_image_fn):
    # Get a list of image file paths
    dataset = tf.data.Dataset.list_files(str(dataset_path/'*.jpg'))
    
    # Load and preprocess the images
    dataset = dataset.map(load_image_fn, num_parallel_calls=AUTOTUNE)
    
    # Shuffle, batch, and prefetch
    dataset = dataset.shuffle(buffer_size=400)  # Adjust buffer_size as needed
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    
    return dataset

# Create training and test datasets
train_dataset = load_dataset(train_path, load_image_train)
test_dataset = load_dataset(test_path, load_image_test)

# Get a sample from the training dataset
for input_image, target_image in train_dataset.take(1):
    print(f"Train Input image shape: {input_image.shape}")
    print(f"Train Target image shape: {target_image.shape}")

# Similarly, for the test dataset
for input_image, target_image in test_dataset.take(1):
    print(f"Test Input image shape: {input_image.shape}")
    print(f"Test Target image shape: {target_image.shape}")

# U-Net Generator Model

In [None]:
# Encoder
Output_channel = 3
def downsample (filters: int, kernel_size: int,  stride: int = 2, apply_batchnorm: bool=True):
    """ this function applies down sampling in the Encoder part of Unet Generator model

    Args:
        filters (int): number of filters in the convolutional layer
        kernel_size (int): size of the filters
        apply_batchnorm (bool, optional): applying batch normalization after the convolutional layer. Defaults to True.
    """
    # weight initialization
    initializer = tf.random_normal_initializer (0., 0.02)
    # create down sampling layer
    downsampled_feature_map = tf.keras.Sequential()
    downsampled_feature_map.add(tf.keras.layers.Conv2D(filters, 
                                                       kernel_size, 
                                                       strides=stride, 
                                                       padding = 'same', 
                                                       kernel_initializer = initializer, 
                                                       use_bias = False))
    # conditionally add batch normalization
    if apply_batchnorm:
        downsampled_feature_map.add(tf.keras.layers.BatchNormalization(momentum= 0.8))
    # leaky relu activation    
    downsampled_feature_map.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    
    return downsampled_feature_map


# Decoder
def upsample(filters: int, kernel_size: int, apply_dropout : bool= False):
    """This is a function definition for creating an up sampling block in Encoder part of the U-Net model

    Args:
        filters (int): number of filters in the transposed convolutional layer
        kernel_size (int):  size of the filters
        apply_dropout (bool, optional): applying drop out regularization after the transposed convolutional layer. Defaults to False.
    """
    # weight initialization
    initializer = tf.random_normal_initializer (0., 0.02)
    # create up sampling layer
    upsampled_feature_map = tf.keras.Sequential() 
    upsampled_feature_map.add(tf.keras.layers.Conv2DTranspose(filters, kernel_size, strides=2, padding = 'same', kernel_initializer=initializer, use_bias= False ))
    # add batch normalization
    upsampled_feature_map.add(tf.keras.layers.BatchNormalization(momentum= 0.8))
    # conditionally add dropout
    if apply_dropout:
        upsampled_feature_map.add(tf.keras.layers.Dropout(0.5))
    # relu activation
    upsampled_feature_map.add(tf.keras.layers.ReLU())
    return upsampled_feature_map

# Generator Class
def Generator(image_shape=(256, 256, 3)):
    input_image=tf.keras.layers.Input(shape=image_shape)
    #C64-C128-C256-C512-C512-C512-C512-C512
    down_stack = [
        downsample(filters=64, kernel_size=4, apply_batchnorm=False),
        downsample(filters=128, kernel_size=4, apply_batchnorm=True),
        downsample(filters=256, kernel_size=4, apply_batchnorm=True),
        downsample(filters=512, kernel_size=4, apply_batchnorm=True),
        downsample(filters=512, kernel_size=4, apply_batchnorm=True),
        downsample(filters=512, kernel_size=4, apply_batchnorm=True),
        downsample(filters=512, kernel_size=4, apply_batchnorm=True),
        downsample(filters=512, kernel_size=4, apply_batchnorm=False), # bottleneck, no batch norm 
    ]
    # CD512-CD512-CD512-CD512-C256-C128-C64
    up_stack = [
        upsample(filters=512, kernel_size=4, apply_dropout=True),
        upsample(filters=512, kernel_size=4, apply_dropout=True),
        upsample(filters=512, kernel_size=4, apply_dropout=True),
        upsample(filters=512, kernel_size=4, apply_dropout=False),
        upsample(filters=256, kernel_size=4, apply_dropout=False),
        upsample(filters=128, kernel_size=4, apply_dropout=False),
        upsample(filters=64, kernel_size=4, apply_dropout=False),
    ]
    initializer = tf.random_normal_initializer(0., 0.02)
    x= input_image
    skips = []
    # Encoder
    for down in down_stack:
        x = down(x)
        # keep teh output of each layer for using it as skip connection
        skips.append(x)
    skips = reversed(skips[:-1]) # we remove the bottleneck while the x includes the bottleneck
    # Decoder
    for up, skip in zip (up_stack, skips):
        # the first x is the bottleneck
        x = up(x)
        x= tf.keras.layers.Concatenate()([x, skip])
    # activation=tanh makes sure that the output values are between -1 and 1
    output_image= tf.keras.layers.Conv2DTranspose (filters = Output_channel, 
                                                  kernel_size = 4, 
                                                  strides = 2, 
                                                  padding= 'same',
                                                  kernel_initializer= initializer,
                                                  activation= 'tanh') (x)
    generative_model = tf.keras.Model(inputs = input_image, outputs = output_image)
    return generative_model
gen_model = Generator()
gen_model.summary()

# PatchGAN Discriminator Model

In [None]:
 # A 70x70 PatchGAN will classify 70x70 patches of the input image as real or fake
def Discriminator(image_shape=(256, 256, 3)):
    """Builds a PatchGAN-based Discriminator model for a Conditional GAN.
    This discriminator is designed to classify whether each 70x70 patch in the input 
    image is real or generated (fake). It concatenates the input image and target image 
    and applies several convolutional layers to process this input, following the 
    architecture described in the PatchGAN model.

    Args:
       image_shape (tuple): The shape of the input image. Default is (256, 256, 3).

    Returns:
        tf.keras.Model: The compiled Keras Model representing the discriminator. 
        The model takes two images as inputs (input and target) and outputs a matrix 
        of patch classifications, where each value represents whether a specific patch 
        of the image is real or fake.
    """
    # Weight initialization as described in the original paper
    initializer = tf.random_normal_initializer(0., 0.02)
    input_image  = tf.keras.layers.Input(shape=image_shape, name = 'input_image')
    target_image = tf.keras.layers.Input(shape=image_shape, name = 'target_image')    
    # Concatenate the input and target image (for conditional GAN)
    x = tf.keras.layers.concatenate([input_image, target_image])
    # PatchGAN architecture
    # C64 - Conv layer with 64 filters, 4x4 kernel, strides of 2, followed by LeakyReLU
    down1 = downsample (filters = 64,   kernel_size=4, apply_batchnorm = False) (x)
    # C128
    down2 = downsample (filters = 128, kernel_size=4, apply_batchnorm = True) (down1)
    # C256
    down3 = downsample (filters = 256, kernel_size=4, apply_batchnorm = True) (down2)
    # C512- 
    down4 = downsample (filters = 512, kernel_size=4,  apply_batchnorm = True) (down3)
    # The shape of down4 is 16x16 in which each pixel represents a 70x70 patch in the input image
    output = tf.keras.layers.Conv2D(filters=1, kernel_size=4, strides=1, padding= 'same',kernel_initializer=initializer)(down4)
    patch_output = tf.keras.layers.Activation('sigmoid')(output)
    discriminator_model = tf.keras.Model(inputs = [input_image, target_image], outputs = patch_output)
    return discriminator_model
# create a model
disc_model = Discriminator()
disc_model.summary()

# Generator Loss

In [93]:
LAMBDA = 100
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits= False) # because we have used sigmoid in the last layer of discriminator
def generator_loss(discriminator_generated_output, generative_output,  target):
    """this function calculates the loss value for the generator

    Args:
        discriminator_generated_output (tf.Tensor): output of the discriminator for the generated output
        generative_output (tf.Tensor): generated output from generator
        target (tf.Tensor): ground truth

    Returns:
        Tensor: it returns three different loss values
    """
    # label of real image = 1   label of fake image = 0
    gan_loss = loss_object(tf.ones_like(discriminator_generated_output), discriminator_generated_output)
    # L1 is used instead of L2 because L1 encourages less blurring
    L1_loss = tf.reduce_mean(tf.abs(target - generative_output))
    total_gen_loss = gan_loss + (LAMBDA * L1_loss)
    return total_gen_loss, gan_loss, L1_loss

# Discriminator Loss

In [94]:
def discriminator_loss(discriminator_real_output, discriminator_generated_output):
    """

    Args:
        discriminator_real_output (tf.Tensor): Discriminator's output for real images
        discriminator_generated_output (tf.Tensor): Discriminator's output for generated/fake images

    Returns:
        Tensor: Total loss value
    """
    # label of real image = 1   label of fake image = 0
    real_loss = loss_object(tf.ones_like(discriminator_real_output), discriminator_real_output)
    generated_loss = loss_object(tf.zeros_like(discriminator_real_output), discriminator_generated_output)
    total_disc_loss= real_loss + generated_loss
    return total_disc_loss

# Optimizers

In [95]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1= 0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1= 0.5)
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
checkpoint = tf.train.Checkpoint(generator_optimizer = generator_optimizer, 
                                                     discriminator_optimizer = discriminator_optimizer,
                                                     discriminator = disc_model,
                                                     generator = gen_model,
                                                     )

# Image Generating

In [96]:
def generate_images(model, test_input, tar):
    """This function allows you to generate and visualize the input image, ground truth image, and predicted image using a generator model.

    Args:
        model (tf.keras.Model): The trained generator model that takes the input image and generates the predicted image.
        test_input (tf.Tensor): A tensor representing the input image to the generator. Shape is expected to be [batch_size, height, width, channels].
        tar (tf.Tensor): A tensor representing the target or ground truth image. It should have the same shape as the `test_input` (excluding batch size).
    """
    prediction = model ( test_input, training= True)
    plt.figure(figsize= (15,15))
    display_list = [test_input[0], tar[0], prediction[0]]
    title= ['Input Image', 'Ground Truth', 'Prediction Image']
    for i in range(3):
        plt.subplot(1,3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()    

# Training

In [None]:
# Create a summary writer
log_dir = "logs/"
# Check if the directory exists, and create it if not
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
summary_writer = tf.summary.create_file_writer(log_dir)
# this decorator converts the 'train_step' function into a TensorFlow graph function, improving its performance.
@tf.function
def train_step (input_image, target, step):
    with tf.GradientTape() as generative_tape, tf.GradientTape() as discriminator_tape:
        # Generate an output
        generative_output= gen_model(input_image, training = True)
        # Check decision of discriminator in detecting real image
        discriminator_real_output = disc_model([input_image, target], training = True)
        # Check decision of discriminator in detecting fake/generated image
        discriminator_generated_output = disc_model([generative_output, target], training = True)
        # Calculate loss values
        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(discriminator_generated_output, generative_output, target)
        disc_loss = discriminator_loss (discriminator_real_output, discriminator_generated_output)
        # Compute the gradients of the generator and discriminator losses
        generator_gradients = generative_tape.gradient(gen_total_loss, gen_model.trainable_variables)
        discriminator_gradients = discriminator_tape.gradient(disc_loss, disc_model.trainable_variables)
        # Apply the computed gradients to update the trainable variables of the generator and discriminator models
        generator_optimizer.apply_gradients (zip(generator_gradients, gen_model.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(discriminator_gradients, disc_model.trainable_variables))
        # Write the generator and discriminator losses as scalar summaries for visualization using TensorBoard
        with summary_writer.as_default():
            tf.summary.scalar('generative_total_loss', gen_total_loss, step= step//1000)
            tf.summary.scalar('generative_gan_loss', gen_gan_loss, step= step//1000)
            tf.summary.scalar('generative_l1_loss', gen_l1_loss, step= step//1000)
            tf.summary.scalar('discriminator_loss', disc_loss,  step= step//1000)
def fit(train_dataset, test_dataset, steps):
    example_input, example_target = next(iter(test_dataset.take(1)))
    start= time.time()
    for step, (input_image, target) in train_dataset.repeat().take(steps).enumerate():
        if (step) % 1000 == 0:
            display.clear_output(wait= True)
            if step !=0:
                print(f'Time taken for 1000 steps: {time.time() - start:.02f} sec\n')
            start = time.time()
            generate_images(gen_model, example_input, example_target)
            print(f"Step: {step//1000}k")
        train_step(input_image, target, step)
        if (step+1) % 10 == 0:
            print('.', end= '', flush = True)
        if (step +1) % 5000 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)
# train the model
fit(train_dataset, test_dataset, steps=4000)