In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
import tensorflow_addons as tfa

import os
import glob
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
%load_ext tensorboard

AUTOTUNE = tf.data.experimental.AUTOTUNE

tf.keras.backend.clear_session()

print('GPU::', tf.config.list_physical_devices('GPU'))
print('version', tf.__version__)

In [None]:
IMAGE_SIZE = 224
VGG_TENSOR_RC = IMAGE_SIZE//8 # input of 224 into vgg_19 outputs 28x28x512 @ block4_conv1
THUMB_SIZE = IMAGE_SIZE//4
batch_size = 18 #18 for enc & dec, 36 for enc only, 30 for dec only
DATASET_LENGTH = 70000 // batch_size # style: 81445, content: 82612 # content face: 70000
number_of_epochs = 20

CHECKPOINT_DIR = './042.ckpts'
LOGS_DIR = './logs/adain.encoder.decoder.v417'

total_steps = DATASET_LENGTH * number_of_epochs

test_sample = max(5, batch_size)

layer_names = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1']

In [None]:
vgg = tf.keras.applications.VGG19(
    include_top=False,
    weights='imagenet',
    input_tensor=keras.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
)

vgg_model_internal = tf.keras.Model(vgg.input, [vgg.get_layer(name).output for name in layer_names])
vgg_model_internal.trainable = False
vgg_model_internal.summary()

In [None]:
# all_image_paths = ['../data/train2014/' + f for f in os.listdir('../data/train2014')]
all_image_paths = [f for f in glob.glob('../ffhq-dataset/images224x224/**/*.png')]
all_style_image_paths = [f for f in glob.glob('../data/wikiart/**/*')]

def preprocess_image_from_path(path):
    img = tf.io.read_file(path)
    img = tf.cast(tf.image.decode_image(img, channels=3, expand_animations = False), tf.float32)

    shape = tf.shape(img)

    if shape[0] < shape[1]:
        img = tf.image.resize(img, (IMAGE_SIZE, 10000), preserve_aspect_ratio=True)
    elif shape[0] > shape[1]:
        img = tf.image.resize(img, (10000, IMAGE_SIZE), preserve_aspect_ratio=True)
    
    if shape[0] >= IMAGE_SIZE and shape[1] >= IMAGE_SIZE:
        img = tf.image.random_crop(img, (IMAGE_SIZE, IMAGE_SIZE, 3))
    else:
        img = tf.image.resize_with_pad(img, IMAGE_SIZE, IMAGE_SIZE)
    
    return img / 255. # normalize to [0,1] range

def preprocess_content_image(img):
    
    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_flip_up_down(img)
    
    img = tf.image.random_hue(img, .1)
    img = tf.image.random_brightness(img, .1)
    img = tf.image.random_contrast(img, .85, 1.)
    img = tf.image.random_saturation(img, .85, 1.)
    
    return img

style_dataset = tf.data.Dataset.from_tensor_slices(all_style_image_paths)
content_dataset = tf.data.Dataset.from_tensor_slices(all_image_paths)

style_dataset = style_dataset.map(preprocess_image_from_path, num_parallel_calls=AUTOTUNE)
style_dataset = style_dataset.shuffle(buffer_size=50).repeat()

content_dataset = content_dataset.map(preprocess_image_from_path, num_parallel_calls=AUTOTUNE)
content_dataset = content_dataset.map(preprocess_content_image, num_parallel_calls=AUTOTUNE)
content_dataset = content_dataset.shuffle(buffer_size=50).repeat()

dataset = tf.data.Dataset.zip((style_dataset, content_dataset))
dataset = dataset.batch(batch_size)
dataset = dataset.repeat()

base_style = [data for data in list(style_dataset.take(test_sample).as_numpy_iterator())]
base_content = [data for data in list(content_dataset.take(test_sample).as_numpy_iterator())]

