In [10]:
import vgg, pdb, time
import tensorflow as tf, numpy as np, os
import transform
import scipy.misc, numpy as np, os
from functools import reduce

In [11]:
def save_img(out_path, img):
    img = np.clip(img, 0, 255).astype(np.uint8)
    scipy.misc.imsave(out_path, img)

def scale_img(style_path, style_scale):
    scale = float(style_scale)
    o0, o1, o2 = scipy.misc.imread(style_path, mode='RGB').shape
    scale = float(style_scale)
    new_shape = (int(o0 * scale), int(o1 * scale), o2)
    style_target = _get_img(style_path, img_size=new_shape)
    return style_target

def get_img(src, img_size=False):
   img = scipy.misc.imread(src, mode='RGB') # misc.imresize(, (256, 256, 3))
   if not (len(img.shape) == 3 and img.shape[2] == 3):
       img = np.dstack((img,img,img))
   if img_size != False:
       img = scipy.misc.imresize(img, img_size)
   return img

def exists(p, msg):
    assert os.path.exists(p), msg

def list_files(in_path):
    files = []
    for (dirpath, dirnames, filenames) in os.walk(in_path):
        files.extend(filenames)
        break

    return files
def _tensor_size(tensor):
    from operator import mul
    return reduce(mul, (d.value for d in tensor.get_shape()[1:]), 1)

In [12]:
STYLE_LAYERS = ('relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1')
CONTENT_LAYER = 'relu4_2'
DEVICES = 'CUDA_VISIBLE_DEVICES'

In [75]:
def optimize(content_targets, style_target, content_weight, style_weight,
             tv_weight, vgg_path, epochs=2, print_iterations=1000,
             batch_size=4, save_path='saver/fns.ckpt',
             learning_rate=1e-3):
    mod = len(content_targets) % batch_size
    if mod > 0:
        print("Train set has been trimmed slightly..")
        content_targets = content_targets[:-mod] 

    style_features = {}

    batch_shape = (batch_size,256,256,3)
    style_shape = (1,) + style_target.shape
    print(style_shape)

    # precompute style features
    with tf.Graph().as_default(), tf.device('/cpu:0'), tf.Session() as sess:
        style_image = tf.placeholder(tf.float32, shape=style_shape, name='style_image')
        style_image_pre = vgg.preprocess(style_image)
        print(vgg_path)
        net = vgg.net(vgg_path, style_image_pre)
        style_pre = np.array([style_target])
        for layer in STYLE_LAYERS:
            features = net[layer].eval(feed_dict={style_image:style_pre})
            features = np.reshape(features, (-1, features.shape[3]))
            gram = np.matmul(features.T, features) / features.size
            style_features[layer] = gram

    with tf.Graph().as_default(), tf.Session() as sess:
        X_content = tf.placeholder(tf.float32, shape=batch_shape, name="X_content")
        X_pre = vgg.preprocess(X_content)

        # precompute content features
        content_features = {}
        content_net = vgg.net(vgg_path, X_pre)
        content_features[CONTENT_LAYER] = content_net[CONTENT_LAYER]

        preds = transform.net(X_content/255.0)
        preds_pre = vgg.preprocess(preds)

        net = vgg.net(vgg_path, preds_pre)

        content_size = _tensor_size(content_features[CONTENT_LAYER])*batch_size
        assert _tensor_size(content_features[CONTENT_LAYER]) == _tensor_size(net[CONTENT_LAYER])
        content_loss = content_weight * (2 * tf.nn.l2_loss(
            net[CONTENT_LAYER] - content_features[CONTENT_LAYER]) / content_size
        )

        style_losses = []
        for style_layer in STYLE_LAYERS:
            layer = net[style_layer]
            bs, height, width, filters = [i.value for i in layer.get_shape()]
            size = height * width * filters
            feats = tf.reshape(layer, (bs, height * width, filters))
            feats_T = tf.transpose(feats, perm=[0,2,1])
            grams = tf.matmul(feats_T, feats) / size
            style_gram = style_features[style_layer]
            style_losses.append(2 * tf.nn.l2_loss(grams - style_gram)/style_gram.size)

        style_loss = style_weight * reduce(tf.add, style_losses) / batch_size

        # total variation denoising
        tv_y_size = _tensor_size(preds[:,1:,:,:])
        tv_x_size = _tensor_size(preds[:,:,1:,:])
        y_tv = tf.nn.l2_loss(preds[:,1:,:,:] - preds[:,:batch_shape[1]-1,:,:])
        x_tv = tf.nn.l2_loss(preds[:,:,1:,:] - preds[:,:,:batch_shape[2]-1,:])
        tv_loss = tv_weight*2*(x_tv/tv_x_size + y_tv/tv_y_size)/batch_size

        loss = content_loss + style_loss + tv_loss
        print(loss)
        # overall loss
        train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss)
        sess.run(tf.initialize_all_variables())
        import random
        uid = random.randint(1, 100)
        print("UID: %s" % uid)
        for epoch in range(epochs):
            print(epoch)
            num_examples = len(content_targets)
            iterations = 0
            while iterations * batch_size < num_examples:
                start_time = time.time()
                curr = iterations * batch_size
                step = curr + batch_size
                X_batch = np.zeros(batch_shape, dtype=np.float32)
                for j, img_p in enumerate(content_targets[curr:step]):
                    X_batch[j] = get_img(img_p, (256,256,3)).astype(np.float32)

                iterations += 1
                print("assert")
                assert X_batch.shape[0] == batch_size

                feed_dict = {
                   X_content:X_batch
                }

                train_step.run(feed_dict=feed_dict)
                end_time = time.time()
                delta_time = end_time - start_time
                print("UID: %s, batch time: %s" % (uid, delta_time))
                is_print_iter = int(iterations) % print_iterations == 0
                is_last = epoch == epochs - 1 and iterations * batch_size >= num_examples
                should_print = is_print_iter or is_last
                if should_print:
                    to_get = [style_loss, content_loss, tv_loss, loss, preds]
                    test_feed_dict = {
                    X_content:X_batch
                    }
                    tup = sess.run(to_get, feed_dict = test_feed_dict)
                    _style_loss,_content_loss,_tv_loss,_loss,_preds = tup
                    losses = (_style_loss, _content_loss, _tv_loss, _loss)
                    saver = tf.train.Saver()
                    res = saver.save(sess, save_path)
                    yield(_preds, losses, iterations, epoch)

