In [None]:
import tensorflow as tf
import cv2
import numpy as np
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline

In [None]:
layer_pos_dict = {"conv1_1" : 1, "conv2_1" : 4, "conv3_1" : 7, "conv4_1" : 11, "conv5_1" : 15}

content_layer = "conv3_1"
content_loss_weight = 0.1
style_loss_weights = {"conv1_1" : 3000, "conv2_1" : 750, "conv3_1" : 250, "conv4_1" : 100, "conv5_1" : 50}
tv_loss_weight = 0.0001

In [None]:
IMAGENET_MEANS = [103.939, 116.779, 123.68]
# IMAGENET_MEANS = [0.40760392, 0.45795686, 0.48501961]

def process_image(img):
    processed_image = np.array(img).astype(np.float32)
#     processed_image /= 255
    for x in range(3):
        processed_image[:, :, x] -= IMAGENET_MEANS[x]
    return processed_image
        
def restore_image(img):
    restored_image = np.array(img)
    for x in range(3):
        restored_image[:, :, x] += IMAGENET_MEANS[x]
    restored_image.clip(0, 255)
#     restored_image.clip(0, 1)
#     restored_image *= 255
    return restored_image.astype(np.uint8)

In [None]:
content_source = cv2.imread("../../data/images/Amsterdam.jpg", 1)
content_source = cv2.resize(content_source, (224, 224))

style_source = cv2.imread("../../data/images/VanGogh.jpg", 1)
style_source = cv2.resize(style_source, (224, 224))

In [None]:
target_content = tf.constant(process_image(content_source))
target_style = tf.constant(process_image(style_source))
recovered_image = tf.Variable(tf.random_normal([1, 224, 224, 3]), name="recovered_image", trainable=True)
concatenated_input = tf.concat([tf.expand_dims(target_content, axis=0), 
                                tf.expand_dims(target_style, axis=0), 
                                recovered_image
                               ], axis=0)

In [None]:
vgg16 = tf.contrib.keras.applications.vgg16.VGG16(include_top=False, weights='imagenet', 
                                                  input_tensor=concatenated_input, input_shape=None)
for layer in vgg16.layers:
    layer.trainable = False

In [None]:
def get_content_loss(layer):
    content_embeddings = vgg16.layers[layer_pos_dict[layer]].output
    return tf.reduce_sum(tf.squared_difference(content_embeddings[0, :, :, :], content_embeddings[2, :, :, :]))

In [None]:
def gram_matrix(embedding):
    filters_first = tf.transpose(embedding, perm=[2, 0, 1])
    filters_flatten = tf.contrib.keras.backend.batch_flatten(filters_first)
    gram = tf.matmul(filters_flatten, filters_flatten, transpose_b=True)
    return gram

def get_style_loss(layer):
    style_embeddings = vgg16.layers[layer_pos_dict[layer]].output
    
    embedding_shape = style_embeddings.get_shape().as_list()
    layer_width = embedding_shape[1]
    layer_height = embedding_shape[2]
    n_filters = embedding_shape[3]
    style_norm = 4 * (n_filters * layer_width * layer_height) ** 2
    
    target_gram = gram_matrix(style_embeddings[1, :, :, :])
    recovered_gram = gram_matrix(style_embeddings[2, :, :, :])
    
    return tf.reduce_sum(tf.squared_difference(target_gram, recovered_gram)) / style_norm

In [None]:
def get_total_variation_loss(x):
    width, height = 224, 224
    width_offset = tf.square(x[:, :width - 1, :height - 1, :] - x[:, 1:, :height - 1, :])
    height_offset = tf.square(x[:, :width - 1, :height - 1, :] - x[:, :width - 1, 1:, :])
    return tf.reduce_sum(width_offset + height_offset)    

In [None]:
content_loss = content_loss_weight * get_content_loss(content_layer)
style_loss = tf.Variable(0.)
for style_layer in style_loss_weights:
    style_loss += style_loss_weights[style_layer] * get_style_loss(style_layer)
tv_loss = tv_loss_weight * get_total_variation_loss(recovered_image)

total_loss = content_loss + style_loss + tv_loss

In [None]:
adam = tf.train.AdamOptimizer(3e-1).minimize(total_loss, var_list=[recovered_image])

In [None]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(50):
        current_content_loss, current_style_loss, current_tv_loss, current_total_loss, _ = \
        sess.run([content_loss, style_loss, tv_loss, total_loss, adam])
        print(i, current_content_loss, current_style_loss, current_tv_loss, current_total_loss)
    final_image = recovered_image.eval()[0, :, :, :]

In [None]:
plt.figure(figsize=(15,15))
plt.subplot(1, 3, 1)
plt.imshow(content_source[:, :, [2, 1, 0]])

plt.subplot(1, 3, 2)
plt.imshow(style_source[:, :, [2, 1, 0]])

plt.subplot(1, 3, 3)
plt.imshow(restore_image(final_image)[:, :, [2, 1, 0]])