# Theano Tutorial : A Neural Algorithm of Artistic Style

Author : Sihyeon Seong ( sihyun0826@kaist.ac.kr )


## Data preparation

load input images and subtract global mean


In [88]:
#THEANO_FLAGS='floatX=float32,device=cpu,nvcc.fastmath=True' python artistic_alexnet_train_rmsprop.py

import os
import sys
import timeit

import numpy as np
from os import listdir
from os.path import isfile, join

from scipy import misc
import matplotlib.pyplot as plt

import theano
import theano.tensor as T
from theano.tensor.signal import downsample, pool
from theano.tensor.nnet import conv2d

from pylearn2.expr.normalize import CrossChannelNormalization

theano.config.floatX='float32'

params_path = 'pretrained_weights/parameters_releasing'

img_mean = np.load('pretrained_weights/img_mean.npy')
print 'img_mean shape : ',img_mean.shape
print np.mean(np.mean(img_mean,axis=1),axis=1)

file_list = ['van_gogh_starry_night.jpg', 'kaist_n1.jpg']

for i in file_list: 

    f = misc.imread(i)
    min_dim, max_dim = np.argmin(f.shape[:2]), np.argmax(f.shape[:2])

    resize_scale = 227.0/f.shape[min_dim]

    f = misc.imresize(f,[int(f.shape[0]*resize_scale),int(f.shape[1]*resize_scale)])
    f = f[int((f.shape[0]-227.0)/2):int((f.shape[0]-227.0)/2)+227, int((f.shape[1]-227.0)/2):int((f.shape[1]-227.0)/2)+227, :]
    print 'image shape(before) : ',f.shape

    plt.imshow(f)
    plt.show()

    f = np.transpose(f,(2,0,1))
    print 'image shape(after) : ',f.shape

    preprocessed_img = np.asarray(f,dtype=np.float32)-img_mean[:,16:16+227,16:16+227]

    np.save(i[:len(i)-4]+'.npy',preprocessed_img)


img_mean shape :  (3, 256, 256)
[ 122.22585154  116.20915924  103.56548431]
image shape(before) :  (227, 227, 3)
image shape(after) :  (3, 227, 227)
image shape(before) :  (227, 227, 3)
image shape(after) :  (3, 227, 227)


## Build our model

1) Build Convolution & Pooling layer


In [89]:

class ConvPoolLayer(object):

    def __init__(self, input, filter_shape, image_shape, f_params_w, f_params_b, lrn=False, t_style=None, t_content=None, convstride=1, padsize =0, group=1, poolsize = 3, poolstride = 1):

        self.input = input

        if t_style is not None:
            self.t_style = np.asarray(np.load(t_style),dtype=theano.config.floatX)

        if t_content is not None:
            self.t_content = np.asarray(np.load(t_content),dtype=theano.config.floatX)

        if lrn is True:
            self.lrn_func = CrossChannelNormalization()

        if group == 1:
            self.W = theano.shared(np.asarray(np.transpose(np.load(os.path.join(params_path,f_params_w)),(3,0,1,2)),dtype=theano.config.floatX), borrow=True)
            self.b = theano.shared(np.asarray(np.load(os.path.join(params_path,f_params_b)),dtype=theano.config.floatX), borrow=True)
            conv_out = conv2d(input=self.input,filters=self.W,filter_shape=filter_shape,border_mode = padsize,subsample=(convstride, convstride),filter_flip=True)

        elif group == 2:
            self.filter_shape = np.asarray(filter_shape)
            self.image_shape = np.asarray(image_shape)
            self.filter_shape[0] = self.filter_shape[0] / 2
            self.filter_shape[1] = self.filter_shape[1] / 2
            self.image_shape[1] = self.image_shape[1] / 2
            self.W0 = theano.shared(np.asarray(np.transpose(np.load(os.path.join(params_path,f_params_w[0])),(3,0,1,2)),dtype=theano.config.floatX), borrow=True)
            self.W1 = theano.shared(np.asarray(np.transpose(np.load(os.path.join(params_path,f_params_w[1])),(3,0,1,2)),dtype=theano.config.floatX), borrow=True)
            self.b0 = theano.shared(np.asarray(np.load(os.path.join(params_path,f_params_b[0])),dtype=theano.config.floatX), borrow=True)
            self.b1 = theano.shared(np.asarray(np.load(os.path.join(params_path,f_params_b[1])),dtype=theano.config.floatX), borrow=True)
            conv_out0 = conv2d(input=self.input[:,:self.image_shape[1],:,:],filters=self.W0,filter_shape=tuple(self.filter_shape),border_mode = padsize,subsample=(convstride, convstride),filter_flip=True) + self.b0.dimshuffle('x', 0, 'x', 'x')
            conv_out1 = conv2d(input=self.input[:,self.image_shape[1]:,:,:],filters=self.W1,filter_shape=tuple(self.filter_shape),border_mode = padsize,subsample=(convstride, convstride),filter_flip=True) + self.b1.dimshuffle('x', 0, 'x', 'x')
            conv_out = T.concatenate([conv_out0, conv_out1],axis=1)

        else:
            raise AssertionError()

        relu_out = T.maximum(conv_out, 0)
        if poolsize != 1:
            self.output = pool.pool_2d(input=relu_out,ds=(poolsize,poolsize),ignore_border=True, st=(poolstride,poolstride),mode='average_exc_pad')
        else:
            self.output = relu_out

        if lrn is True:
            self.output = self.lrn_func(self.output)

    def style_error(self):
        gram_matrix_ori = T.dot(self.t_style.reshape((self.t_style.shape[1],self.t_style.shape[2]*self.t_style.shape[3])),self.t_style.reshape((self.t_style.shape[1],self.t_style.shape[2]*self.t_style.shape[3])).T)
        gram_matrix_gen = T.dot(self.output.reshape((self.t_style.shape[1],self.t_style.shape[2]*self.t_style.shape[3])),self.output.reshape((self.t_style.shape[1],self.t_style.shape[2]*self.t_style.shape[3])).T)
        return T.sum(T.sum((gram_matrix_gen-gram_matrix_ori)**2))/(4.0*(self.t_style.shape[1]**2)*((self.t_style.shape[2]*self.t_style.shape[3])**2))

    def content_error(self):
        return T.sum(T.sum(T.sum((self.output-self.t_content)**2)))/2.0
    


