In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Concatenate, LeakyReLU, Activation, Layer, InputSpec
from tensorflow.keras.layers.experimental.preprocessing import Rescaling
from tensorflow.keras.losses import MeanSquaredError, MeanAbsoluteError
from tensorflow.keras.metrics import Mean
from tensorflow_addons.layers import InstanceNormalization
from IPython import display
from time import perf_counter
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import os
import random

# preserve threads for GPU
os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'

# constrain VRAM usage
gpus = tf.config.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

# enable mixed-precision training
keras.mixed_precision.set_global_policy('mixed_float16')

In [None]:
class ReflectionPadding2D(Layer):
    def __init__(self, padding, **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def compute_output_shape(self, input_shape):
        shape = (input_shape[0], input_shape[1] + 2*self.padding[0], input_shape[2] + 2*self.padding[1], input_shape[3])
        return shape

    def call(self, x, mask=None):
        width_pad, height_pad = self.padding
        return tf.pad(x, [[0, 0], [height_pad, height_pad], [width_pad, width_pad], [0, 0]], 'REFLECT')

def define_discriminator(input_shape, name='Discriminator'):
    init = RandomNormal(stddev=0.02)
    input_image = Input(shape=input_shape)
    d = Conv2D(64, 4, strides=2, padding='same', kernel_initializer=init)(input_image)
    d = LeakyReLU(alpha=0.2)(d)
    d = Conv2D(128, 4, strides=2, padding='same', kernel_initializer=init, use_bias=False)(d)
    d = InstanceNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    d = Conv2D(256, 4, strides=2, padding='same', kernel_initializer=init, use_bias=False)(d)
    d = InstanceNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    d = Conv2D(512, 4, strides=2, padding='same', kernel_initializer=init, use_bias=False)(d)
    d = InstanceNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    d = Conv2D(512, 4, padding='same', kernel_initializer=init, use_bias=False)(d)
    d = InstanceNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    output = Conv2D(1, 4, padding='same', kernel_initializer=init, name='Patch', dtype=tf.float32)(d)
    model = Model(input_image, output, name=name)
    return model

def resnet_block(n_filters, input_layer):
    init = RandomNormal(stddev=0.02)
    g = ReflectionPadding2D(padding=(1, 1))(input_layer)
    g = Conv2D(n_filters, 3, kernel_initializer=init, use_bias=False)(g)
    g = InstanceNormalization()(g)
    g = Activation('relu')(g)
    g = ReflectionPadding2D(padding=(1, 1))(g)
    g = Conv2D(n_filters, 3, kernel_initializer=init, use_bias=False)(g)
    g = InstanceNormalization()(g)
    g = Concatenate()([g, input_layer])
    return g

def define_generator(input_shape, output_channel, n_resnet=9, name='Generator'):
    init = RandomNormal(stddev=0.02)
    input_image = Input(shape=input_shape)
    g = ReflectionPadding2D(padding=(3, 3))(input_image)
    g = Conv2D(64, 7, kernel_initializer=init, use_bias=False)(g)
    g = InstanceNormalization()(g)
    g = Activation('relu')(g)
    g = ReflectionPadding2D(padding=(1, 1))(g)
    g = Conv2D(128, 3, strides=2, kernel_initializer=init, use_bias=False)(g)
    g = InstanceNormalization()(g)
    g = Activation('relu')(g)
    g = ReflectionPadding2D(padding=(1, 1))(g)
    g = Conv2D(256, 3, strides=2, kernel_initializer=init, use_bias=False)(g)
    g = InstanceNormalization()(g)
    g = Activation('relu')(g)
    for _ in range(n_resnet):
        g = resnet_block(256, g)
    g = Conv2DTranspose(128, 3, strides=2, padding='same', kernel_initializer=init, use_bias=False)(g)
    g = InstanceNormalization()(g)
    g = Activation('relu')(g)
    g = Conv2DTranspose(64, 3, strides=2, padding='same', kernel_initializer=init, use_bias=False)(g)
    g = InstanceNormalization()(g)
    g = Activation('relu')(g)
    g = ReflectionPadding2D(padding=(3, 3))(g)
    g = Conv2D(output_channel, 7, kernel_initializer=init)(g)
    output_image = Activation('tanh', dtype=tf.float32)(g)
    model = Model(input_image, output_image, name=name)
    return model

In [None]:
def data_train_preprocessing(sample):
    image = tf.io.read_file(sample)
    image = tf.io.decode_jpeg(image)
    image = tf.cast(image, tf.float32) / 127.5 - 1.
    return image

def data_train_augmentation(image):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    if random.random() < 0.5:
        image = tf.image.rot90(image)
    return image

def data_val_preprocessing(sample):
    visible = tf.io.read_file(sample[0])
    thermal = tf.io.read_file(sample[1])
    visible = tf.io.decode_jpeg(visible)
    thermal = tf.io.decode_jpeg(thermal)
    visible = tf.cast(visible, tf.float32) / 127.5 - 1.
    thermal = tf.cast(thermal, tf.float32) / 127.5 - 1.
    return visible, thermal

def create_dataset(data, mode, batch_size=1):
    if mode == 'training':
        return data.map(
            data_train_preprocessing, 
            num_parallel_calls=AUTOTUNE, 
            deterministic=False).batch(
            batch_size, 
            num_parallel_calls=AUTOTUNE, 
            deterministic=False, drop_remainder=True).cache().shuffle(len(data)).map(
            data_train_augmentation, 
            num_parallel_calls=AUTOTUNE, 
            deterministic=False).prefetch(tf.data.AUTOTUNE)
    elif mode == 'validation':
        return data.map(
            data_val_preprocessing, 
            num_parallel_calls=AUTOTUNE).batch(
            BATCH_SIZE, 
            num_parallel_calls=AUTOTUNE, 
            drop_remainder=True).cache().prefetch(tf.data.AUTOTUNE)
    else:
        raise Exception("Invalid value for argument 'mode': %s. Supposed to be either 'training' or 'validation'."%(mode))

In [None]:
def plot_history(history):
    plt.plot(history['gen_v2t_loss'], label='Generator V2T Loss')
    plt.plot(history['gen_t2v_loss'], label='Generator T2V Loss')
    plt.title('Learning Curve')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

def test_model(gen_v2t, gen_t2v, data_val):
    NUM = 9
    plt.figure(figsize=(NUM+1, 5), dpi=300)
    for i, (real_v, real_t) in enumerate(data_val.take(NUM)):
        fake_t = gen_v2t(real_v, training=False)
        fake_v = gen_t2v(real_t, training=False)
        plt.subplot(4, NUM, i+1)
        plt.imshow((tf.squeeze(real_v) + 1.)/2.)
        plt.axis('off')
        plt.subplot(4, NUM, i+NUM+1)
        plt.imshow((tf.squeeze(fake_v) + 1.)/2.)
        plt.axis('off')
        plt.subplot(4, NUM, i+2*NUM+1)
        plt.imshow((tf.squeeze(real_t) + 1.)/2., cmap='gray')
        plt.axis('off')
        plt.subplot(4, NUM, i+3*NUM+1)
        plt.imshow((tf.squeeze(fake_t) + 1.)/2., cmap='gray')
        plt.axis('off')
    plt.subplots_adjust(
        top=1-0.5/5, 
        bottom=0.5/5, 
        left=0.5/(NUM+1), 
        right=1-0.5/(NUM+1), 
        wspace=0, 
        hspace=0)
    plt.show()

In [None]:
class CycleGAN:
    def __init__(self, disc_v, disc_t, gen_v2t, gen_t2v, disc_v_opt, disc_t_opt, gen_v2t_opt, gen_t2v_opt, pool_size=50):
        self.disc_v = disc_v
        self.disc_t = disc_t
        self.gen_v2t = gen_v2t
        self.gen_t2v = gen_t2v
        self.disc_v_opt = disc_v_opt
        self.disc_t_opt = disc_t_opt
        self.gen_v2t_opt = gen_v2t_opt
        self.gen_t2v_opt = gen_t2v_opt
        self.pool_size = pool_size
        self.pool_count = 0
        self.pool_v = []
        self.pool_t = []
        self.MSE = MeanSquaredError()
        self.MAE = MeanAbsoluteError()
        self.history = {'disc_v_loss':[], 'disc_t_loss':[], 'gen_v2t_loss':[], 'gen_t2v_loss':[]}
        self.disc_v_mean_loss = Mean(name='disc_v_loss')
        self.disc_t_mean_loss = Mean(name='disc_t_loss')
        self.gen_v2t_mean_loss = Mean(name='gen_v2t_loss')
        self.gen_t2v_mean_loss = Mean(name='gen_t2v_loss')
    
    def update_fake_pool(self, image, pool):
        selected = image
        if self.pool_count < 50:
            pool.append(image)
        elif random.random() < 0.5:
            index = random.randint(0, 49)
            selected = pool[index]
            pool[index] = image
        return selected
    
    def update_discriminator(self, disc, gen, origin, real, pool, disc_opt, disc_mean_loss):
        fake = gen(origin, training=False)
        fake = self.update_fake_pool(fake, pool)
        zeros_patch = tf.zeros([1,16,16,1])
        ones_patch = tf.ones([1,16,16,1])
        with tf.GradientTape() as disc_tape:
            fake_pred = disc(fake)
            real_pred = disc(real)
            disc_loss = (self.MSE(zeros_patch, fake_pred) + self.MSE(ones_patch, real_pred)) / 2.
        grad_of_disc = disc_tape.gradient(disc_loss, disc.trainable_variables)
        disc_opt.apply_gradients(zip(grad_of_disc, disc.trainable_variables))
        disc_mean_loss.update_state(disc_loss)
    
    def update_generator(self, disc, gen_o2t, gen_t2o, origin, target, gen_opt, gen_mean_loss):
        ones_patch = tf.ones([1,16,16,1])
        backward_fake = gen_t2o(target, training=False)
        with tf.GradientTape() as gen_tape:
            forward_fake = gen_o2t(origin)
            forward_fake_pred = disc(forward_fake, training=False)
            adversarial_loss = self.MSE(ones_patch, forward_fake_pred)
            forward_cycle = gen_t2o(forward_fake, training=False)
            forward_loss = self.MAE(origin, forward_cycle)
            backward_cycle = gen_o2t(backward_fake)
            backward_loss = self.MAE(target, backward_cycle)
            gen_loss = adversarial_loss + 10*(forward_loss + backward_loss)
        grad_of_gen = gen_tape.gradient(gen_loss, gen_o2t.trainable_variables)
        gen_opt.apply_gradients(zip(grad_of_gen, gen_o2t.trainable_variables))
        gen_mean_loss.update_state(gen_loss)
    
    @tf.function(jit_compile=True)
    def train_step(self, real_v, real_t):
        self.update_discriminator(self.disc_v, self.gen_t2v, real_t, real_v, self.pool_v, self.disc_v_opt, self.disc_v_mean_loss)
        self.update_generator(self.disc_v, self.gen_t2v, self.gen_v2t, real_t, real_v, self.gen_t2v_opt, self.gen_t2v_mean_loss)
        self.update_discriminator(self.disc_t, self.gen_v2t, real_v, real_t, self.pool_t, self.disc_t_opt, self.disc_t_mean_loss)
        self.update_generator(self.disc_t, self.gen_v2t, self.gen_t2v, real_v, real_t, self.gen_v2t_opt, self.gen_v2t_mean_loss)
        
    def fit(self, visible_train, thermal_train, data_val, epochs=1):
        for epoch in range(epochs):
            tic = perf_counter()
            self.disc_v_mean_loss.reset_states()
            self.disc_t_mean_loss.reset_states()
            self.gen_v2t_mean_loss.reset_states()
            self.gen_t2v_mean_loss.reset_states()
            for real_v, real_t in tqdm(zip(visible_train, thermal_train), total=len(visible_train)):
                model.train_step(real_v, real_t)
                self.pool_count += 1
            self.history['disc_v_loss'].append(self.disc_v_mean_loss.result().numpy())
            self.history['disc_t_loss'].append(self.disc_t_mean_loss.result().numpy())
            self.history['gen_v2t_loss'].append(self.gen_v2t_mean_loss.result().numpy())
            self.history['gen_t2v_loss'].append(self.gen_t2v_mean_loss.result().numpy())
            display.clear_output(wait=True)
            print("Epoch %d/%d - %.1fs | disc_v_loss: %.5f - disc_t_loss: %.5f - gen_v2t_loss: %.5f - gen_t2v_loss: %.5f"%(
                epoch+1, epochs, perf_counter() - tic, 
                self.history['disc_v_loss'][-1], self.history['disc_t_loss'][-1], 
                self.history['gen_v2t_loss'][-1], self.history['gen_t2v_loss'][-1]))
            plot_history(self.history)
            test_model(self.gen_v2t, self.gen_t2v, data_val)
        return self.history

In [None]:
data_path = "./processed_data"
training_path = os.path.join(data_path, "training")
validation_path = os.path.join(data_path, "validation_2")
visible_training_path = os.path.join(training_path, "visible")
thermal_training_path = os.path.join(training_path, "thermal")
visible_validation_path = os.path.join(validation_path, "visible")
thermal_validation_path = os.path.join(validation_path, "thermal")
visible_training_files = [os.path.join(visible_training_path, filename) for filename in os.listdir(visible_training_path)]
thermal_training_files = [os.path.join(thermal_training_path, filename) for filename in os.listdir(thermal_training_path)]
visible_validation_files = [os.path.join(visible_validation_path, filename) for filename in sorted(os.listdir(visible_validation_path))]
thermal_validation_files = [os.path.join(thermal_validation_path, filename) for filename in sorted(os.listdir(thermal_validation_path))]
validation_files = list(zip(visible_validation_files, thermal_validation_files))
random.shuffle(visible_training_files)
random.shuffle(thermal_training_files)

AUTOTUNE = tf.data.AUTOTUNE
BATCH_SIZE = 1
visible_train = tf.data.Dataset.from_tensor_slices(visible_training_files)
thermal_train = tf.data.Dataset.from_tensor_slices(thermal_training_files)
data_val = tf.data.Dataset.from_tensor_slices(validation_files)
visible_train = create_dataset(visible_train, 'training', batch_size=BATCH_SIZE)
thermal_train = create_dataset(thermal_train, 'training', batch_size=BATCH_SIZE)
data_val = create_dataset(data_val, 'validation', batch_size=BATCH_SIZE)
print("Number of training data (pairs):", len(visible_train))
print("Number of validation data (pairs):", len(data_val))

In [None]:
disc_v = define_discriminator(input_shape=(256,256,3), name='Discriminator_V')
disc_t = define_discriminator(input_shape=(256,256,1), name='Discriminator_T')
gen_v2t = define_generator((256,256,3), 1, n_resnet=9, name='Generator_V2T')
gen_t2v = define_generator((256,256,1), 3, n_resnet=9, name='Generator_T2V')
disc_v_opt = Adam(learning_rate=2e-4, beta_1=0.5)
disc_t_opt = Adam(learning_rate=2e-4, beta_1=0.5)
gen_v2t_opt = Adam(learning_rate=2e-4, beta_1=0.5)
gen_t2v_opt = Adam(learning_rate=2e-4, beta_1=0.5)
POOL_SIZE = 50
model = CycleGAN(disc_v, disc_t, gen_v2t, gen_t2v, disc_v_opt, disc_t_opt, gen_v2t_opt, gen_t2v_opt, pool_size=POOL_SIZE)

In [None]:
#keras.utils.plot_model(disc_v, to_file='Discriminator_V.png', show_shapes=True)
#keras.utils.plot_model(disc_t, to_file='Discriminator_T.png', show_shapes=True)
#keras.utils.plot_model(gen_v2t, to_file='Generator_V2T.png', show_shapes=True)
#keras.utils.plot_model(gen_t2v, to_file='Generator_T2V.png', show_shapes=True)

In [None]:
EPOCH = 25
history = model.fit(visible_train, thermal_train, data_val, epochs=EPOCH)

In [None]:
model.gen_v2t.save("./cyclegan/v2t", include_optimizer=False)
model.gen_t2v.save("./cyclegan/t2v", include_optimizer=False)