# Casual2Professional CycleGAN Tester

Tests translation of photos using trained CycleGAN model.

## Imports and Declarations

In [None]:
import os
import numpy as np

import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow.data as tf_data

from tensorflow.keras import layers

from tensorflow.keras.models import Model

from tensorflow_addons.layers import InstanceNormalization

from tensorflow.keras.initializers import RandomNormal

from tensorflow.keras.callbacks import Callback
from tensorflow.keras.callbacks import ModelCheckpoint

from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.losses import MeanAbsoluteError

from tensorflow.keras.optimizers import Adam

from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.preprocessing.image import array_to_img
from tensorflow.keras.preprocessing.image import load_img

In [None]:
# Local Imports
from cyclegan_model import *

In [None]:
image_dataset_path = './casual2professional/'

pic_size = 256
read_image_size = (pic_size, pic_size)
model_image_size = (pic_size, pic_size, 3)
resent_blocks = 9

autotune = tf.data.experimental.AUTOTUNE

buffer_size = 256
batch_size = 1

# test_sample_size = 10

# File to use for the trained model weights
test_epoch = 190
weight_file = './c2p_{}_checkpoints/cyclegan_checkpoints.{:03d}'.format(pic_size, test_epoch)

# Test Translation Output Folder
test_image_output_path = './c2p_{}_test_output_epoch_{}/'.format(pic_size, test_epoch)

## Load and Convert Test Image Dataset

In [None]:
def load_images(path, size = read_image_size):
    data_list = list()
    for filename in os.listdir(path):
        pixels = load_img(path + filename, target_size = size)
        pixels = img_to_array(pixels)
        
        data_list.append(pixels)
    return np.asarray(data_list)

def convert_image_to_dataset(image_data, label):
    labels = [label] * len(image_data)
    image_dataset = tf_data.Dataset.from_tensor_slices((image_data, labels))
    
    return image_dataset

def normalise_img(img):
    img = tf.cast(img, dtype = tf.float32)
    
    # Map values in the range [-1, 1]
    return (img / 127.5) - 1.0

def preprocess_test_image(img, label):
    # Only resizing and normalisation for the test images
    img = tf.image.resize(img, [model_image_size[0], model_image_size[1]])
    img = normalise_img(img)
    
    return img

In [None]:
test_A = load_images(image_dataset_path + 'testA/')
test_B = load_images(image_dataset_path + 'testB/')

### Convert to tensor datasets and perform preprocessing

In [None]:
# Convert image numpy arrays to tf datasets.
# Set Domain A as label 0 and Domain B as label 1
test_A = convert_image_to_dataset(test_A, 0)
test_A = (test_A.map(preprocess_test_image, num_parallel_calls = autotune).cache().shuffle(buffer_size).batch(batch_size))

test_B = convert_image_to_dataset(test_B, 1)
test_B = (test_B.map(preprocess_test_image, num_parallel_calls = autotune).cache().shuffle(buffer_size).batch(batch_size))

## Reload Trained Model

### Create Empty Model

In [None]:
gen_G = get_resnet_generator(name = "generator_G", num_residual_blocks = resent_blocks, model_image_size = model_image_size)
gen_F = get_resnet_generator(name = "generator_F", num_residual_blocks = resent_blocks, model_image_size = model_image_size)

disc_X = get_discriminator(name = "discriminator_X", model_image_size = model_image_size)
disc_Y = get_discriminator(name = "discriminator_Y", model_image_size = model_image_size)

# Create CycleGAN model
cycle_gan_tester = CycleGan(generator_G = gen_G, generator_F = gen_F, discriminator_X = disc_X, discriminator_Y = disc_Y)

learn_rate = 2e-4
beta_1_value = 0.5

# Compile the model
cycle_gan_tester.compile(
    gen_G_optimizer = Adam(learning_rate = learn_rate, beta_1 = beta_1_value),
    gen_F_optimizer = Adam(learning_rate = learn_rate, beta_1 = beta_1_value),
    disc_X_optimizer = Adam(learning_rate = learn_rate, beta_1 = beta_1_value),
    disc_Y_optimizer = Adam(learning_rate = learn_rate, beta_1 = beta_1_value),
    gen_loss_fn = generator_loss_fn,
    disc_loss_fn = discriminator_loss_fn
)

### Load Weights

In [None]:
cycle_gan_tester.load_weights(weight_file).expect_partial()
print('Weights loaded successfully. Generating translation with test images...')

## Translate Test Pictures with Trained Model

In [None]:
def make_image_output_dir(folder_path):
    if not os.path.exists(test_image_output_path):
        os.makedirs(test_image_output_path)

def plot_save_translation_pair(image, translated, folder_path, epoch, iteration):
    fig, ax = plt.subplots(1, 2, figsize = (5, 3))
    
    ax[0].axis('off')
    ax[0].imshow(image)
    
    ax[1].axis('off')
    ax[1].imshow(translated)
    
    plt.savefig('{}/epoch_{}_pic_{}.png'.format(folder_path, epoch, iteration))
    plt.show()

In [None]:
# Currently only testing for A to B

make_image_output_dir(test_image_output_path)

for i, img in enumerate(test_A):
    predict = cycle_gan_tester.gen_G(img)[0].numpy()
    predict = ((predict * 127.5) + 127.5).astype(np.uint8)
    img = ((img[0] * 127.5) + 127.5).numpy().astype(np.uint8)
    
    plot_save_translation_pair(img, predict, test_image_output_path, test_epoch, i)