In [None]:
# This is based on multiple open source code I've seen.
# Probably https://github.com/jonrei/tf-AdaIN/blob/master/AdaIN.py
class AdaIN(layers.Layer):
    def __init__(self, **kwargs):
        super(AdaIN, self).__init__(**kwargs)
    
    @tf.function
    def call(self, _input):
        style_features, content_features, alpha = _input

        style_mean, style_variance = tf.nn.moments(style_features, [1,2], keepdims=True)
        content_mean, content_variance = tf.nn.moments(content_features, [1,2], keepdims=True)

        epsilon = 1e-5
        
        #batch_normalization breaks on tflite's Android GPU delegate
        content_std = tf.math.sqrt(content_variance + epsilon)
        style_std = tf.math.sqrt(style_variance + epsilon)

        normalized_content_features = (content_features - content_mean) / (content_std + epsilon) * style_std + style_mean
        
        return alpha * normalized_content_features + (1 - alpha) * content_features

# grabbed this from somewhere..., to avoid checkerboard pattern in output
class ReflectionPad(layers.Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [layers.InputSpec(ndim=4)]
        super(ReflectionPad, self).__init__(**kwargs)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1] + 2 * self.padding[0], input_shape[2] + 2 * self.padding[1], input_shape[3])
    
    def call(self, input_tensor, mask=None):
        padding_width, padding_height = self.padding
        return tf.pad(input_tensor, [[0,0], [padding_height, padding_height], [padding_width, padding_width], [0,0] ], 'REFLECT')

def get_style_loss(encoded_style, encoded_y):
    loss = 0
    epsilon = 1e-5
    
    for style, y in zip(encoded_style, encoded_y):
        mean_style, variance_style = tf.nn.moments(style, [1,2], keepdims=True)
        mean_y, variance_y = tf.nn.moments(y, [1,2], keepdims=True)
        
        std_style = tf.math.sqrt(variance_style + epsilon)
        std_y = tf.math.sqrt(variance_y + epsilon)
        
        loss += tf.reduce_mean(tf.math.square(mean_style - mean_y))
        loss += tf.reduce_mean(tf.math.square(std_style - std_y))
        
        # must have read something about log_cosh. Haven't tested on output quality diff between log_cosh & mse
        loss += tf.reduce_mean(tf.reduce_sum(tf.reshape(keras.losses.log_cosh(mean_style, mean_y), (batch_size, -1)), axis=1))
        loss += tf.reduce_mean(tf.reduce_sum(tf.reshape(keras.losses.log_cosh(std_style, std_y), (batch_size, -1)), axis=1))
        
    return loss


In [None]:
e_loss_tracker = keras.metrics.Mean(name="encoder_loss")
d_loss_tracker = keras.metrics.Mean(name="decoder_loss")

adain_layer = AdaIN(name="AdaIN", trainable=False)
# issues with mediaPipe to pass a scalar
const_alpha = tf.ones((batch_size, 1, 1, 512)) * .5

checkpoint_manager = None

def embed(container, image, i, j):
    if len(image.shape) > 3:
        batch, ii, ij, ic = image.shape
    else:
        batch = 1
        ii, ij, ic = image.shape

    for b in range(batch):
        container[b, i:i + ii, j:j + ij] = image[b]