2) Build Our Models


In [90]:

batch_size = 1

rng = np.random.RandomState(23455)

lr = T.fscalar('lr')
x = theano.shared(input_img,borrow=True)
print('... building the model')


... building the model


In [91]:

layer1_input = x.reshape((batch_size, 3, 227, 227))

convpool_layer1 = ConvPoolLayer(input=layer1_input, image_shape=(batch_size, 3, 227, 227), 
                                filter_shape=(96, 3, 11, 11), 
                                f_params_w='W_0_65.npy', 
                                f_params_b='b_0_65.npy', 
                                t_style = 'cnn_features/van_gogh_starry_night_1.npy', 
                                t_content = 'cnn_features/kaist_n1_1.npy', 
                                lrn=True, convstride=4, padsize=0, group=1, poolsize=3, poolstride=2)

convpool_layer2 = ConvPoolLayer(input=convpool_layer1.output,image_shape=(batch_size, 96, 27, 27),
                                filter_shape=(256, 96, 5, 5), 
                                f_params_w=['W0_1_65.npy','W1_1_65.npy'],
                                f_params_b=['b0_1_65.npy','b1_1_65.npy'],
                                t_style = 'cnn_features/van_gogh_starry_night_2.npy', 
                                lrn=True, convstride=1, padsize=2, group=2, poolsize=3, poolstride=2)

convpool_layer3 = ConvPoolLayer(input=convpool_layer2.output,image_shape=(batch_size, 256, 13, 13),
                                filter_shape=(384, 256, 3, 3), 
                                f_params_w='W_2_65.npy', 
                                f_params_b='b_2_65.npy', 
                                t_style = 'cnn_features/van_gogh_starry_night_3.npy',
                                convstride=1, padsize=1, group=1,poolsize=1, poolstride=0)

convpool_layer4 = ConvPoolLayer(input=convpool_layer3.output,image_shape=(batch_size, 384, 13, 13),
                                filter_shape=(384, 384, 3, 3), 
                                f_params_w=['W0_3_65.npy','W1_3_65.npy'], 
                                t_style = 'cnn_features/van_gogh_starry_night_4.npy', 
                                f_params_b=['b0_3_65.npy','b1_3_65.npy'],
                                convstride=1, padsize=1, group=2,poolsize=1, poolstride=0)

convpool_layer5 = ConvPoolLayer(input=convpool_layer4.output,image_shape=(batch_size, 384, 13, 13),
                                filter_shape=(256, 384, 3, 3), 
                                f_params_w=['W0_4_65.npy','W1_4_65.npy'], 
                                t_style = 'cnn_features/van_gogh_starry_night_5.npy', 
                                f_params_b=['b0_4_65.npy','b1_4_65.npy'], 
                                convstride=1, padsize=1, group=2,poolsize=3, poolstride=2)

