implementation of [universal style transfer via feature transforms by Li et al](https://arxiv.org/pdf/1705.08086.pdf)

In [None]:
# for auto-reloading extenrnal modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline

In [None]:
# make variables display whenever they are on their own line (not just the last line of a cell)
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [None]:
import os

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from imageio import imread
import skimage.transform

from utils import *

## Settings

In [None]:
layer = 'conv4'

## Featureize images

In [None]:
compute_features('D:/test2017', out_layer_name='vgg_19/' + layer + '/' + layer + '_1', num_images=40000)
compute_features('D:/val2017', out_layer_name='vgg_19/' + layer + '/' + layer + '_1', num_images=200)

## Autoencoder training

In [None]:
layer_to_sizes = {
    'conv1': (None, 224, 224, 64),
    'conv2': (None, 112, 112, 128),
    'conv4': (None, 28, 28, 512)
}

tf.reset_default_graph()
sess = tf.Session()
images_ph = tf.placeholder('float', (None, 224, 224, 3), name='images')
features_ph = tf.placeholder('float', layer_to_sizes[layer], name='features')

In [None]:
train_dataset = make_precomputed_dataset('D:/test2017', layer, num_images=40000, batch_size=32)
val_dataset = make_precomputed_dataset('D:/val2017', layer, num_images=200, batch_size=32)

In [None]:
architecture_4 = [(3, 256), 'upsample', (3, 256), (3, 256), (3, 256), (3, 128), 'upsample', (3, 112), (3, 64), 'upsample', (3, 64), (3, 3)]
architecture_1 = [(3, 3)]
reconstructed_image, regularized = make_decoder(features_ph, architecture_4, sess)
(per_img_loss, total_loss) = create_loss(images_ph, features_ph, 
                                         reconstructed_image, 'vgg_19/' + layer + '/' + layer + '_1', 
                                         regularized, sess, lambda_reg=1e-3)

In [None]:
(train_step, merged) = setup_training(total_loss, train_dataset, sess, lr=5e-5)

In [None]:
train(per_img_loss, train_step, merged, train_dataset, val_dataset, images_ph, features_ph, sess,
      num_epochs=1, summary_freq=50)

In [None]:
saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'decoder'))
saver.save(sess, './decoder4/decoder')

## load pretrained model

In [None]:
sess = tf.Session()

saver = tf.train.import_meta_graph('decoder4/decoder.meta')
writer = tf.summary.FileWriter('summaries', sess.graph)
saver.restore(sess, tf.train.latest_checkpoint('decoder4/'))

In [None]:
reconstructed_image = sess.graph.get_tensor_by_name('decoder/conv_layer_11/Relu:0')
features_ph = sess.graph.get_tensor_by_name('features:0')

## testing

In [None]:
img = imread('C:/users/rtimpe/downloads/cat.jpg')
img = skimage.transform.resize(img, (224, 224, 3))
img *= 255

In [None]:
encoded_img = encode_image(img[np.newaxis, :,:,:], 'vgg_19/' + layer + '/' + layer + '_1')
reconstructed = sess.run(reconstructed_image, feed_dict={features_ph: encoded_img})
reconstructed = np.squeeze(reconstructed)

In [None]:
plt.imshow((reconstructed).astype(np.uint8))

## style transfer stuff

In [None]:
content = encode_file('C:/users/rtimpe/downloads/face.jpg', layer)
style = encode_file('C:/users/rtimpe/downloads/lights.jpg', layer)

In [None]:
# first whiten
v = np.transpose(content, (3, 0, 1, 2)).reshape(content.shape[-1], -1)
v_centered = v - np.mean(v, axis=1)[:, np.newaxis]
w, _ = whiten(v_centered)

# now color
style_r = np.transpose(style, (3, 0, 1, 2)).reshape(style.shape[-1], -1)
style_centered = style_r - np.mean(style_r, axis=1)[:, np.newaxis]
cs, _ = color(w, style_centered)
cs = cs + style_r.mean(axis=1)[:, np.newaxis]
cs_r = np.reshape(cs, (1, 28, 28, 512))

In [None]:
np.linalg.norm(np.dot(w, w.T) / 783 - np.eye(512))

In [None]:
cs.shape

In [None]:
np.linalg.norm(np.dot(cs, cs.T) / 783 - np.dot(style_centered, style_centered.T) / 783)

In [None]:
alpha = .4
interp = alpha * content + (1.0 - alpha) * cs_r

In [None]:
reconstructed = sess.run(reconstructed_image, feed_dict={features_ph: interp})
reconstructed = np.squeeze(reconstructed)

In [None]:
fix, ax = plt.subplots(figsize=(15, 15))
ax.imshow(skimage.transform.resize(reconstructed.astype(np.uint8), (600, 600, 3)))

## other crap

In [None]:
content = encode_file('C:/users/rtimpe/downloads/cat.jpg', layer)

In [None]:
v = content.reshape(-1, content.shape[-1]).T
v_centered = v - np.mean(v, axis=1)[:, np.newaxis]
w, _ = whiten(v_centered)
w = w + np.mean(v, axis=1)[:, np.newaxis]
w = np.reshape(w, (1, 28, 28, 512))

In [None]:
alpha = .15
interp = alpha * content + (1.0 - alpha) * w

In [None]:
reconstructed = sess.run(reconstructed_image, feed_dict={features_ph: interp})
reconstructed = np.squeeze(reconstructed)

In [None]:
fix, ax = plt.subplots(figsize=(15, 15))
ax.imshow(skimage.transform.resize(reconstructed.astype(np.uint8), (600, 600, 3)))

In [None]:
c1 = np.array([[1,2], [2,4]])
c2 = np.eye(2)
mu = np.array([0, 0])
x = np.random.multivariate_normal(mu, c1, 500).T 
x -= x.mean(axis=1)[:,np.newaxis] # colored
y = np.random.multivariate_normal(mu, c2, 600).T
y -= y.mean(axis=1)[:,np.newaxis] # whitened

In [None]:
w, _ = whiten(y)
c, _ = color(w, x)

In [None]:
np.dot(c, c.T) / 599

In [None]:
x.mean(axis=1)

In [None]:
np.dot(w, w.T) / 500