class CustomCallback(keras.callbacks.Callback):
    def __init__(self, patience=0):
        super(CustomCallback, self).__init__()
        self.tick = 0
        self.writer = tf.summary.create_file_writer(LOGS_DIR)
        self.e_train_loss = tf.keras.metrics.Mean('encoder_train_loss', dtype=tf.float32)
        self.d_train_loss = tf.keras.metrics.Mean('decoder_train_loss', dtype=tf.float32)
        self.learning_rate = tf.keras.metrics.Mean('learning_rate', dtype=tf.float32)
        self.alpha = tf.ones((batch_size, 1, 1, 512)) * .5

    def on_train_begin(self, logs=None):
        net = tf.train.Checkpoint(
            encoder=self.model.encoder,
            decoder=self.model.decoder
        )
        ckpt = tf.train.Checkpoint(
            step=tf.Variable(1),
            net=net,            
            e_optimizer=self.model.e_optimizer,
            d_optimizer=self.model.d_optimizer,
            encoder=self.model.encoder,
            decoder=self.model.decoder
        )
        
        self.manager = tf.train.CheckpointManager(ckpt, CHECKPOINT_DIR, max_to_keep=3)
        ckpt.restore(self.manager.latest_checkpoint)
        
    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))
        self.manager.save()
        
    def on_epoch_end(self, epoch, logs=None):
        self.manager.save()

    def on_train_batch_end(self, batch, logs=None):
        self.tick += 1
                
        if self.tick % 100 == 0:
            self.e_train_loss(logs.get('encoder_loss'))
            self.d_train_loss(logs.get('decoder_loss'))
            self.learning_rate(self.model.e_optimizer.lr)
            with self.writer.as_default():
                tf.summary.scalar('encoder_train_loss', self.e_train_loss.result(), step=self.tick)
                tf.summary.scalar('decoder_train_loss', self.d_train_loss.result(), step=self.tick)
                tf.summary.scalar('learning_rate', self.learning_rate.result(), step=self.tick)
                
        if self.tick == 1 or (self.tick % 250 == 0 and self.tick < 1000) or self.tick % 1000 == 0: 
            batched_bs = tf.slice(base_style, [0, 0, 0, 0], [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3])
            batched_bc = tf.slice(base_content, [0, 0, 0, 0], [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3])

            preprocessed_input_style = keras.applications.vgg19.preprocess_input(batched_bs * 255.)
            b_style = vgg_model_internal(preprocessed_input_style, training=False)[-1]
            b_content = self.model.encoder(batched_bc, training=False)
            
            output, _ = self.model.decoder([b_style, b_content, self.alpha], training=False)
            
            resized_bs = tf.image.resize(batched_bs, (THUMB_SIZE, THUMB_SIZE), preserve_aspect_ratio=True)
            resized_bc = tf.image.resize(batched_bc, (THUMB_SIZE, THUMB_SIZE), preserve_aspect_ratio=True)
            
            output = output.numpy()
            embed(output, resized_bs.numpy(), 0, 0)
            embed(output, resized_bc.numpy(), IMAGE_SIZE - THUMB_SIZE, 0)

            np.reshape(output, (-1, IMAGE_SIZE, IMAGE_SIZE, 3))
            with self.writer.as_default():
                tf.summary.image("5 outputs", output, max_outputs=5, step=self.tick)
            
            self.manager.save()
            