cost= 0.2*(convpool_layer1.style_error() + convpool_layer2.style_error() 
           + convpool_layer3.style_error() + convpool_layer4.style_error() 
           + convpool_layer5.style_error()) + 0.00002*convpool_layer1.content_error()



In [92]:

print('... train')

params = x

grads = T.grad(cost, params)

#####RMSprop
decay = 0.9
max_scaling=1e5
epsilon = 1. / max_scaling

vels = theano.shared(params.get_value() * 0.) 

new_mean_squared_grad = (decay * vels + (1 - decay) * T.sqr(grads))
rms_grad_t = T.sqrt(new_mean_squared_grad)
delta_x_t = - lr * grads / rms_grad_t

updates=[]
updates.append((params,params + delta_x_t))
updates.append((vels,new_mean_squared_grad))


inference_model = theano.function(
    [],
    [convpool_layer1.output, convpool_layer2.output, convpool_layer3.output, convpool_layer4.output, convpool_layer5.output],
    givens={
        x: input_img
    }
)


... train


In [93]:

print('... inference')

x.set_value(np.load('van_gogh_starry_night.npy').astype(np.float32).reshape(1,3,227,227))

results1 = inference_model()

x.set_value(np.load('kaist_n1.npy').astype(np.float32).reshape(1,3,227,227))

results2 = inference_model()

for i in xrange(5):
    filestr = 'cnn_features/'+file_list[0][:len(file_list[0])-4]+'_%d.npy'%(i+1)
    np.save(filestr,results1[i])
    filestr = 'cnn_features/'+file_list[1][:len(file_list[1])-4]+'_%d.npy'%(i+1)
    np.save(filestr,results2[i])
    print '%s saved'%filestr

img_out = theano.function([],x)

train_model = theano.function([lr],cost,updates=updates)
    

... inference
cnn_features/kaist_n1_1.npy saved
cnn_features/kaist_n1_2.npy saved
cnn_features/kaist_n1_3.npy saved
cnn_features/kaist_n1_4.npy saved
cnn_features/kaist_n1_5.npy saved


### Generate input images

In [None]:

x.set_value(np.random.normal(0.0, 1.0, size=(1,3,227,227)).astype(np.float32))

img_mean = np.load('pretrained_weights/img_mean.npy')
img_mean = np.transpose(img_mean,(1,2,0))

tmp = np.transpose(np.squeeze(input_img),(1,2,0))

recon = tmp+img_mean[16:16+227,16:16+227,:]

fig = plt.figure()
fig_handle = plt.imshow(recon.astype(np.uint8))
fig.show()


In [None]:

n_epochs = 3000
learning_rate = 1.0
schedules = [5000,12000]
lr_phase = 0
for i in xrange(1, n_epochs):
    img_gen = np.transpose(np.squeeze(img_out()),(1,2,0))
    recon = img_gen+img_mean[16:16+227,16:16+227,:]
    fig_handle.set_data(recon.astype(np.uint8))
    fig.canvas.draw()

    print train_model(learning_rate), ' / epochs : ', i

    if i%1000==0:
        results = img_out()
        np.save('results.npy',results)
        

1897.91088867  / epochs :  1
1542.4642334  / epochs :  2
1337.92810059  / epochs :  3
1198.8614502  / epochs :  4
1092.32348633  / epochs :  5
1005.84790039  / epochs :  6
932.965454102  / epochs :  7
870.639343262  / epochs :  8
816.007873535  / epochs :  9
766.907714844  / epochs :  10
723.359924316  / epochs :  11
683.648071289  / epochs :  12
647.439880371  / epochs :  13
614.358093262  / epochs :  14
583.717895508  / epochs :  15
555.446777344  / epochs :  16
529.023803711  / epochs :  17
504.585327148  / epochs :  18
481.651123047  / epochs :  19
460.347106934  / epochs :  20
440.267333984  / epochs :  21
421.392456055  / epochs :  22
403.350982666  / epochs :  23
386.774505615  / epochs :  24
370.733306885  / epochs :  25
355.711761475  / epochs :  26
341.130096436  / epochs :  27
327.312225342  / epochs :  28
314.278594971  / epochs :  29
301.829223633  / epochs :  30
290.003662109  / epochs :  31
278.669494629  / epochs :  32
267.842773438  / epochs :  33
257.593261719  / epoc

In [None]:
plt.close()
