In [None]:
#import the necessary libraries
import os
import sys
import scipy.io
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from PIL import Image
import numpy as np
import tensorflow as tf
import imageio
import h5py
import tables
####################################################################################################################
# the vgg code is a separate py file obtained from here: https://github.com/anishathalye/neural-style
#download and place it in the same location as the python notebook folder
import vgg 
####################################################################################################################
from decimal import Decimal
import pickle
import gzip
import time
import functools
from functools import reduce
%matplotlib inline

In [None]:
#the path to the vgg weights
#the vgg network can be downloaded from http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat
vgg_path = "../models/imagenet-vgg-verydeep-19.mat"

In [None]:
#this ensures the program can use all the gpu resources it can get
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)

In [None]:
#MSCOCO dataset http://cocodataset.org/ was downloaded and an hdf5 file was made from 10000 images..
#.. out of the 80000 iamges present in the dataset. you'd need to create this dataset
hdf5_path = "../data/image_coco_10000_data.hdf5"
hdf5_file = tables.open_file(hdf5_path, mode='r')
content_array = hdf5_file.root.content_images[:,:,:,:]
style_array = hdf5_file.root.style_images[:,:,:,:]
batch_size = 8
style_num = 3

In [None]:
# initialize weight for the feedfwd network
def weight_initializer(weight_input, output_channel_size, filter_size, deconv = False):
    
    _, rows, columns, input_channel_size = [i.value for i in weight_input.get_shape()]
    
    if deconv:
        weight_shape = [filter_size,filter_size,output_channel_size,input_channel_size]
    else:
        weight_shape = [filter_size,filter_size,input_channel_size,output_channel_size]

    weight_output = tf.Variable(tf.truncated_normal(weight_shape, stddev=0.1, seed=1), dtype=tf.float32)
    
    return weight_output

In [None]:
# normalization of activation signals via instance norm fn
def instance_norm(net, train=True):
    batch, rows, cols, channels = [i.value for i in net.get_shape()]
    var_shape = [channels]
    mu, sigma_sq = tf.nn.moments(net, [1,2], keep_dims=True)
    shift = tf.Variable(tf.zeros(var_shape))
    scale = tf.Variable(tf.ones(var_shape))
    epsilon = 1e-3
    normalized = (net-mu)/(sigma_sq + epsilon)**(.5)
    return scale * normalized + shift

In [None]:
# convolution block for the feed fwd network
def conv_block(block_input, num_filters, filter_size, stride_length, use_relu = True):
    
    init_weights = weight_initializer(block_input, num_filters, filter_size)
    strides = [1,stride_length,stride_length,1]
    block_output = tf.nn.conv2d(block_input,init_weights,strides,padding='SAME')
    
    block_output = instance_norm(block_output)
    
    if use_relu:
        block_output = tf.nn.relu(block_output)
    
    return block_output

In [None]:
# reverse of the convolution process
def conv_transpose_block(block_input, num_filters, filter_size, stride_length):
    
    init_weights = weight_initializer(block_input, num_filters, filter_size, deconv = True)
    
    batch_size, rows, cols, in_channels = [i.value for i in (block_input).get_shape()]
    new_rows, new_cols = int(rows * stride_length), int(cols * stride_length)
    batch_size = tf.shape(block_input)[0]
    new_shape = [batch_size, new_rows, new_cols, num_filters]
    tf_shape = tf.stack(new_shape)
    strides = [1,stride_length,stride_length,1]
    
    block_output = tf.nn.conv2d_transpose(block_input, init_weights, tf_shape, strides, padding='SAME')
    
    block_output = instance_norm(block_output)
    
    return tf.nn.relu(block_output)

In [None]:
# residual block fn
def residual_block(block_input, filter_size = 3, num_filters = 128):
    temp = conv_block(block_input, num_filters, filter_size, 1)
    return block_input + conv_block(temp, num_filters, filter_size, 1, use_relu = False)

In [None]:
#define the input model
def fast_model(input_image):
    conv_1 = conv_block(input_image, 32, 9, 1)
    conv_2 = conv_block(conv_1, 64, 3, 2)
    conv_3 = conv_block(conv_2, 128, 3, 2)
    
    res_1 = residual_block(conv_3, 3, 128) 
    res_2 = residual_block(res_1, 3, 128)
    res_3 = residual_block(res_2, 3, 128)
    res_4 = residual_block(res_3, 3, 128)
    res_5 = residual_block(res_4, 3, 128)
    
    conv_t_1 = conv_transpose_block(res_5, 64, 3, 2)
    conv_t_2 = conv_transpose_block(conv_t_1, 32, 3, 2)
    conv_t_3 = conv_block(conv_t_2, 3, 9, 1, use_relu = False)
    final_output = tf.nn.tanh(conv_t_3)*150 + 255.0/2
    return final_output

In [None]:
#function to calculate the content cost
def content_cost(activation_content, activation_generated):
    m, height, width, number_of_channels = activation_generated.get_shape().as_list()
    content_loss = 2.0*tf.nn.l2_loss(activation_generated - activation_content)/float(4*height*width*number_of_channels*batch_size)
    return content_loss

In [None]:
# this function calculates a cost that helps in making the generated image smoother
def total_variation(layer):
    shape = tf.shape(layer)
    height = shape[1]
    width = shape[2]
    y = tf.slice(layer, [0,0,0,0], tf.stack([-1,height-1,-1,-1])) - tf.slice(layer, [0,1,0,0], [-1,-1,-1,-1])
    x = tf.slice(layer, [0,0,0,0], tf.stack([-1,-1,width-1,-1])) - tf.slice(layer, [0,0,1,0], [-1,-1,-1,-1])
    tloss = tf.nn.l2_loss(x) / tf.to_float(tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y))
    return tloss

