In [None]:
import tensorflow as tf
import cv2
import numpy as np

In [None]:
vgg16 = tf.contrib.keras.applications.vgg16.VGG16(include_top=False, weights='imagenet', input_tensor=None, input_shape=None)
for layer in vgg16.layers:
    layer.trainable = False

In [None]:
vgg16.summary()

In [None]:
layer_pos_dict = {"conv1_2" : 2, "conv2_2" : 5, "conv3_2" : 8, "conv4_2" : 12, "conv5_2" : 16}
layer_pick = "conv2_2"

In [None]:
img = cv2.imread("../../data/images/VanGogh.jpg", 1)
img = cv2.resize(img, (224, 224))

In [None]:
IMAGENET_MEANS = [103.939, 116.779, 123.68]

def process_image(img):
    processed_image = np.array(img).astype(np.float32)
    for x in range(3):
        processed_image[:, :, x] -= IMAGENET_MEANS[x]
    return processed_image
        
def restore_image(img):
    restored_image = np.array(img)
    for x in range(3):
        restored_image[:, :, x] += IMAGENET_MEANS[x]
    restored_image.clip(0, 255)
    return restored_image.astype(np.uint8)

In [None]:
target_image = tf.placeholder(tf.float32, shape=(224, 224, 3), name="target_image")
recovered_image = tf.Variable(tf.random_normal([224, 224, 3]), name="recovered_image")

In [None]:
def get_embedding(image):
    last_layer = tf.expand_dims(image, axis=0)
    for i in range(1, layer_pos_dict[layer_pick] + 1):
        next_layer = vgg16.layers[i](last_layer)
        last_layer = next_layer
    return last_layer

In [None]:
def gram_matrix(embedding):
    filters_first = tf.transpose(embedding, perm=[3, 0, 1, 2])
    filters_flatten = tf.contrib.keras.backend.batch_flatten(filters_first)
    gram = tf.matmul(filters_flatten, filters_flatten, transpose_b=True)
    return gram

In [None]:
target_gram = gram_matrix(get_embedding(target_image))
recovered_gram = gram_matrix(get_embedding(recovered_image))

In [None]:
style_loss = tf.reduce_sum(tf.squared_difference(target_gram, recovered_gram))

In [None]:
adam = tf.train.AdamOptimizer(1e-4).minimize(style_loss)

In [None]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    feed_dict = {target_image : process_image(img)}
    sess.run(adam, feed_dict=feed_dict)