In [1]:
import os
import sys
import numpy as np
import scipy.io
import scipy.misc
import tensorflow as tf  # Import TensorFlow after Scipy or Scipy will break
from PIL import Image
from util import *

IMAGE_HEIGHT = 227
IMAGE_WIDTH = 227
COLOR_CHANNELS = 3


In [2]:
VGG_MODEL = 'imagenet-vgg-verydeep-19.mat'

class Weights:

    def __init__(self, *args, **kwargs):
        self.weights = {}
        graph = kwargs.get('graph', None)
        shape = kwargs.get('shape', None)
        npzfile = kwargs.get('npzfile', None)
        if graph: # initialize by corresponding graph structure
            for key, _ in graph.iteritems():
                # remove first dim for batch size
                weight_shape = graph[key].get_shape()[1:]
                self.weights[key+'_w'] = tf.Variable(tf.zeros(weight_shape), name = key+'_w')
            self.shape = { key: weight.get_shape() for key, weight in self.weights.iteritems() }
        elif shape: # initialize w.r.t shape of existing Weight
            for key, s in shape.iteritems():
                self.weights[key] = tf.Variable(tf.zeros(s), name = key)
            self.shape = shape
        else: # initialize from loaded file
            for key, value in npzfile.iteritems():
                self.weights[key] = tf.Variable(value, name = key)
            self.shape = { key: weight.get_shape() for key, weight in self.weights.iteritems() }

    def add(self, W): #  NOT inplace sum
        Sum = Weights(shape = self.shape)
        for key, _ in Sum.weights.iteritems():
            Sum.weights[key] = tf.add(self.weights[key], W.weights[key])
        return Sum

    def sub(self, W):
        Sub = Weights(shape = self.shape)
        for key, _ in Sub.weights.iteritems():
            Sub.weights[key] = tf.sub(self.weights[key], W.weights[key])
        return Sub

    def sqr_norm(self):
        return sum([tf.reduce_sum(tf.square(w)) for _, w in self.weights.iteritems()])

    def soft_thresh(self, s):
        W = Weights(shape = self.shape)
        for key, w in self.weights.iteritems():
            W.weights[key] = tf.maximum(tf.abs(w) - s, tf.zeros(w.get_shape()))
            W.weights[key] = tf.mul(tf.sign(w), W.weights[key])
        return W

    def compute_reg(self, graph):

        def _inner_prod(t1, t2):
            # use broadcast of tf.mul, reduce sum in consistant dim
            return tf.reduce_sum(tf.mul(t1,t2), [1, 2, 3])

        # sum of inner products of output coeffs and weights
        return sum([_inner_prod(weight, graph[key[:-2]]) for key, weight in self.weights.iteritems() ])

In [7]:
def build_graph(image):
    graph = load_vgg_model(VGG_MODEL, input_image = image)
    model_var = tf.all_variables()

    def _normalize_graph(graph):
        for key, val in graph.iteritems():
            graph[key] = tf.scalar_mul(.1/tf.reduce_mean(val), val)

    _normalize_graph(graph)

    return (graph, model_var)


def reg_loss(regs, labels):
    return tf.reduce_mean(tf.squared_difference(regs , labels))

def residual_loss(beta, z, u):
    return beta.sub(z).add(u).sqr_norm()

def start_session(model_var):
    sess = tf.InteractiveSession()
    sess.run(tf.initialize_variables(model_var))
    sess.run(tf.initialize_all_variables())
    return sess

def update_z(z, beta, u, s):
    op = []
    for key, val in beta.add(u).soft_thresh(s).weights.iteritems():
        op.append(z.weights[key].assign(val))
    return tf.group(*op)

def check(z, beta, u, s):
    tmp = beta.add(u).soft_thresh(s)
    min_nz = []
    for _, val in tmp.weights.iteritems():
        min_nz.append( tf.reduce_min(tf.abs(val) + tf.scalar_mul(1,tf.to_float(tf.equal(val, 0))) ) )
    return tf.reduce_min(tf.pack(min_nz))

def update_u(u, beta, z):
    op = []
    for key, val in u.add(beta).sub(z).weights.iteritems():
        op.append(u.weights[key].assign(val))
    return tf.group(*op)

def tf_count_zero(t):
    elements_equal_to_value = tf.equal(t, 0)
    as_ints = tf.cast(tf.equal(t, 0), tf.int32)
    count = tf.reduce_sum(as_ints)
    return count

def load_Weight(filename):
    Wfile = np.load(filename)
    return Weights(npzfile = Wfile)

