In [1]:
import os
import sys
import numpy as np
import scipy.io
import scipy.misc
import tensorflow as tf  # Import TensorFlow after Scipy or Scipy will break
from PIL import Image
from util import *

IMAGE_HEIGHT = 227
IMAGE_WIDTH = 227
COLOR_CHANNELS = 3


In [2]:
VGG_MODEL = 'imagenet-vgg-verydeep-19.mat'

class Weights:

    def __init__(self, *args, **kwargs):
        self.weights = {}
        graph = kwargs.get('graph', None)
        shape = kwargs.get('shape', None)
        if graph:
            for key, _ in graph.iteritems():
                # remove first dim for batch size
                weight_shape = graph[key].get_shape()[1:]
                self.weights[key+'_w'] = tf.Variable(tf.zeros(weight_shape), name = key+'_w')
            self.shape = { key: weight.get_shape() for key, weight in self.weights.iteritems() }
        else:
            for key, s in shape.iteritems():
                self.weights[key] = tf.Variable(tf.zeros(s), name = key)
            self.shape = shape

    def add(self, W): #  NOT inplace sum
        Sum = Weights(shape = self.shape)
        for key, _ in Sum.weights.iteritems():
            Sum.weights[key] = tf.add(self.weights[key], W.weights[key])
        return Sum

    def sub(self, W):
        Sub = Weights(shape = self.shape)
        for key, _ in Sub.weights.iteritems():
            Sub.weights[key] = tf.sub(self.weights[key], W.weights[key])
        return Sub

    def sqr_norm(self):
        return sum([tf.reduce_sum(tf.square(w)) for _, w in self.weights.iteritems()])

    def compute_reg(self, graph):

        def _inner_prod(t1, t2):
            # use broadcast of tf.mul, reduce sum in consistant dim
            return tf.reduce_sum(tf.mul(t1,t2), [1, 2, 3])

        # sum of inner products of output coeffs and weights
        return sum([_inner_prod(weight, graph[key[:-2]]) for key, weight in self.weights.iteritems() ])


def build_graph(image):
    graph = load_vgg_model(VGG_MODEL, input_image = image)
    model_var = tf.all_variables()
    return (graph, model_var)

def reg_loss(regs, labels):
    return tf.reduce_mean(tf.squared_difference(regs , labels))

def residual_loss(beta, z, u):
    return beta.sub(z).add(u).sqr_norm()

def start_session(model_var):
    sess = tf.InteractiveSession()
    sess.run(tf.initialize_variables(model_var))
    sess.run(tf.initialize_all_variables())
    return sess

In [5]:
# Content image to use.
CONTENT_IMAGE = 'images/inputs/hummingbird-photo_p1-rot.jpg' #'images/inputs/hummingbird-small.jpg'
content_image = load_image(CONTENT_IMAGE, image_width=IMAGE_WIDTH, image_height=IMAGE_HEIGHT)
# Style image to use.
STYLE_IMAGE = 'images/inputs/Nr2_original_p1-ds.jpg' #'images/inputs/Nr2_orig.jpg'
style_image = load_image(STYLE_IMAGE, image_width=IMAGE_WIDTH, image_height=IMAGE_HEIGHT)
labels = tf.constant([0 , 1], dtype = 'float32')

graph, model_var = build_graph(tf.concat(0, [content_image, style_image]))

beta = Weights(graph = graph)
regs = beta.compute_reg(graph)

z = Weights(graph = graph)
u = Weights(graph = graph)

loss = reg_loss(regs, labels) + residual_loss(beta, z, u)
opt = tf.train.AdamOptimizer(learning_rate=0.00001)
opt_op = opt.minimize(loss, var_list=beta.weights.values())

sess = start_session(model_var)

print( "total number of weight variables: %.4e" % sess.run(sum([tf.size(v) for key, v in beta.weights.iteritems()])) )


Exception AssertionError: AssertionError("Nesting violated for default stack of <type 'weakref'> objects",) in <bound method InteractiveSession.__del__ of <tensorflow.python.client.session.InteractiveSession object at 0x116e188d0>> ignored


total number of weight variables: 1.7198e+07


In [6]:
for key, val in graph.iteritems():
    print("%s layer :" % key)
    v = sess.run(val)
    print("median : %.4e, max: %.4e, mean: %.4e" % (np.median(v), v.max(), np.sqrt(np.mean(v**2))))
    
def normalize_graph(graph):
    for key, val in graph.iteritems():
        graph[key] = tf.mul(1.0/(tf.reduce_mean(val, [0,1,2], True) + 1e-30), val)

normalize_graph(graph)

for key, val in graph.iteritems():
    print("%s layer :" % key)
    v = sess.run(val)
    print("median : %.4e, max: %.4e, mean: %.4e" % (np.median(v), v.max(), np.sqrt(np.mean(v**2))))

conv1_1 layer :
median : 0.0000e+00, max: 6.6271e+02, mean: 5.8276e+01
conv1_2 layer :
median : 1.8818e+01, max: 2.5672e+03, mean: 2.4785e+02
conv5_4 layer :
median : 0.0000e+00, max: 8.0207e+01, mean: 2.1472e+00
conv5_1 layer :
median : 0.0000e+00, max: 6.2764e+02, mean: 3.3763e+01
conv5_3 layer :
median : 0.0000e+00, max: 1.6487e+02, mean: 6.4895e+00
conv5_2 layer :
median : 0.0000e+00, max: 2.9968e+02, mean: 1.4071e+01
conv4_4 layer :
median : 0.0000e+00, max: 2.2615e+03, mean: 7.3880e+01
conv4_1 layer :
median : 3.3254e+01, max: 4.4528e+03, mean: 4.0544e+02
conv4_2 layer :
median : 0.0000e+00, max: 4.5533e+03, mean: 3.2589e+02
conv4_3 layer :
median : 0.0000e+00, max: 3.5280e+03, mean: 1.8845e+02
conv3_4 layer :
median : 1.2348e+01, max: 6.2507e+03, mean: 4.9011e+02
conv3_3 layer :
median : 1.3243e+02, max: 3.4848e+03, mean: 2.8468e+02
conv3_2 layer :
median : 2.0286e+01, max: 3.3438e+03, mean: 1.6056e+02
conv3_1 layer :
median : 0.0000e+00, max: 3.5254e+03, mean: 1.4124e+02
input 