In [4]:
import scipy.misc
import scipy.io
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from nst_utils import * 

%matplotlib inline

In [16]:
def compute_content_cost(sess,model,layer):
    content_I=sess.run(model[layer])
    generated_I=model[layer]
    J=0.25*tf.reduce_mean((content_I-generated_I)**2)
    return J

def compute_single_style_cost(contentDst,contentSrc):
    _, n_H, n_W, n_C = contentSrc.get_shape().as_list()
    dst=tf.reshape(contentDst,[n_H*n_W,n_C])
    src=tf.reshape(contentSrc,[n_H*n_W,n_C])
    R_dst=tf.matmul(tf.transpose(dst),dst)
    R_src=tf.matmul(tf.transpose(src),src)
    J=tf.reduce_sum((R_dst-R_src)**2)/(2*n_H*n_W*n_C)**2
    return J

def compute_style_cost(sess,model,layers):
    J=0
    for name,weight in layers:
        style_I=sess.run(model[name])
        generated_I=model[name]
        J+= weight*compute_single_style_cost(style_I,generated_I)
    return J

def compute_total_cost(J_content,J_style,alpha,beta):
    return alpha*J_content+beta*J_style


    
    

In [14]:
tf.reset_default_graph()
sess = tf.InteractiveSession()

model = load_vgg_model("pretrained-model/imagenet-vgg-verydeep-19.mat")
content_image = scipy.misc.imread("images/louvre_small.jpg")
content_image = reshape_and_normalize_image(content_image)
print(content_image.shape)
style_image = scipy.misc.imread("images/monet.jpg")
style_image = reshape_and_normalize_image(style_image)
print(style_image.shape)
generated_image = generate_noise_image(content_image)
print(generated_image.shape)



(1, 300, 400, 3)
(1, 300, 400, 3)
(1, 300, 400, 3)


`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
  after removing the cwd from sys.path.
`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
  import sys


In [20]:
sess.run(model['input'].assign(content_image))
J_content=compute_content_cost(sess,model,'conv4_2')

sess.run(model['input'].assign(style_image))
layers = [
    ('conv1_1', 0.2),
    ('conv2_1', 0.2),
    ('conv3_1', 0.2),
    ('conv4_1', 0.2),
    ('conv5_1', 0.2)]
J_style=compute_style_cost(sess,model,layers)
J=compute_total_cost(J_content,J_style,10,20)

optimizer = tf.train.AdamOptimizer(2.0)
train_op=optimizer.minimize(J)

In [21]:
def train_model(sess,input_image,num_iter=200):
    sess.run(tf.global_variables_initializer())
    sess.run(model['input'].assign(input_image))
    
    for i in range(num_iter):
        sess.run(train_op)
        generated_image = sess.run(model['input'])
        if i%20 == 0:
            Jt, Jc, Js = sess.run([J, J_content, J_style])
            print("Iteration " + str(i) + " :")
            print("total cost = " + str(Jt))
            print("content cost = " + str(Jc))
            print("style cost = " + str(Js))
            save_image("output/" + str(i) + ".png", generated_image)
    save_image('output/generated_image.jpg', generated_image)
    return generated_image

In [22]:
train_model(sess,generated_image)

Iteration 0 :
total cost = 2534359800.0
content cost = 7919.0483
style cost = 126714030.0
Iteration 20 :
total cost = 473896450.0
content cost = 15247.318
style cost = 23687198.0
Iteration 40 :
total cost = 243137340.0
content cost = 16756.848
style cost = 12148489.0
Iteration 60 :
total cost = 158345870.0
content cost = 17459.568
style cost = 7908564.0
Iteration 80 :
total cost = 116863580.0
content cost = 17764.785
style cost = 5834297.0
Iteration 100 :
total cost = 93018696.0
content cost = 18022.463
style cost = 4641923.5
Iteration 120 :
total cost = 77290330.0
content cost = 18243.812
style cost = 3855394.5
Iteration 140 :
total cost = 65919360.0
content cost = 18388.842
style cost = 3286773.5
Iteration 160 :
total cost = 57163844.0
content cost = 18504.707
style cost = 2848939.8
Iteration 180 :
total cost = 50283600.0
content cost = 18623.188
style cost = 2504868.5


array([[[[ -19.62501   ,  -28.33849   ,   18.568626  ],
         [ -12.1398735 ,  -41.11242   ,   18.823547  ],
         [ -12.810609  ,  -46.72897   ,    9.501727  ],
         ...,
         [ -18.688385  ,  -39.155567  ,    6.1682076 ],
         [ -20.637672  ,  -20.407604  ,   15.410741  ],
         [ -26.01847   ,  -27.456268  ,   10.056149  ]],

        [[ -46.131184  ,  -65.87976   ,  -20.294115  ],
         [ -50.537144  ,  -38.507336  ,  -12.715837  ],
         [ -31.69999   ,  -48.70993   ,  -26.394997  ],
         ...,
         [ -19.160255  ,  -37.85635   ,    1.5577611 ],
         [ -35.121815  ,  -63.264805  ,   14.420129  ],
         [ -40.003174  ,  -72.3035    ,   16.728476  ]],

        [[ -62.120274  ,  -75.07994   ,  -63.44629   ],
         [ -44.424698  ,  -56.378136  ,  -24.05177   ],
         [ -34.13938   ,  -55.166622  ,  -50.319424  ],
         ...,
         [ -29.739195  ,  -50.237865  ,  -12.106047  ],
         [ -23.49479   ,  -44.740005  ,   -7.220371  ],
  