In [76]:
import sys
import transform, numpy as np, vgg, pdb, os
import scipy.misc
import tensorflow as tf
from collections import defaultdict
import time

In [77]:
BATCH_SIZE = 4
DEVICE = '/gpu:0'

# get img_shape
def ffwd(data_in, paths_out, checkpoint_dir, device_t='/gpu:0', batch_size=4):
    assert len(paths_out) > 0
    is_paths = type(data_in[0]) == str
    if is_paths:
        assert len(data_in) == len(paths_out)
        img_shape = get_img(data_in[0]).shape
    else:
        assert data_in.size[0] == len(paths_out)
        img_shape = X[0].shape

    g = tf.Graph()
    batch_size = min(len(paths_out), batch_size)
    curr_num = 0
    soft_config = tf.ConfigProto(allow_soft_placement=True)
    soft_config.gpu_options.allow_growth = True
    with g.as_default(), g.device(device_t), \
            tf.Session(config=soft_config) as sess:
        batch_shape = (batch_size,) + img_shape
        img_placeholder = tf.placeholder(tf.float32, shape=batch_shape,
                                         name='img_placeholder')

        preds = transform.net(img_placeholder)
        saver = tf.train.Saver()
        if os.path.isdir(checkpoint_dir):
            ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                raise Exception("No checkpoint found...")
        else:
            saver.restore(sess, checkpoint_dir)

        num_iters = int(len(paths_out)/batch_size)
        for i in range(num_iters):
            pos = i * batch_size
            curr_batch_out = paths_out[pos:pos+batch_size]
            if is_paths:
                curr_batch_in = data_in[pos:pos+batch_size]
                X = np.zeros(batch_shape, dtype=np.float32)
                for j, path_in in enumerate(curr_batch_in):
                    img = get_img(path_in)
                    assert img.shape == img_shape, \
                        'Images have different dimensions. ' +  \
                        'Resize images or use --allow-different-dimensions.'
                    X[j] = img
            else:
                X = data_in[pos:pos+batch_size]

            _preds = sess.run(preds, feed_dict={img_placeholder:X})
            for j, path_out in enumerate(curr_batch_out):
                save_img(path_out, _preds[j])
                
        remaining_in = data_in[num_iters*batch_size:]
        remaining_out = paths_out[num_iters*batch_size:]
    if len(remaining_in) > 0:
        ffwd(remaining_in, remaining_out, checkpoint_dir, 
            device_t=device_t, batch_size=1)

def ffwd_to_img(in_path, out_path, checkpoint_dir, device='/cpu:0'):
    paths_in, paths_out = [in_path], [out_path]
    ffwd(paths_in, paths_out, checkpoint_dir, batch_size=1, device_t=device)

def ffwd_different_dimensions(in_path, out_path, checkpoint_dir, 
            device_t=DEVICE, batch_size=4):
    in_path_of_shape = defaultdict(list)
    out_path_of_shape = defaultdict(list)
    for i in range(len(in_path)):
        in_image = in_path[i]
        out_image = out_path[i]
        shape = "%dx%dx%d" % get_img(in_image).shape
        in_path_of_shape[shape].append(in_image)
        out_path_of_shape[shape].append(out_image)
    for shape in in_path_of_shape:
        print(('Processing images of shape %s' % shape))
        ffwd(in_path_of_shape[shape], out_path_of_shape[shape], 
            checkpoint_dir, device_t, batch_size)


