implementation of [universal style transfer via feature transforms by Li et al](https://arxiv.org/pdf/1705.08086.pdf)

In [None]:
# for auto-reloading extenrnal modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline

In [None]:
# make variables display whenever they are on their own line (not just the last line of a cell)
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [None]:
import os

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from imageio import imread
import skimage.transform

from utils import *

## Featureize images

In [None]:
compute_features('../val2017', out_layer_name='vgg_19/conv4/conv4_1', num_images=500)

## Create graph

In [None]:
layer_to_sizes = {
    'conv1': (None, 224, 224, 64),
    'conv2': (None, 112, 112, 128),
    'conv4': (None, 28, 28, 512)
}

tf.reset_default_graph()
sess = tf.Session()
images_ph = tf.placeholder('float', (None, 224, 224, 3))
features_ph = tf.placeholder('float', layer_to_sizes['conv4'])

In [None]:
dataset = make_precomputed_dataset('../val2017', 'conv4', num_images=500)

In [None]:
reconstructed_image = make_decoder(features_ph, [(3, 256), 'upsample', (3, 256), (3, 256), (3, 256), (3, 128), 'upsample', (3, 112), (3, 64), 'upsample', (3, 64), (3, 3)], sess)

In [None]:
loss = create_loss(images_ph, features_ph, reconstructed_image, 'vgg_19/conv4/conv4_1', sess)

In [None]:
train(loss, dataset, images_ph, features_ph, sess, num_epochs=15, lr=5e-4)

In [None]:
saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'decoder'))
saver.save(sess, './decoder4/decoder')

## style transfer stuff

In [None]:
# img = imread('../val2017/000000369751.jpg')
img = imread('C:/users/rtimpe/downloads/cat.jpg')
img = skimage.transform.resize(img, (224, 224, 3))
img = img[np.newaxis, :, :, :]
img = img * 255
encoded_img = encode_image(img, 'vgg_19/conv4/conv4_1')

In [None]:
v = encoded_img.reshape(-1, encoded_img.shape[2]).T
v_centered = v - np.mean(v)
w, _ = whiten(v_centered)
w = np.reshape(w, (1, 28, 28, 512))

In [None]:
reconstructed = sess.run(reconstructed_image, feed_dict={features_ph: encoded_img})
reconstructed = np.squeeze(reconstructed)

In [None]:
(reconstructed)

In [None]:
w.max()

In [None]:
plt.imshow(((reconstructed)).astype(np.uint8))

In [None]:
plt.imshow(np.squeeze(img.astype(np.uint8)))

## other crap

In [None]:
it = dataset.make_one_shot_iterator()

In [None]:
img,_ = sess.run(it.get_next())
img

In [None]:
plt.imshow(img)

In [None]:
decoded_img = sess.run(reconstructed_image, feed_dict={input_ph: img[np.newaxis,:,:,:]})

In [None]:
decoded_img = np.squeeze(decoded_img)

In [None]:
decoded_img[decoded_img > 255] = 255

In [None]:
plt.imshow(decoded_img.astype(np.uint8))

In [None]:
writer = tf.summary.FileWriter('summaries', sess.graph)

In [None]:
sess.run(dataset.make_one_shot_iterator().get_next())

In [None]:
it = dataset.make_one_shot_iterator()
sess.run(it.get_next())[1].shape

In [None]:
# make_encoder(images_ph, sess)
writer = tf.summary.FileWriter('summaries', sess.graph)