### Setting the stage for building and training the Pix2Pix model

In [None]:
# Import statements

import tensorflow as tf
import os
import time
import datetime
import numpy as np
from tqdm.notebook import tqdm
from natsort import natsorted
from PIL import Image
import io

from matplotlib import pyplot as plt
from IPython import display

In [None]:
# Paths to dataset directories

train_path = "sonar_data_train" # Change to lidar_data_train accordingly
test_path = "sonar_data_test" # Change to lidar_data_test accordingly

In [None]:
# Visualising the shape of a sample image from training set

sample_img = tf.io.read_file(os.path.join(train_path, 'train_562.jpg'))
sample_img = tf.io.decode_jpeg(sample_img)
plt.figure()
plt.axis("off")
plt.imshow(sample_img)
print(sample_img.shape)

In [None]:
# Function to load the file and convert it into a Tensor

def load(image_file):
    # Read and decode an image file to a uint8 tensor
    image = tf.io.read_file(image_file)
    image = tf.io.decode_jpeg(image)

    # Split each image tensor into two tensors:
    # - one with visual image
    # - one with the sensor scan 
    w = 1024
    input_image = image[:, :w, :]
    real_image = image[:, w:, :]

    # Convert both images to float32 tensors
    input_image = tf.cast(input_image, tf.float32)
    real_image = tf.cast(real_image, tf.float32)

    return input_image, real_image

In [None]:
# Resizing the input image and the corresponding real image for processing

inp, re = load(os.path.join(train_path, 'train_562.jpg'))
inp = tf.image.resize(inp, [128, 256])
re = tf.image.resize(re, [128, 256])

plt.figure()
plt.axis("off")
plt.imshow(inp/255.0)
plt.figure()
plt.axis("off")
plt.imshow(re/255.0)

In [None]:
# Setting the training parameters

# The training set contains ~ 1200 images
BUFFER_SIZE = 1200

# Baseline - A batch size of 1 produced better results in the original pix2pix paper
# Batch-size-4x - A batch size of 4 is used for the second experiment
BATCH_SIZE = 1

# Resize settings to the smaller image to maintain stability in output
IMG_WIDTH = 256
IMG_HEIGHT = 128

### Applying random jittering to preprocess the training set

In [None]:
# Resizing the image randomly

def resize(input_image, real_image, height, width):
    input_image = tf.image.resize(input_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = tf.image.resize(real_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    
    return input_image, real_image

In [None]:
# Cropping out random patches from the image

def random_crop(input_image, real_image):
    stacked_image = tf.stack([input_image, real_image], axis=0)
    cropped_image = tf.image.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
    
    return cropped_image[0], cropped_image[1]

In [None]:
# Normalising the image pixel values

def normalise(input_image, real_image):
    input_image = (input_image/127.5)-1
    real_image = (real_image/127.5)-1
    
    return input_image, real_image

In [None]:
# Mirroring the image

@tf.function()
def random_jitter(input_image, real_image):
    
    input_image, real_image = resize(input_image, real_image, 128, 256) # Original size = 
    
    input_image, real_image = random_crop(input_image, real_image)
    
    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        real_image = tf.image.flip_left_right(real_image)
    
    return input_image, real_image

In [None]:
# Visualising a few preprocessed images from the training set

plt.figure(figsize=(5,5))
for i in range(4):
    rj_inp, rj_re = random_jitter(inp, re)
    plt.subplot(2,2,i+1)
    plt.imshow(rj_inp/255.0)
    plt.axis("off")
plt.show()

In [None]:
# Loading helper function for the training images

def load_image_train(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = random_jitter(input_image, real_image)
    input_image, real_image = normalise(input_image, real_image)
    
    return input_image, real_image

In [None]:
# Loading helper function for the testing images

def load_image_test(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = resize(input_image, real_image, IMG_HEIGHT, IMG_WIDTH)
    input_image, real_image = normalise(input_image, real_image)
    
    return input_image, real_image

In [None]:
# Building an input pipeline for training data

train_dataset = tf.data.Dataset.list_files(os.path.join(train_path, '*.jpg'))
train_dataset = train_dataset.map(load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)

In [None]:
# Building an input pipeline for testing data

file_names = natsorted(os.listdir(test_path))
file_paths = [os.path.join(test_path, file_name) for file_name in file_names if file_name.endswith('.jpg')]
test_dataset = tf.data.Dataset.from_tensor_slices(file_paths)
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(BATCH_SIZE)

### Building the Generator

In [None]:
OUTPUT_CHANNELS = 3

In [None]:
# The downsampling function of the Generator

def downsample(filters, size, apply_batchnorm = True):
    initialiser = tf.random_normal_initializer(0.,0.02)
    
    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initialiser, use_bias=False))
    
    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())
    
    result.add(tf.keras.layers.LeakyReLU())
    
    return result

down_model = downsample(3,4)
down_result = down_model(tf.expand_dims(inp, 0))

print(down_result.shape)

In [None]:
# The upsampling function of the Generator

def upsample(filters, size, apply_dropout = False):
    initialiser = tf.random_normal_initializer(0.,0.02)
    
    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same', kernel_initializer=initialiser, use_bias=False))
    
    result.add(tf.keras.layers.BatchNormalization())
    
    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))
    
    result.add(tf.keras.layers.ReLU())
    
    return result