def check_opts(opts):
    exists(opts.checkpoint_dir, 'Checkpoint not found!')
    exists(opts.in_path, 'In path not found!')
    if os.path.isdir(opts.out_path):
        exists(opts.out_path, 'out dir not found!')
        assert opts.batch_size > 0

In [78]:
DEVICE = '/gpu:0'
FRAC_GPU = 1

In [79]:
def _get_files(img_dir):
    files = list_files(img_dir)
    return [os.path.join(img_dir,x) for x in files]

In [86]:
def style(styleImg, checkpointDir, vggPath, trainPath, testImg, testResDir, contentWeight, styleWeight, tvWeight, learningRate, checkPointIterations, epochs, batch_size):
    style_target = get_img(styleImg)
    content_targets = _get_files(trainPath)

    kwargs = {
        "epochs":epochs,
        "print_iterations":checkPointIterations,
        "batch_size":batch_size,
        "save_path":os.path.join(checkpointDir,'fns.ckpt'),
        "learning_rate":learningRate
    }


    args = [
        content_targets,
        style_target,
        contentWeight,
        styleWeight,
        tvWeight,
        vggPath
    ]

    for preds, losses, i, epoch in optimize(*args, **kwargs):
        style_loss, content_loss, tv_loss, loss = losses
        print('Epoch %d, Iteration: %d, Loss: %s' % (epoch, i, loss))
        to_print = (style_loss, content_loss, tv_loss)
        print('style: %s, content:%s, tv: %s' % to_print)
        preds_path = '%s/%s_%s.png' % (testResDir,epoch,i)
        ckpt_dir = os.path.dirname(checkpointDir)
        ffwd_to_img(testImg,preds_path,
                                     checkpointDir)
    ckpt_dir = checkpointDir
    # cmd_text = 'python evaluate.py --checkpoint-dir %s ...' % ckpt_dir
    print("Training complete.")


In [87]:
def passArgument():
    styleImg = input("Enter style image path")
    exists(styleImg, "style path not found!")
    checkpointDir = input("Enter checkpointdir path")
    exists(checkpointDir, "checkpoint dir not found!")
    vggPath = input("Enter path to vgg network file")
    exists(vggPath, "vgg network data not found!")
    trainPath = input("Enter path to training images")
    exists(trainPath, "train path not found!")
    testImg = input("Enter path to test images")
    exists(testImg, "test image not found")
    testResDir = input("Enter path to test result directory")
    exists(testResDir, "test result directory not found")
    contentWeight = input("Enter d for default contentWeight")
    if contentWeight == "d":
        contentWeight = 7.5e0
    else:
        contentWeight = float(contentWeight)
    styleWeight = input("Enter d for default styleWeight")
    if styleWeight == "d":
        styleWeight = 1e2
    else:
        styleWeight = float(styleWeight)
    tvWeight = input("Enter d for default tvWeight: ")
    if tvWeight == "d":
        tvWeight = 2e2
    else:
        tvWeight = float(tvWeight)
    learningRate = input("Enter d for default learning rate: ")
    if learningRate == "d":
        learningRate = 1e-3
    else:
        learningRate = float(learningRate)
    checkPointIterations = input("Enter d for default checkPointIterations: ")
    if checkPointIterations == "d":
        checkPointIterations = 1000
    else:
        checkPointIterations = int(checkPointIterations)
    epochs = input("Enter d for default epoch count: ")
    if epochs == "d":
        epochs = 2
    else:
        epochs = int(epochs)
    batch_size = input("Enter d for default batch_size: ")
    if batch_size == "d":
        batch_size = 20
    else:
        batch_size = int(batch_size)
    style(styleImg, checkpointDir, vggPath, trainPath, testImg, testResDir, contentWeight, styleWeight, tvWeight, learningRate, checkPointIterations, epochs, batch_size)
     
  

In [88]:
passArgument()

Enter style image pathstarry-night.jpg
Enter checkpointdir pathcheckpoint1
Enter path to vgg network fileimagenet-vgg-verydeep-19.mat
Enter path to training imagesdata/images
Enter path to test imagesdata/test/COCO_train2014_000000014988.jpg
Enter path to test result directorydata/test1
Enter d for default contentWeightd
Enter d for default styleWeightd
Enter d for default tvWeight: d
Enter d for default learning rate: d
Enter d for default checkPointIterations: d
Enter d for default epoch count: d
Enter d for default batch_size: d


`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
  


Train set has been trimmed slightly..
(1, 600, 800, 3)
imagenet-vgg-verydeep-19.mat
Tensor("add_44:0", shape=(), dtype=float32)
UID: 50
0
assert


`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
  
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.


UID: 50, batch time: 27.898560762405396
assert
UID: 50, batch time: 26.635885000228882
1
assert
UID: 50, batch time: 26.367039442062378
assert
UID: 50, batch time: 26.210173845291138
Epoch 1, Iteration: 2, Loss: 49050156.0
style: 39907016.0, content:3820372.0, tv: 5322767.0
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from checkpoint1/fns.ckpt
Training complete.


`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.
  This is separate from the ipykernel package so we can avoid doing imports until
