In [None]:
#most of the code for feed forward generation is the same as the training code. some unnecessary code is removed.
import os
import sys
import scipy.io
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from PIL import Image
import numpy as np
import tensorflow as tf
import imageio
import h5py
import tables
import vgg 
from decimal import Decimal
import pickle
import gzip
import time
import functools
from functools import reduce
%matplotlib inline

In [None]:
vgg_path = "../models/imagenet-vgg-verydeep-19.mat"

In [None]:
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)

In [None]:
def weight_initializer(weight_input, output_channel_size, filter_size, deconv = False):
    
    _, rows, columns, input_channel_size = [i.value for i in weight_input.get_shape()]
    
    if deconv:
        weight_shape = [filter_size,filter_size,output_channel_size,input_channel_size]
    else:
        weight_shape = [filter_size,filter_size,input_channel_size,output_channel_size]

    weight_output = tf.Variable(tf.truncated_normal(weight_shape, stddev=0.1, seed=1), dtype=tf.float32)
    
    return weight_output

In [None]:
def instance_norm(net, train=True):
    batch, rows, cols, channels = [i.value for i in net.get_shape()]
    var_shape = [channels]
    mu, sigma_sq = tf.nn.moments(net, [1,2], keep_dims=True)
    shift = tf.Variable(tf.zeros(var_shape))
    scale = tf.Variable(tf.ones(var_shape))
    epsilon = 1e-3
    normalized = (net-mu)/(sigma_sq + epsilon)**(.5)
    return scale * normalized + shift

In [None]:
def conv_block(block_input, num_filters, filter_size, stride_length, use_relu = True):
    
    init_weights = weight_initializer(block_input, num_filters, filter_size)
    strides = [1,stride_length,stride_length,1]
    block_output = tf.nn.conv2d(block_input,init_weights,strides,padding='SAME')
    
    block_output = instance_norm(block_output)
    
    if use_relu:
        block_output = tf.nn.relu(block_output)
    
    return block_output

In [None]:
def conv_transpose_block(block_input, num_filters, filter_size, stride_length):
    
    init_weights = weight_initializer(block_input, num_filters, filter_size, deconv = True)
    
    batch_size, rows, cols, in_channels = [i.value for i in (block_input).get_shape()]
    new_rows, new_cols = int(rows * stride_length), int(cols * stride_length)
    batch_size = tf.shape(block_input)[0]
    new_shape = [batch_size, new_rows, new_cols, num_filters]
    tf_shape = tf.stack(new_shape)
    strides = [1,stride_length,stride_length,1]
    
    block_output = tf.nn.conv2d_transpose(block_input, init_weights, tf_shape, strides, padding='SAME')
    
    block_output = instance_norm(block_output)
    
    return tf.nn.relu(block_output)

In [None]:
def residual_block(block_input, filter_size = 3, num_filters = 128):
    temp = conv_block(block_input, num_filters, filter_size, 1)
    return block_input + conv_block(temp, num_filters, filter_size, 1, use_relu = False)

In [None]:
#define the input model
def fast_model(input_image):
    conv_1 = conv_block(input_image, 32, 9, 1)
    conv_2 = conv_block(conv_1, 64, 3, 2)
    conv_3 = conv_block(conv_2, 128, 3, 2)
    
    res_1 = residual_block(conv_3, 3, 128) 
    res_2 = residual_block(res_1, 3, 128)
    res_3 = residual_block(res_2, 3, 128)
    res_4 = residual_block(res_3, 3, 128)
    res_5 = residual_block(res_4, 3, 128)
    
    conv_t_1 = conv_transpose_block(res_5, 64, 3, 2)
    conv_t_2 = conv_transpose_block(conv_t_1, 32, 3, 2)
    conv_t_3 = conv_block(conv_t_2, 3, 9, 1, use_relu = False)
    final_output = tf.nn.tanh(conv_t_3)*150 + 255.0/2
    return final_output

In [None]:
def total_variation(layer):

    shape = tf.shape(layer)
    height = shape[1]
    width = shape[2]
    y = tf.slice(layer, [0,0,0,0], tf.stack([-1,height-1,-1,-1])) - tf.slice(layer, [0,1,0,0], [-1,-1,-1,-1])
    x = tf.slice(layer, [0,0,0,0], tf.stack([-1,-1,width-1,-1])) - tf.slice(layer, [0,0,1,0], [-1,-1,-1,-1])
    tloss = tf.nn.l2_loss(x) / tf.to_float(tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y))
    return tloss

In [None]:
#code to generate the transformed image
with tf.Graph().as_default(),tf.Session() as sess:
    
    new_img = imageio.imread("../data/content/image_to_convert.jpg")

    content_input = tf.placeholder(tf.float32,shape=(1,new_img.shape[0],new_img.shape[1],3),name='content_ip')
    ####################################################################################################################
    generated_image = fast_model(content_input/255.0)
    ####################################################################################################################
    sess.run(tf.global_variables_initializer())  
    ###################################################################################################################
    restore_model = True
    generate_image = True
    
    #load the trained model weights into the network
    if(restore_model):
        saver =  tf.train.Saver()  
        saver.restore(sess,'../models/style/20180817-1/style_model_2')
    ###################################################################################################################

    if(generate_image):
        imshow(new_img)
        new_img = new_img.reshape(1,new_img.shape[0],new_img.shape[1],new_img.shape[2])
        
        gen_img = sess.run(generated_image,{content_input:new_img})
        gen_img = gen_img.reshape(gen_img.shape[1],gen_img.shape[2],gen_img.shape[3])
        gen_img = np.clip(gen_img, 0, 255).astype('uint8')
        imshow(gen_img)
        imageio.imwrite("../data/generated/converted_image.jpg", gen_img) 
        
        