up_model = upsample(3,4)
up_result = up_model(down_result)

print(up_result.shape)

In [None]:
# Building the Generator with the downsampling and upsampling functions

def Generator():
    inputs = tf.keras.layers.Input(shape=[128, 256, 3])
    
    down_stack = [
        downsample(64, 4, apply_batchnorm=False),
        downsample(128, 4),
        downsample(256, 4),
        downsample(512, 4),
        downsample(512, 4),
        downsample(512, 4),
        downsample(512, 4),
    ]
    
    up_stack = [
        upsample(512, 4, apply_dropout=True),
        upsample(512, 4, apply_dropout=True),
        upsample(512, 4, apply_dropout=True),
        upsample(512, 4),
        upsample(256, 4),
        upsample(128, 4),
        upsample(64, 4),
    ]
    
    initialiser = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4, strides=2, padding='same', kernel_initializer=initialiser, activation='tanh')
    
    x = inputs
    skips = []
    
    # Downsampling through the model
    for down in down_stack:
        x = down(x)
        skips.append(x)
    
    skips = reversed(skips[:-1])
    
    # Upsampling through the model
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])
        
    x = last(x)
    
    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
# Visualising the Generator

generator = Generator()
tf.keras.utils.plot_model(generator, to_file='gen.png', show_shapes=True, dpi=300)

In [None]:
# Testing the features learned by the Generator on the sample training image

gen_output = generator(inp[tf.newaxis, ...], training=False)
plt.imshow(gen_output[0, ...])

In [None]:
# Defining the Generator loss

LAMBDA = 100

loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
    
    # Mean absolute error
    l1_loss = tf.reduce_mean(tf.abs(target-gen_output))
    
    total_gen_loss = gan_loss + (LAMBDA*l1_loss)
    
    return total_gen_loss, gan_loss, l1_loss

### Building the discriminator

In [None]:
# Building the Discriminator by using a PatchGAN architecture

def Discriminator():
    initialiser = tf.random_normal_initializer(0., 0.02)
    
    inp = tf.keras.layers.Input(shape=[128, 256, 3], name='input_image')
    tar = tf.keras.layers.Input(shape=[128, 256, 3], name='target_image')
    
    x = tf.keras.layers.concatenate([inp, tar])
    
    down1 = downsample(64, 4, False)(x)
    down2 = downsample(128, 4)(down1)
    down3 = downsample(256, 4)(down2)
    
    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)
    conv = tf.keras.layers.Conv2D(512, 4, strides=1, kernel_initializer=initialiser, use_bias=False)(zero_pad1)
    
    batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
    
    leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
    
    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)
    
    last = tf.keras.layers.Conv2D(1, 4, strides=1, kernel_initializer=initialiser)(zero_pad2)
    
    return tf.keras.Model(inputs=[inp, tar], outputs=last)

In [None]:
# Visualising the Discriminator 

discriminator = Discriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=300)

In [None]:
# Testing the Discriminator on the sample image

