# Casual2Professional CycleGan Loader 

Loads CycleGAN from file and continues testing

## Imports and Declarations

In [None]:
import os
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow.data as tf_data

from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.preprocessing.image import load_img

from tensorflow.keras.callbacks import ModelCheckpoint

In [None]:
# Local Imports
from cyclegan_model import *

In [None]:
# Common Parameters
pic_size = 256
read_image_size = (pic_size, pic_size)
model_image_size = (pic_size, pic_size, 3)
resent_blocks = 9

samples_display_size = 10

# Image Preprocessing Parameters
image_dataset_path = './casual2professional/'

autotune = tf.data.experimental.AUTOTUNE
buffer_size = 256
batch_size = 1

# Model Parameters
epochs_to_train = 100

# Checkpoint parameters
epoch_load = 100
model_load_file = './c2p_{}_checkpoints/cyclegan_checkpoints.{:03d}'.format(pic_size, epoch_load)
checkpoint_filepath = "./c2p_checkpoints_loader/cyclegan_checkpoints_cont.{epoch:03d}"

# Loss values file
loss_value_file = 'loss_values_loader.csv'

## Load and Convert Image Dataset

In [None]:
def load_images(path, size = read_image_size):
    data_list = list()
    for filename in os.listdir(path):
        pixels = load_img(path + filename, target_size = size)
        pixels = img_to_array(pixels)
        
        data_list.append(pixels)
    return np.asarray(data_list)

def convert_image_to_dataset(image_data, label):
    labels = [label] * len(image_data)
    image_dataset = tf_data.Dataset.from_tensor_slices((image_data, labels))
    
    return image_dataset

def normalise_img(img):
    img = tf.cast(img, dtype = tf.float32)
    
    # Map values in the range [-1, 1]
    return (img / 127.5) - 1.0

def preprocess_train_image(img, label):
    # Random flip
    img = tf.image.random_flip_left_right(img)
    
    # Resize to the original size first
    img = tf.image.resize(img, [*read_image_size])
    
    # Random crop to model input_size
    img = tf.image.random_crop(img, size = [*model_image_size])
    
    # Normalise the pixel values in the range [-1, 1]
    img = normalise_img(img)
    
    return img

def preprocess_test_image(img, label):
    # Only resizing and normalisation for the test images
    img = tf.image.resize(img, [model_image_size[0], model_image_size[1]])
    img = normalise_img(img)
    
    return img

In [None]:
train_A = load_images(image_dataset_path + 'trainA/')
test_A = load_images(image_dataset_path + 'testA/')

train_B = load_images(image_dataset_path + 'trainB/')
test_B = load_images(image_dataset_path + 'testB/')

### Convert to tensor datasets and perform preprocessing

In [None]:
# Convert image numpy arrays to tf datasets.
# Set Domain A as label 0 and Domain B as label 1
train_A = convert_image_to_dataset(train_A, 0)
test_A = convert_image_to_dataset(test_A, 0)

train_B = convert_image_to_dataset(train_B, 1)
test_B = convert_image_to_dataset(test_B, 1)

train_A = (train_A.map(preprocess_train_image, num_parallel_calls = autotune).cache().shuffle(buffer_size).batch(batch_size))
test_A = (test_A.map(preprocess_test_image, num_parallel_calls = autotune).cache().shuffle(buffer_size).batch(batch_size))

train_B = (train_B.map(preprocess_train_image, num_parallel_calls = autotune).cache().shuffle(buffer_size).batch(batch_size))
test_B = (test_B.map(preprocess_test_image, num_parallel_calls = autotune).cache().shuffle(buffer_size).batch(batch_size))

## Create Empty CycleGAN Model

In [None]:
cycle_gan_model = create_default_cyclegan_model(resent_blocks, model_image_size)

## Load Weights to Model

In [None]:
cycle_gan_model.load_weights(model_load_file).expect_partial()
print('Weights loaded successfully')

### Continue Training

In [None]:
print('Continuing to train CycleGAN model from epoch {}...'.format(epoch_load + 1))

# Callbacks
plotter = GANMonitor(num_img = samples_display_size, test_A = test_A)
model_checkpoint_callback = ModelCheckpoint(filepath = checkpoint_filepath, verbose = 1)

# Training the model
history = cycle_gan_model.fit(tf.data.Dataset.zip((train_A, train_B)), epochs = epochs_to_train, callbacks = [plotter, model_checkpoint_callback])

### Export Loss Values

In [None]:
g_loss_history = pd.Series(history.history['G_loss'], name = 'G_loss')
f_loss_history = pd.Series(history.history['F_loss'], name = 'F_loss')

dx_loss_history = pd.Series(history.history['D_X_loss'], name = 'D_X_loss')
dy_loss_history = pd.Series(history.history['D_Y_loss'], name = 'D_Y_loss')

loss_df = pd.concat([g_loss_history, f_loss_history, dx_loss_history, dy_loss_history], axis = 1)
loss_df.to_csv(loss_value_file, index = False)