In [None]:
import time

import tensorflow as tf
import numpy as np

import skimage
import skimage.io
import skimage.transform
import matplotlib.pyplot as plt
import custom_vgg19
import Lib

In [None]:
BATCH_SIZE = 10
input_shape = [BATCH_SIZE, 256, 256, 3]
STYLE_LAYERS = ('conv1_1', 'conv2_1', 'conv3_1', 'conv4_1')
CONTENT_LAYER = 'conv4_2' # I can get good result with relu3_2 with slow neural-style with same weight. maybe I can try here
CONTENT_WEIGHT = 7.5
STYLE_WEIGHT = 100
NEW_H, NEW_W = 256, 256

In [None]:
styleimg = Lib.load_image_as_batch_with_optional_resize('./picasso_selfport1907.jpg')
print(styleimg.shape)
contentimg = Lib.load_image_as_batch_with_optional_resize('./brad_pitt.jpg', newH=NEW_H, newW=NEW_W)
print(contentimg.shape)

# show image
# skimage.io.imshow(contentimg[0])
# plt.show()

In [None]:
# Now we can go ahead and extract content features and style features
sess=tf.Session()
styleimg_ph = tf.placeholder(tf.float32, shape=styleimg.shape)
vgg19factory = custom_vgg19.Vgg19Factory()
vgg19_pretrain = vgg19factory.build(styleimg_ph)

# sanity check: make sure the layer names are correct
try:
    style_layers_pretrain = [getattr(vgg19_pretrain, name) for name in STYLE_LAYERS]
    content_layer_pretrain = getattr(vgg19_pretrain, CONTENT_LAYER)
except Exception as ex:
    print ex,  "incorrect layer name. Note: all layer named 'conv' is relu. e.g. 'conv1_1' is actually 'relu1_1'"
    sys.exit(1)

styleimg_grams = [gram_matrix(l) for l in style_layers_pretrain]
styleimg_grams_np = sess.run(styleimg_grams, feed_dict={styleimg_ph:styleimg})
# contentimg_feat_map_np = sess.run(content_layer_pretrain, feed_dict={styleimg_ph:contentimg}) # just for debug propose. It's not slow neural-style, so there is no target content img during training
styleimg_grams = [tf.constant(g, dtype=tf.float32) for g in styleimg_grams_np]

In [None]:
# construct img transfrom network
img_train = tf.placeholder(tf.float32, input_shape)
img_pred = Lib.buildTransformNet(img_train, expected_shape=input_shape)

In [None]:
# construct vgg19 to extract pred img's content & style
vgg19_pred = vgg19factory.build(img_pred)  # make sure pred img have VGG19's desired scale and range([0,1])
style_layers_pred = [getattr(vgg19_pred, name) for name in STYLE_LAYERS]
content_layer_pred = getattr(vgg19_pred, CONTENT_LAYER)

# construct vgg19 to extract train img's content as ground truth
vgg19_extractContent = vgg19factory.build(img_train)    # TODO ugly solution! 
   # So, in total I have to build 3 same vgg19 just because I have different input Tensor 
    # (two are placeholders of different shapes; the other one is the predicted image). Any way to avoid this?
content_layer_target = getattr(vgg19_extractContent, CONTENT_LAYER)


In [None]:
style_losses = [Lib.compute_style_loss(styleimg_grams[i], style_layers_pred[i]) for i in xrange(len(styleimg_grams))]
content_loss = Lib.compute_content_loss(content_layer_target, content_layer_pred)
loss = STYLE_WEIGHT * reduce(tf.add, style_losses) + CONTENT_WEIGHT * content_loss
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
train_op = optimizer.minimize(loss)

** integration test: read some pictures and overfit the network to it **

In [None]:
import os
test_batch_f = filter(lambda s:s.startswith('COCO'), os.listdir('data'))[:BATCH_SIZE]
assert len(test_batch_f) == BATCH_SIZE, ('not enough files', len(test_batch_f))
test_batch_np = np.zeros(input_shape)
for i in xrange(BATCH_SIZE):
    test_batch_np[i] = load_image_as_batch_with_optional_resize('data/'+test_batch_f[i], newH=NEW_H, newW=NEW_W)

sess.run(tf.initialize_all_variables())

In [None]:
MAX_ITER = 200
for i in xrange(MAX_ITER):
    l = sess.run([train_op, loss]+ style_losses +[content_loss], feed_dict={img_train: test_batch_np})
    print l[1:]

In [None]:
img_pred_np = sess.run(img_pred, feed_dict={img_train: test_batch_np})

f,axarr=plt.subplots(3,3, figsize=(10,10))
for i in xrange(3):  
    for j in xrange(3):
        img = np.clip(img_pred_np[i*3+j],0,1)
        axarr[i][j].imshow(img)
        axarr[i][j].xaxis.set_visible(False)
        axarr[i][j].yaxis.set_visible(False)
plt.show()

In [None]:
f,axarr=plt.subplots(3,3)
for i in xrange(3): 
    for j in xrange(3): 
        axarr[i][j].imshow(test_batch_np[i*3+j])
        plt.imsave(str(i*3+j)+)
plt.show()

In [None]:
saver=tf.train.Saver()
saver.save(sess, 'chkpt/cur.ckpt')