In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Concatenate, Dropout, LeakyReLU, Activation, Layer, InputSpec
from tensorflow.keras.layers.experimental.preprocessing import Rescaling
from tensorflow.keras.losses import BinaryCrossentropy, MeanAbsoluteError
from tensorflow.keras.metrics import Mean
from tensorflow_addons.layers import InstanceNormalization
from IPython import display
from time import perf_counter
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import os
import random

# preserve threads for GPU
os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'

# constrain VRAM usage
gpus = tf.config.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

# enable mixed-precision training
keras.mixed_precision.set_global_policy('mixed_float16')

In [None]:
def data_preprocessing(sample):
    visible = tf.io.read_file(sample[0])
    thermal = tf.io.read_file(sample[1])
    visible = tf.io.decode_jpeg(visible)
    thermal = tf.io.decode_jpeg(thermal)
    visible = tf.cast(visible, tf.float32) / 127.5 - 1.
    thermal = tf.cast(thermal, tf.float32) / 127.5 - 1.
    return visible, thermal

def data_augmentation(visible, thermal):
    if random.random() < 0.5:
        visible = tf.image.flip_left_right(visible)
        thermal = tf.image.flip_left_right(thermal)
    if random.random() < 0.5:
        visible = tf.image.flip_up_down(visible)
        thermal = tf.image.flip_up_down(thermal)
    if random.random() < 0.5:
        visible = tf.image.rot90(visible)
        thermal = tf.image.rot90(thermal)
    return visible, thermal

In [None]:
def define_discriminator(visible_shape=(256,256,3), thermal_shape=(256,256,1)):
    init = RandomNormal(stddev=0.02)
    visible_image = Input(shape=visible_shape)
    thermal_image = Input(shape=thermal_shape)
    merged = Concatenate()([visible_image, thermal_image])
    d = Conv2D(64, 4, strides=2, padding='same', kernel_initializer=init)(merged)
    d = LeakyReLU(alpha=0.2)(d)
    d = Conv2D(128, 4, strides=2, padding='same', kernel_initializer=init, use_bias=False)(d)
    d = InstanceNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    d = Conv2D(256, 4, strides=2, padding='same', kernel_initializer=init, use_bias=False)(d)
    d = InstanceNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    d = Conv2D(512, 4, strides=2, padding='same', kernel_initializer=init, use_bias=False)(d)
    d = InstanceNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    d = Conv2D(512, 4, padding='same', kernel_initializer=init, use_bias=False)(d)
    d = InstanceNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    output = Conv2D(1, 4, padding='same', kernel_initializer=init, name='Patch', dtype=tf.float32)(d)
    model = Model([visible_image, thermal_image], output, name='Discriminator')
    return model