disc_out = discriminator([inp[tf.newaxis, ...], gen_output], training=False)
plt.imshow(disc_out[0, ..., -1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()

In [None]:
# Defining the Discriminator loss

def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
    
    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
    
    total_disc_loss = real_loss + generated_loss
    
    return total_disc_loss

### Optimisers and Checkpoint saving function

In [None]:
# Defining Generator and Discriminator optimiser

generator_optimiser = tf.keras.optimizers.Adam(2e-4, beta_1=0.6)
discriminator_optimiser = tf.keras.optimizers.Adam(2e-4, beta_1=0.6)

In [None]:
# Defining checkpointts to save model weights

checkpoint_dir = 'training_checkpoints_v1'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimiser, discriminator_optimizer=discriminator_optimiser, generator=generator, discriminator=discriminator)

### Function to plot images during training

In [None]:
# Function to visualise images while training

def generate_images(model, test_input, tar):
  prediction = model(test_input, training=True)
  plt.figure(figsize=(15, 15))

  display_list = [test_input[0], tar[0], prediction[0]]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # Getting the pixel values in the [0, 1] range to plot.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

In [None]:
# Visualisining the first step of training

for example_input, example_target in test_dataset.take(1):
    generate_images(generator, example_input, example_target)

### Training phase

In [None]:
# Building the train step iterator as a helper function

@tf.function
def train_step(input_image, target, step):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)

        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)

        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimiser.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    discriminator_optimiser.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
    
    return gen_total_loss, disc_loss
    
    with summary_writer.as_default():
        tf.summary.scalar('gen_total_loss', gen_total_loss, step=step//1000)
        tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=step//1000)
        tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=step//1000)
        tf.summary.scalar('disc_loss', disc_loss, step=step//1000)


# Visualising adversarial losses

def show_losses(gen_total_loss, disc_loss):
    print("generator loss: ", gen_total_loss.numpy())
    print("Discriminator loss: ", disc_loss.numpy())

In [None]:
# Logs to visualise training metrics using Tensorflow

log_dir="logs/"

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [None]:
# Training the model

def fit(train_ds, test_ds, steps):
    example_input, example_target = next(iter(test_ds.take(1)))
    start = time.time()
    
    i = 1
    j = 100

    for step, (input_image, target) in train_ds.repeat().take(steps).enumerate():
        if (step) % 1000 == 0:
            display.clear_output(wait=True)

        if step != 0:
            print(f'Time taken for step no. {step} - {time.time()-start:.2f} sec\n')

        start = time.time()

        generate_images(generator, example_input, example_target)
        print(f"Step: {step//1000}k\n")

        train_step(input_image, target, step)


        if (step+1) % 100 == 0:
            print(j, ' steps done\n', end='', flush=True)
            j=j+100
            gen_output = generator(input_image, training=True)
            disc_real_output = discriminator([input_image, target], training=True)
            disc_generated_output = discriminator([input_image, gen_output], training=True)
            gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
            disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
            show_losses(gen_total_loss, disc_loss)
            

        if (step+1) % 400 == 0:
            print(i, ' epochs completed\n', end='', flush=True)
            i= i+1
    

        # Save (checkpoint) the model every 5k steps
        if (step + 1) % 5000 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)

In [None]:
# Logging information to tensorboard and visualising training metrics

%load_ext tensorboard
%tensorboard --logdir {'logsfit'}

In [None]:
# Checking the availability of GPU

tf.config.experimental.list_physical_devices('GPU') 

In [None]:
# Checking if a saved checkpoint or model weight exists

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
# Starting the training

fit(train_dataset, test_dataset, steps=40000)

### Testing phase

In [None]:
# Applying the model to generate a few visualisations on testing set

for inp, tar in test_dataset.take(5):
    generate_images(generator, inp, tar)

In [None]:
# Function to apply the model on the test dataset directory and the save the output
# The output is saved in same visual format as input with concatenated scan and generated image

def generate_images_test(model, test_input, tar):
    prediction = model(test_input, training=True)
    fig = plt.figure(figsize=(15, 15))

    display_list = [test_input[0], prediction[0]]

    # Concatenating the images
    concatenated_image = np.concatenate(display_list, axis=1)

    # Displaying the concatenated image
    plt.imshow(concatenated_image * 0.5 + 0.5)
    
    # Adjusting the size of the axes to fill the entire figure
    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
    plt.margins(0, 0)
    
    # Removing the axis
    plt.axis('off')
    
    # Saving the image to a buffer
    buf = io.BytesIO()
    plt.savefig(buf, format='jpg', bbox_inches='tight', pad_inches=0)
    
    # Converting the buffer to an image
    buf.seek(0)
    img = Image.open(buf)
    
    return img

def save_images_test(model, test_dataset, directory, frames_test_dir, size=(512, 150)):
    # Create the directory if it doesn't exist
    if not os.path.exists(directory):
        os.makedirs(directory)
    
    # Get the list of file names in the frames_test directory
    file_names = natsorted(os.listdir(frames_test_dir))
    
    # Iterate over the test dataset
    for i, (inp, tar) in tqdm(enumerate(test_dataset)):
        # Generate the image
        img = generate_images_test(model, inp, tar)
        
        # Resize the image
        img = img.resize(size, Image.ANTIALIAS)
        
        # Get the file name for this image
        file_name = file_names[i]
        
        # Save the image to the specified directory
        img.save(os.path.join(directory, file_name))

save_images_test(generator, test_dataset, 'sonar_data_augmented', 'sonar_data_test')