class CustomModel(keras.Model):
    def __init__(self, encoder, decoder):
        super(CustomModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def compile(self, e_optimizer, d_optimizer):
        super(CustomModel, self).compile()
        self.e_optimizer = e_optimizer
        self.d_optimizer = d_optimizer
        
    def train_step(self, data):
        #TODO: need to find a better way for the training cycles
        
        # avoids vgg19.preprocess_input from mutating the cached image
        _input_style = tf.identity(data[0])
        _input_content = tf.identity(data[1])
        
        # input images are [0, 1]
        preprocessed_input_style = keras.applications.vgg19.preprocess_input(_input_style * 255.)
        encoded_style = vgg_model_internal(preprocessed_input_style, training=False)
        y_style = encoded_style[-1]
        
        preprocessed_input_content = keras.applications.vgg19.preprocess_input(_input_content * 255.)
        encoded_content = vgg_model_internal(preprocessed_input_content, training=False)

        loss = 0
        encoder_loss = 0
        y_content = None

        with tf.GradientTape() as tape:
            # input images are [0, 1]
            y_content = self.encoder(_input_content, training=True)
            
            encoder_loss = tf.reduce_mean(tf.math.square(y_content - encoded_content[-1]))
    
        trainable_vars = self.encoder.trainable_variables
        grads = tape.gradient(encoder_loss, trainable_vars)
        self.e_optimizer.apply_gradients(zip(grads, trainable_vars))

        if y_content is None:
            y_content = self.encoder(_input_content, training=False)

        with tf.GradientTape() as tape:
            # input images are [0, 1]
            y, adaIn_output = self.decoder([y_style, y_content, const_alpha], training=True)
            
            # y is [0, 1]
            preprocessed_y = keras.applications.vgg19.preprocess_input(y * 255)
            encoded_y = vgg_model_internal(preprocessed_y, training=False)

            loss = get_style_loss(encoded_style, encoded_y) * 1e2
            loss += tf.reduce_mean(tf.math.square(adaIn_output - encoded_y[-1]))
            loss += tf.image.total_variation(y) * .5
    
        trainable_vars = self.decoder.trainable_variables
        grads = tape.gradient(loss, trainable_vars)
        self.d_optimizer.apply_gradients(zip(grads, trainable_vars))

        d_loss_tracker.update_state(loss)
        e_loss_tracker.update_state(encoder_loss)

        return {"encoder_loss": e_loss_tracker.result(), "decoder_loss": d_loss_tracker.result()}
        
def conv(x, ch, name, strides=1):
    x = ReflectionPad((1, 1), name=name + 'reflection')(x)
    x = layers.SeparableConv2D(ch, (3, 3), strides=strides, name=name + 'sepconv', use_bias=True, kernel_regularizer=keras.regularizers.l2(0.001))(x)
    x = tfa.layers.InstanceNormalization(axis=3, center=True, scale=True)(x)
    x = tf.keras.layers.LeakyReLU()(x)
    
    return x

def get_decoder_model(x):
    # 56, 56, 512
    x = conv(x, 512, 'block_5_')
    
    # 112, 112, 512
    x = tf.keras.layers.UpSampling2D()(x)
    
    x = conv(x, 256, 'block_4a_')
    x = conv(x, 256, 'block_4b_')
    x = conv(x, 256, 'block_4c_')
    x = conv(x, 256, 'block_4d_')
    
    # 224, 224, 256
    x = tf.keras.layers.UpSampling2D()(x)
    
    x = conv(x, 128, 'block_3a_')
    x = conv(x, 128, 'block_3b_')
    
    x = conv(x, 64, 'block_2a_')
    
    # 448, 448, 64
    x = tf.keras.layers.UpSampling2D()(x)
    
    x = conv(x, 64, 'block_1b_')
    
    initializer = tf.keras.initializers.HeUniform()
    x = ReflectionPad((1, 1), name='block_reflection')(x)
    x = layers.Conv2D(3, (3, 3), name='block_conv', kernel_initializer=initializer)(x)
    return tf.math.tanh(x) * .5 + .5

def get_encoder_model(x):
    x = tf.reverse(x, [3])
    x = x * 255. - [103.939, 116.779, 123.68]
    # 448, 448, 3
    x = conv(x, 64, 'block1_conv1') # 448, 448, 64
    x = conv(x, 64, 'block1_conv2') # 448, 448, 64
    
    x = conv(x, 128, 'block2_conv1', strides=2) #224, 224, 128
    x = conv(x, 128, 'block2_conv2') #224, 224, 128
    
    x = conv(x, 256, 'block3_conv1', strides=2) #112, 112, 256
    x = conv(x, 256, 'block3_conv2') #112, 112, 256
    
    x = conv(x, 512, 'block4_conv1', strides=2) #56, 56, 512
    x = conv(x, 512, 'block4_conv2') #56, 56, 512

    return x

In [None]:
_input = keras.Input((IMAGE_SIZE, IMAGE_SIZE, 3))
x = _input
x = get_encoder_model(x)

encoder = keras.Model(_input, x)

_input_style = keras.Input((VGG_TENSOR_RC, VGG_TENSOR_RC, 512))
_input_content = keras.Input((VGG_TENSOR_RC, VGG_TENSOR_RC, 512))
_input_alpha = keras.Input(shape=(1, 1, 512))

adaIn_output = AdaIN(name="AdaIN")([_input_style, _input_content, _input_alpha])
output = get_decoder_model(adaIn_output)

decoder = keras.Model([_input_style, _input_content, _input_alpha], [output, adaIn_output])

wrapper_model = CustomModel(encoder=encoder, decoder=decoder)
wrapper_model.compile(
    e_optimizer=keras.optimizers.Adam(learning_rate=1e-3, beta_1=0.9, amsgrad=True),
    d_optimizer=keras.optimizers.Adam(learning_rate=1e-3, beta_1=0.9, amsgrad=True)
)

In [None]:
%%script false --no-raise-error

wrapper_model.encoder.summary()
wrapper_model.decoder.summary()

In [None]:
%%script false --no-raise-error

net = tf.train.Checkpoint(
    encoder=wrapper_model.encoder,
    decoder=wrapper_model.decoder
)

ckpt = tf.train.Checkpoint(
    step=tf.Variable(1),
    net=net,
    encoder=wrapper_model.encoder,
    decoder=wrapper_model.decoder
)

ckpt.restore(tf.train.latest_checkpoint(CHECKPOINT_DIR))

batched_bs = tf.slice(base_style, [0, 0, 0, 0], [5, IMAGE_SIZE, IMAGE_SIZE, 3])
batched_bc = tf.slice(base_content, [0, 0, 0, 0], [5, IMAGE_SIZE, IMAGE_SIZE, 3])

preprocessed_input_style = keras.applications.vgg19.preprocess_input(batched_bs * 255.)
encoded_style = vgg_model_internal(preprocessed_input_style, training=False)

ys = encoded_style[-1]
yc = wrapper_model.encoder(batched_bc)

y, _ = wrapper_model.decoder([ys, yc, tf.ones((5, 1, 1, 512)) * .5])

fig = plt.figure(figsize = (20, 20))
fig.add_subplot(3,3,1)
plt.imshow((batched_bs[4].numpy() * 255.).astype(np.uint8))

fig.add_subplot(3,3,2)
plt.imshow((batched_bc[4].numpy() * 255.).astype(np.uint8))

fig.add_subplot(3,3,3)
plt.imshow((y[4].numpy()).astype(np.uint8))
plt.show()

def get_loss():
    _batched_bc = tf.slice(base_content, [0, 0, 0, 0], [1, IMAGE_SIZE, IMAGE_SIZE, 3])

    preprocessed_input_content = keras.applications.vgg19.preprocess_input(_batched_bc * 255.)
    encoded_content = vgg_model_internal(preprocessed_input_content, training=False)

    encoder_loss = 0

    # input images are [0, 1]
    y_content = wrapper_model.encoder(_batched_bc)
    encoder_loss += tf.keras.losses.log_cosh(encoded_content[-1], y_content)
        
    return tf.reduce_mean(encoder_loss).numpy()

print(get_loss())

In [None]:
wrapper_model.fit(
    dataset,
    {},
    epochs=number_of_epochs,
    steps_per_epoch=DATASET_LENGTH,
    batch_size=batch_size,
    callbacks=[
        CustomCallback()
    ]
)

In [None]:
# %%script false --no-raise-error

_input = keras.Input((IMAGE_SIZE, IMAGE_SIZE, 3))
_output = wrapper_model.encoder(_input)
model = keras.Model(_input, _output)
tf.saved_model.save(model, './adain_encoder')

_input = keras.Input((IMAGE_SIZE, IMAGE_SIZE, 3))
# 'RGB'->'BGR'
y = tf.reverse(_input, [3])
y = y * 255. - [103.939, 116.779, 123.68]
_output = vgg_model_internal(y)
model = keras.Model(_input, _output[-1])
tf.saved_model.save(model, './adain_vgg')

s_input = keras.Input((VGG_TENSOR_RC, VGG_TENSOR_RC, 512))
c_input = keras.Input((VGG_TENSOR_RC, VGG_TENSOR_RC, 512))
a_input = keras.Input((1, 1, 512))

_output, _ = wrapper_model.decoder([s_input, c_input, a_input])

model = keras.Model([s_input, c_input, a_input], _output)
tf.saved_model.save(model, './adain_decoder')

Conversion to web models uses the following CLI command

`tensorflowjs_converter ./adain_encoder/ ./new_web_models/encoder`