In [1]:
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.applications.vgg19 as vgg19

from tensorflow import einsum
from tensorflow.keras.layers import Lambda, MaxPool2D, AvgPool2D, Layer, Input, Subtract, Multiply, Add
from tensorflow.keras.models import Model

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append('../utilities')
from utilities import load_image, show_image, vgg19_process_image, vgg19_deprocess_image

In [4]:
base_model = tf.keras.applications.VGG19(include_top=False, weights='imagenet')

In [5]:
# Gram Matrix Layers
def gram_matrix(activations):
    result        = tf.linalg.einsum('aijb,aijc->abc', activations, activations)
    input_shape   = tf.shape(activations)
    num_locations = tf.cast(input_shape[1]*input_shape[2], tf.float32)
    return result/(num_locations)

In [6]:
# require: base_model, style_image, content_image, style_layers, content_layers
width = height = 448

style_layers    = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1']
content_layers  = ['block4_conv2']

# Build the extractive model
base_model.trainable   = False
num_style_outputs      = len(style_layers)

style_outputs   = []
content_outputs = []

for i, layer in enumerate(base_model.layers):
    if i == 0:
        # Custom input size
        input_      = tf.keras.layers.Input(shape = (width, height, 3))
        current     = input_
    elif isinstance(layer, MaxPool2D):
        # Replace max pooling with average pooling
        pool_config = layer.get_config()
        avg_pool    = AvgPool2D().from_config(pool_config)
        current     = avg_pool(current)
    else:
        current = layer(current)

    if layer.name in style_layers:
        # Compute gram matrix as a style output
        style_outputs.append(Lambda(gram_matrix, name = f'gram_{layer.name}')(current))
    if layer.name in content_layers:
        # Keep some convolutional layers as outputs
        content_outputs.append(current)

extract_model = Model(inputs = [input_], outputs = style_outputs + content_outputs, name = 'extractive_model')

In [7]:
# Source layers (with no/fake inputs)
class Source(Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super().__init__(**kwargs)

    def build(self, input_shapes):
        self.kernel = self.add_weight(name='kernel', shape=self.output_dim, initializer='uniform', trainable=True)
        super().build(input_shapes)  

    def call(self, inputs):
        return self.kernel

    def compute_output_shape(self):
        return self.output_dim
    
    def get_params(self):
        base_config = super().get_config()
        return {**base_config, 'output_dim' : self.output_dim}


In [8]:
fake_input = Input(())
source     = Source((1 ,width, height, 3))

extract_outputs  = extract_model(source(fake_input))

style_content_weighting = 0.01

style_layer_weighting   = tf.constant(1/len(style_outputs) * style_content_weighting * (1/4))
content_layer_weighting = tf.constant((1/2))

transfer_inputs  = [fake_input]
transfer_outputs = []
for i, output in enumerate(extract_outputs):
    batch_input_shape = output.shape
    input_     = Input(shape = batch_input_shape[1:], batch_size=batch_input_shape[0])
    transfer_inputs.append(input_)
    difference = Subtract()([input_, output])
    square     = Lambda(tf.square, name = f'square_{i}')(difference)
    # Determine if the output is style or content
    if i < num_style_outputs:
        reduce = Lambda(lambda t : tf.reduce_mean(t, axis = [1,2]), name = f'mean_{i}')(square)
        scale  = Lambda(lambda x : x * style_layer_weighting, name = f'weight_{i}')(reduce)
    else:
        reduce = Lambda(lambda t : tf.reduce_sum(t, axis = [1,2,3]), name = f'sum_{i}')(square)
        scale  = Lambda(lambda x : x * content_layer_weighting, name = f'weight_{i}')(reduce)
    transfer_outputs.append(scale)
output = Add()(transfer_outputs)

transfer_model = Model(inputs = transfer_inputs, outputs = output)

In [9]:
# Load test cases
content_path = '../dream-base-images/marco3.png'
style_path   = '../dataset/images/train/Max Ernst/21433.jpg'

content_image = load_image(content_path, cast = tf.float32)
style_image = load_image(style_path, cast = tf.uint8)

# Pre-process the style and content images
images = [style_image, content_image]
for i, image in enumerate(images):
    image = tf.cast(image, tf.float32)
    image = tf.image.resize(image, [width, height])
    image = vgg19_process_image(image)
    image = tf.expand_dims(image, axis = 0)
    images[i] = image

style_image, content_image = images

# Compute the target activations
style_activations   = extract_model(style_image  )[:num_style_outputs]
content_activations = extract_model(content_image)[num_style_outputs:]

In [10]:
dummy = tf.constant(0, shape = (1,), dtype = tf.float32)

In [11]:
transfer_input = tuple([dummy] + style_activations + content_activations)
transfer_input = tf.data.Dataset.from_tensors((transfer_input, dummy))

In [12]:
def precomputed_loss(dummy, loss):
    return loss

In [13]:
transfer_model.compile(optimizer = 'adam', loss = precomputed_loss)

In [None]:
transfer_model.fit(transfer_input, epochs = 50)

In [None]:
tf.keras.utils.plot_model(transfer_model)