In [4]:
# Content image to use.
CONTENT_IMAGE = 'images/inputs/hummingbird-photo_p1-rot.jpg' #'images/inputs/hummingbird-small.jpg'
content_image = load_image(CONTENT_IMAGE, image_width=IMAGE_WIDTH, image_height=IMAGE_HEIGHT)
# Style image to use.
STYLE_IMAGE = 'images/inputs/Nr2_original_p1-ds.jpg' #'images/inputs/Nr2_orig.jpg'
style_image = load_image(STYLE_IMAGE, image_width=IMAGE_WIDTH, image_height=IMAGE_HEIGHT)
labels = tf.constant([0 , 1], dtype = 'float32')

graph, model_var = build_graph(tf.concat(0, [content_image, style_image]))

beta = Weights(graph = graph)
regs = beta.compute_reg(graph)

z = Weights(graph = graph)
u = Weights(graph = graph)

loss = reg_loss(regs, labels) + residual_loss(beta, z, u)
opt = tf.train.AdamOptimizer(learning_rate=0.0000001)
opt_op = opt.minimize(loss, var_list=beta.weights.values())

sess = start_session(model_var)
itr = 0
s = 2e-7
loss_bd = 1.0e-3


check minimum non-zero value: 1.0000e+00


In [6]:
opt_op.run()
loss_val = sess.run(loss)
sess.run(update_z(z, beta, u, s))
print("check minimum non-zero value: %.4e" % sess.run(check(z, beta, u, s)))
sess.run(update_u(u, beta, z))


check minimum non-zero value: 1.5632e-13


In [16]:
for key in beta.weights.keys():
    tmp = tf.abs(beta.soft_thresh(s).weights[key])
    print("%s : %.4e" % (key, tf.reduce_min(tmp + tf.to_float(tf.equal(tmp, 0))).eval()) )
    print("%s : %.4e" % (key, tf.reduce_max(tmp).eval()) )
    

conv5_2_w : 1.0000e+00
conv5_2_w : 0.0000e+00
conv5_3_w : 1.0000e+00
conv5_3_w : 0.0000e+00
conv5_1_w : 1.0000e+00
conv5_1_w : 0.0000e+00
conv1_2_w : 1.0000e+00
conv1_2_w : 0.0000e+00
conv1_1_w : 1.0000e+00
conv1_1_w : 0.0000e+00
avgpool2_w : 1.0000e+00
avgpool2_w : 0.0000e+00
conv4_4_w : 1.0000e+00
conv4_4_w : 0.0000e+00
conv4_3_w : 1.0000e+00
conv4_3_w : 0.0000e+00
conv4_2_w : 1.0000e+00
conv4_2_w : 0.0000e+00
avgpool1_w : 1.0000e+00
avgpool1_w : 0.0000e+00
conv5_4_w : 1.0000e+00
conv5_4_w : 0.0000e+00
conv3_1_w : 1.0000e+00
conv3_1_w : 0.0000e+00
conv2_1_w : 1.0000e+00
conv2_1_w : 0.0000e+00
conv4_1_w : 1.0000e+00
conv4_1_w : 0.0000e+00
conv3_2_w : 1.0000e+00
conv3_2_w : 0.0000e+00
avgpool4_w : 1.0000e+00
avgpool4_w : 0.0000e+00
conv3_3_w : 1.0000e+00
conv3_3_w : 0.0000e+00
avgpool5_w : 1.0000e+00
avgpool5_w : 0.0000e+00
input_w : 9.9476e-14
input_w : 1.3480e-10
conv2_2_w : 1.0000e+00
conv2_2_w : 0.0000e+00
conv3_4_w : 1.0000e+00
conv3_4_w : 0.0000e+00
avgpool3_w : 1.0000e+00
avgpoo

In [4]:
npzfile = np.load("z_thresh-1.00e-06.npz")

In [8]:
z = load_Weight("z_thresh-1.00e-06.npz")

In [20]:
# Content image to use.
CONTENT_IMAGE = 'images/inputs/hummingbird-photo_p1-rot.jpg' #'images/inputs/hummingbird-small.jpg'
content_image = load_image(CONTENT_IMAGE, image_width=IMAGE_WIDTH, image_height=IMAGE_HEIGHT)
image = tf.squeeze(tf.constant( content_image ))
# Style image to use.
STYLE_IMAGE = 'images/inputs/Nr2_original_p1-ds.jpg' #'images/inputs/Nr2_orig.jpg'
style_image = load_image(STYLE_IMAGE, image_width=IMAGE_WIDTH, image_height=IMAGE_HEIGHT)
images = tf.concat(0, [content_image, style_image])

In [17]:
#tf.image.crop_to_bounding_box(image, 0, 0, 200, 200)
tf.image.crop_to_bounding_box(tf.image.resize_images(image, 200, 100), 100, 0, 100, 100)

<tf.Tensor 'Slice_1:0' shape=(100, 100, 3) dtype=float32>