In [None]:
#load the vgg network weights and mean pixel
vgg_weights,vgg_mean_pixel = vgg.load_net(vgg_path)

#define the vgg network layers which we'll use to calculate style cost. the weights per layer is also mentioned.
style_features = {}
style_layers = [('conv1_1',0.2),('conv2_1',0.2),('conv3_1',0.2),('conv4_1',0.2),('conv5_1',0.2)]
####################################################################################################################
#pre-calculate the style cost for the style image. this saves time on sending data back and forth between the RAM and GPU
with tf.Graph().as_default(),tf.device('/cpu:0'), tf.Session() as sess:
    style_input = tf.placeholder(tf.float32,shape=(1,style_array.shape[1],style_array.shape[2],style_array.shape[3]),name='style_ip')
    style_net = vgg.net_preloaded(vgg_weights, (style_input - vgg_mean_pixel), 'max')
    
    temp_style = style_array[style_num,:,:,:].reshape(1,style_array.shape[1],style_array.shape[2],style_array.shape[3])
    
    for layer_name, coeff in style_layers:
        activation_style = style_net[layer_name].eval(feed_dict={style_input:temp_style})
        activation_style = np.reshape(activation_style,(-1,activation_style.shape[3]))
        gram_style = np.matmul(activation_style.T, activation_style)/activation_style.size
        style_features[layer_name] = gram_style

#using tensorflow session as shown below instead of tf.interactivesession makes the training process much faster.
# the training time for one epoch/iteration of 10000 imgs went from 720 to 580 secs
with tf.Graph().as_default(),tf.Session() as sess:
    number_of_content = content_array.shape[0]

    content_input = tf.placeholder(tf.float32,shape=(batch_size,content_array.shape[1],content_array.shape[2],content_array.shape[3]),name='content_ip')
    ####################################################################################################################
    content_net = vgg.net_preloaded(vgg_weights, (content_input - vgg_mean_pixel), 'max')

    generated_image = fast_model(content_input/255.0)
    generated_net = vgg.net_preloaded(vgg_weights, (generated_image - vgg_mean_pixel), 'max')
    ####################################################################################################################
    activation_content = content_net['conv4_2']
    activation_generated = generated_net['conv4_2']
    j_content = content_cost(activation_content, activation_generated)
    ###################################################################################################################
    style_losses = []

    for layer_name, coeff in style_layers:
        layer = generated_net[layer_name]
        bs, height, width, filters = layer.get_shape().as_list()
        size = height * width * filters
        feats = tf.reshape(layer, [-1, height * width, filters])
        feats_T = tf.transpose(feats, perm=[0,2,1])
        grams = tf.matmul(feats_T, feats) / size
        style_gram = style_features[layer_name]
        style_losses.append(2 * tf.nn.l2_loss(grams - style_gram)/style_gram.size)

    j_style = functools.reduce(tf.add, style_losses) / batch_size
    ###################################################################################################################
    tv_loss = total_variation(generated_image)
    ###################################################################################################################
    alpha = 7.5
    beta = 100.0
    gamma = 200.0
    total_j = alpha*j_content + beta*j_style + gamma*tv_loss
    ###################################################################################################################
    train_step = tf.train.AdamOptimizer(0.01).minimize(total_j)
    num_iterations = 35 #also called epochs.
    sess.run(tf.global_variables_initializer())  
    ###################################################################################################################
    #change the boolean values below to restore and save model respectively. this is useful when you wanna
    #iteratively change the alpha beta gamma or the style layer weights to get the style look that you want 
    #from a generated image. if youre happy with some weight values and you trained the network for some epochs,  
    #you can save and restore the network weights to train for more epochs.
    restore_model = False
    save_model = True
    
    #restore variable values. while saving the model further below, im only saving variable values and not the graph. 
    if(restore_model):
        saver =  tf.train.Saver()  
        saver.restore(sess,'../models/style/20180819-1/style_model_2')
    ###################################################################################################################
    num_minibatch = int(number_of_content/batch_size) 

    for i in range (num_iterations+1):
        print("iteration number = ", i)
        start_time = time.time()    
        for j in range(num_minibatch): 
            temp_content = content_array[j*batch_size:batch_size*(j+1),:,:,:]

            _,jc,js,jt,jtv = sess.run([train_step,j_content, j_style,total_j,tv_loss],
                                              feed_dict={content_input:temp_content}) 

        jt = Decimal(np.asscalar(jt))
        js = Decimal(np.asscalar(js))
        print("total cost = ", '{:.9e}'.format(jt))
        print("content cost = ", jc)
        print("style cost = ", '{:.9e}'.format(js))
        print("tv cost = ", jtv)

        #this number below selects one of the image from the datasets to apply the transofrmation. this is useful 
        #when you wanna see how the image keeps changing as the netwokr is getting trained.
        imagenum = 4
        gen_img = sess.run(generated_image,{content_input:content_array[imagenum:imagenum+batch_size,:,:,:]})
        gen_img = gen_img[0].reshape(gen_img.shape[1],gen_img.shape[2],gen_img.shape[3])
        gen_img = np.clip(gen_img, 0, 255).astype('uint8')
        imshow(gen_img)
        #write the generated image to a file to see its progress
        imageio.imwrite("../data/generated/1-" + str(i) + ".jpg", gen_img) 

        end_time = time.time()
        print("epoch time is ", (end_time - start_time)) #average epoch time is around 580 secs for 10000 imgs
        
        #save the model without the graph.
        if(save_model and i%5 == 0):
            saver_2 = tf.train.Saver()  
            saver_2.save(sess,"../models/style/20180819-1/style_model_2",write_meta_graph=False)
    
    
    ###################################################################################################################
        
        