class ReflectionPadding2D(Layer):
    def __init__(self, padding, **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def compute_output_shape(self, input_shape):
        shape = (input_shape[0], input_shape[1] + 2*self.padding[0], input_shape[2] + 2*self.padding[1], input_shape[3])
        return shape

    def call(self, x, mask=None):
        width_pad, height_pad = self.padding
        return tf.pad(x, [[0, 0], [height_pad, height_pad], [width_pad, width_pad], [0, 0]], 'REFLECT')

def resnet_block(n_filters, input_layer):
    init = RandomNormal(stddev=0.02)
    g = ReflectionPadding2D(padding=(1, 1))(input_layer)
    g = Conv2D(n_filters, 3, kernel_initializer=init, use_bias=False)(g)
    g = InstanceNormalization()(g)
    g = Activation('relu')(g)
    g = ReflectionPadding2D(padding=(1, 1))(g)
    g = Conv2D(n_filters, 3, kernel_initializer=init, use_bias=False)(g)
    g = InstanceNormalization()(g)
    g = Concatenate()([g, input_layer])
    return g

def define_generator(input_shape=(256,256,3), n_resnet=9):
    init = RandomNormal(stddev=0.02)
    input_image = Input(shape=input_shape)
    g = ReflectionPadding2D(padding=(3, 3))(input_image)
    g = Conv2D(64, 7, kernel_initializer=init, use_bias=False)(g)
    g = InstanceNormalization()(g)
    g = Activation('relu')(g)
    g = ReflectionPadding2D(padding=(1, 1))(g)
    g = Conv2D(128, 3, strides=2, kernel_initializer=init, use_bias=False)(g)
    g = InstanceNormalization()(g)
    g = Activation('relu')(g)
    g = ReflectionPadding2D(padding=(1, 1))(g)
    g = Conv2D(256, 3, strides=2, kernel_initializer=init, use_bias=False)(g)
    g = InstanceNormalization()(g)
    g = Activation('relu')(g)
    for _ in range(n_resnet):
        g = resnet_block(256, g)
    g = Conv2DTranspose(128, 3, strides=2, padding='same', kernel_initializer=init, use_bias=False)(g)
    g = InstanceNormalization()(g)
    g = Activation('relu')(g)
    g = Conv2DTranspose(64, 3, strides=2, padding='same', kernel_initializer=init, use_bias=False)(g)
    g = InstanceNormalization()(g)
    g = Activation('relu')(g)
    g = ReflectionPadding2D(padding=(3, 3))(g)
    g = Conv2D(1, 7, kernel_initializer=init)(g)
    output_image = Activation('tanh', dtype=tf.float32)(g)
    model = Model(input_image, output_image)
    return model

"""
def encoder_block(input_layer, n_filters, layernorm=True):
    init = RandomNormal(stddev=0.02)
    if layernorm:
        g = Conv2D(n_filters, 4, strides=2, padding='same', kernel_initializer=init, use_bias=False)(input_layer)
        g = InstanceNormalization()(g)
    else:
        g = Conv2D(n_filters, 4, strides=2, padding='same', kernel_initializer=init)(input_layer)
    g_a = LeakyReLU(alpha=0.2)(g)
    return g, g_a

def decoder_block(input_layer, skip_in, n_filters, dropout=True):
    init = RandomNormal(stddev=0.02)
    g = Conv2DTranspose(n_filters, 4, strides=2, padding='same', kernel_initializer=init, use_bias=False)(input_layer)
    g = InstanceNormalization()(g)
    if dropout:
        g = Dropout(0.5)(g)
    g = Concatenate()([g, skip_in])
    g = Activation('relu')(g)
    return g

def define_generator(visible_shape=(256,256,3)):
    init = RandomNormal(stddev=0.02)
    visible_image = Input(shape=visible_shape)
    e1, e1_a = encoder_block(visible_image, 64, layernorm=False)
    e2, e2_a = encoder_block(e1_a, 128)
    e3, e3_a = encoder_block(e2_a, 256)
    e4, e4_a = encoder_block(e3_a, 512)
    e5, e5_a = encoder_block(e4_a, 512)
    e6, e6_a = encoder_block(e5_a, 512)
    e7, e7_a = encoder_block(e6_a, 512)
    v = Conv2D(512, 4, strides=2, padding='same', activation='relu', kernel_initializer=init)(e7_a)
    d1 = decoder_block(v, e7, 512)
    d2 = decoder_block(d1, e6, 512)
    d3 = decoder_block(d2, e5, 512)
    d4 = decoder_block(d3, e4, 512, dropout=False)
    d5 = decoder_block(d4, e3, 256, dropout=False)
    d6 = decoder_block(d5, e2, 128, dropout=False)
    d7 = decoder_block(d6, e1, 64, dropout=False)
    thermal_image = Conv2DTranspose(1, 4, strides=2, padding='same', activation='tanh', dtype=tf.float32, kernel_initializer=init)(d7)
    model = Model(visible_image, thermal_image, name='Generator')
    return model
"""

In [None]:
def plot_history(history):
    #plt.plot(history['disc_loss'], label='Discriminator Loss')
    plt.plot(history['gen_loss'], label='Generator Loss')
    plt.title('Learning Curve')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

def test_model(gen, data_val):
    NUM = 9
    plt.figure(figsize=(NUM+1, 4), dpi=300)
    for i, (visible, thermal) in enumerate(data_val.take(NUM)):
        fake = gen(visible, training=False)
        plt.subplot(3, NUM, i+1)
        plt.imshow((tf.squeeze(visible) + 1.)/2.)
        plt.axis('off')
        plt.subplot(3, NUM, i+NUM+1)
        plt.imshow((tf.squeeze(fake) + 1.)/2., cmap='gray')
        plt.axis('off')
        plt.subplot(3, NUM, i+2*NUM+1)
        plt.imshow((tf.squeeze(thermal) + 1.)/2., cmap='gray')
        plt.axis('off')
    plt.subplots_adjust(
        top=1-0.5/4, 
        bottom=0.5/4, 
        left=0.5/(NUM+1), 
        right=1-0.5/(NUM+1), 
        wspace=0, 
        hspace=0)
    plt.show()

In [None]:
class Pix2Pix:
    def __init__(self, disc, gen, disc_opt, gen_opt):
        self.disc = disc
        self.gen = gen
        self.disc_opt = disc_opt
        self.gen_opt = gen_opt
        self.BCE = BinaryCrossentropy(from_logits=True)
        self.MAE = MeanAbsoluteError()
        self.history = {'disc_loss':[], 'gen_loss':[]}
        self.disc_mean_loss = Mean(name='disc_loss')
        self.gen_mean_loss = Mean(name='gen_loss')
    
    @tf.function(jit_compile=True)
    def train_step(self, visible, thermal):
        fake = self.gen(visible, training=False)
        zeros_patch = tf.zeros([1,16,16,1])
        ones_patch = tf.ones([1,16,16,1])
        with tf.GradientTape() as disc_tape:
            fake_pred = self.disc([visible, fake])
            thermal_pred = self.disc([visible, thermal])
            disc_loss = (self.BCE(zeros_patch, fake_pred) + self.BCE(ones_patch, thermal_pred)) / 2.
        grad_of_disc = disc_tape.gradient(disc_loss, self.disc.trainable_variables)
        self.disc_opt.apply_gradients(zip(grad_of_disc, self.disc.trainable_variables))
        self.disc_mean_loss.update_state(disc_loss)
        with tf.GradientTape() as gen_tape:
            fake = self.gen(visible)
            fake_pred = self.disc([visible, fake], training=False)
            bce_loss = self.BCE(ones_patch, fake_pred)
            mae_loss = self.MAE(thermal, fake)
            gen_loss = bce_loss + 100*mae_loss
        grad_of_gen = gen_tape.gradient(gen_loss, self.gen.trainable_variables)
        self.gen_opt.apply_gradients(zip(grad_of_gen, self.gen.trainable_variables))
        self.gen_mean_loss.update_state(mae_loss)
    
    def fit(self, data_train, data_val, epochs=1):
        for epoch in range(epochs):
            tic = perf_counter()
            self.disc_mean_loss.reset_states()
            self.gen_mean_loss.reset_states()
            for visible, thermal in tqdm(data_train):
                model.train_step(visible, thermal)
            self.history['disc_loss'].append(self.disc_mean_loss.result().numpy())
            self.history['gen_loss'].append(self.gen_mean_loss.result().numpy())
            display.clear_output(wait=True)
            print("Epoch %d/%d - %.1fs | disc_loss: %.5f - gen_loss: %.5f"%(
                epoch+1, epochs, perf_counter() - tic, self.history['disc_loss'][-1], self.history['gen_loss'][-1]))
            plot_history(self.history)
            test_model(self.gen, data_val)
        return self.history

In [None]:
data_path = "./processed_data"
training_path = os.path.join(data_path, "training")
validation_path = os.path.join(data_path, "validation")
visible_training_path = os.path.join(training_path, "visible")
thermal_training_path = os.path.join(training_path, "thermal")
visible_validation_path = os.path.join(validation_path, "visible")
thermal_validation_path = os.path.join(validation_path, "thermal")
visible_training_files = [os.path.join(visible_training_path, filename) for filename in sorted(os.listdir(visible_training_path))]
thermal_training_files = [os.path.join(thermal_training_path, filename) for filename in sorted(os.listdir(thermal_training_path))]
visible_validation_files = [os.path.join(visible_validation_path, filename) for filename in sorted(os.listdir(visible_validation_path))]
thermal_validation_files = [os.path.join(thermal_validation_path, filename) for filename in sorted(os.listdir(thermal_validation_path))]
training_files = list(zip(visible_training_files, thermal_training_files))
validation_files = list(zip(visible_validation_files, thermal_validation_files))
random.shuffle(training_files)

AUTOTUNE = tf.data.AUTOTUNE
BATCH_SIZE = 1
data_train = tf.data.Dataset.from_tensor_slices(training_files)
data_val = tf.data.Dataset.from_tensor_slices(validation_files)
data_train = data_train.map(
    data_preprocessing, 
    num_parallel_calls=AUTOTUNE, 
    deterministic=False).batch(
    BATCH_SIZE, 
    num_parallel_calls=AUTOTUNE, 
    deterministic=False, drop_remainder=True).cache().shuffle(len(data_train)).map(
    data_augmentation, 
    num_parallel_calls=AUTOTUNE, 
    deterministic=False).prefetch(tf.data.AUTOTUNE)
data_val = data_val.map(
    data_preprocessing, 
    num_parallel_calls=AUTOTUNE).batch(
    BATCH_SIZE, 
    num_parallel_calls=AUTOTUNE, 
    drop_remainder=True).cache().prefetch(tf.data.AUTOTUNE)
print("Number of training data (pairs):", len(data_train))
print("Number of validation data (pairs):", len(data_val))

In [None]:
disc = define_discriminator()
gen = define_generator()
disc_opt = Adam(learning_rate=2e-4, beta_1=0.5)
gen_opt = Adam(learning_rate=2e-4, beta_1=0.5)
model = Pix2Pix(disc, gen, disc_opt, gen_opt)

In [None]:
#keras.utils.plot_model(disc, to_file='Discriminator.png', show_shapes=True)
#keras.utils.plot_model(gen, to_file='Generator.png', show_shapes=True)

In [None]:
EPOCH = 75
history = model.fit(data_train, data_val, epochs=EPOCH)

In [None]:
model.gen.save("./pix2pix", include_optimizer=False)

In [None]:
gen_pix2pix = keras.models.load_model("./pix2pix", compile=False)
gen_cyclegan = keras.models.load_model("./cyclegan/v2t", compile=False)

In [None]:
test_model(model.gen, data_train)

In [None]:
COL = 9
ROW = 4
plt.figure(figsize=(COL+1, ROW+1), dpi=300)
for i, (real_v, real_t) in enumerate(data_val.take(COL)):
    fake_t_pix2pix = gen_pix2pix(real_v, training=False)
    fake_t_cyclegan = gen_cyclegan(real_v, training=False)
    plt.subplot(ROW, COL, i+1)
    plt.imshow((tf.squeeze(real_v) + 1.)/2.)
    plt.axis('off')
    plt.subplot(ROW, COL, i+COL+1)
    plt.imshow((tf.squeeze(fake_t_pix2pix) + 1.)/2., cmap='gray')
    plt.axis('off')
    plt.subplot(ROW, COL, i+2*COL+1)
    plt.imshow((tf.squeeze(fake_t_cyclegan) + 1.)/2., cmap='gray')
    plt.axis('off')
    plt.subplot(ROW, COL, i+3*COL+1)
    plt.imshow((tf.squeeze(real_t) + 1.)/2., cmap='gray')
    plt.axis('off')
plt.subplots_adjust(
    top=1-0.5/(ROW+1), 
    bottom=0.5/(ROW+1), 
    left=0.5/(COL+1), 
    right=1-0.5/(COL+1), 
    wspace=0, 
    hspace=0)
plt.show()