In [2]:
import mxnet as mx
from mxnet import nd
import mxnet.gluon as gluon
from mxnet.gluon.model_zoo.vision import resnet18_v1

In [3]:
import time
import random
import os
import mxnet as mx
import numpy as np
np.set_printoptions(precision=2)
import argparse
import symbol

from skimage import io, transform, exposure, color

VGGPATH = '../vgg19-0000.params'
RESNETPATH= '../resnet18.params'
COCOPATH = '/home/ubuntu/data/train2014'

try:
    os.mkdir(args.model_name)
    os.mkdir('%s/data'%args.model_name)
    os.mkdir('%s/output'%args.model_name)
except:
    pass

def postprocess_img(im):
    im = im[0]
    im[0,:] += 123.68
    im[1,:] += 116.779
    im[2,:] += 103.939
    im = np.swapaxes(im, 0, 2)
    im = np.swapaxes(im, 0, 1)
    im[im<0] = 0
    im[im>255] = 255
    return im.astype(np.uint8)

def crop_img(im, size):
    im = io.imread(im)
    if len(im.shape) == 2:
        im = color.gray2rgb(im)
    if im.shape[0]*size[1] > im.shape[1]*size[0]:
        c = (im.shape[0]-1.*im.shape[1]/size[1]*size[0]) / 2
        c = int(c)
        im = im[c:-(1+c),:,:]
    else:
        c = (im.shape[1]-1.*im.shape[0]/size[0]*size[1]) / 2
        c = int(c)
        im = im[:,c:-(1+c),:]
    im = transform.resize(im, size)
    im = exposure.equalize_adapthist(im, kernel_size=(16,16), clip_limit=0.01)
    im *= 255
    return im

def preprocess_img(im):
    im = im.astype(np.float32)
    im = np.swapaxes(im, 0, 2)
    im = np.swapaxes(im, 1, 2)
    im[0,:] -= 123.68
    im[1,:] -= 116.779
    im[2,:] -= 103.939
    im = np.expand_dims(im, 0)
    return im

def get_mrf_executor(layer, patch_shape):
    patch_size = patch_shape[-1]
    data = mx.sym.Variable('conv')
    weight = mx.sym.Variable('weight')
    dist = mx.sym.Convolution(data=data, weight=weight, kernel=(patch_size, patch_size), num_filter=patch_shape[0], no_bias=True)
    dist_executor = dist.bind(args={'conv': layer, 'weight': mx.nd.zeros(patch_shape, mx.gpu())}, ctx=mx.gpu())
    return dist_executor

def get_tv_grad_executor(img, ctx, tv_weight):
    nchannel = img.shape[1]
    simg = mx.sym.Variable("img")
    skernel = mx.sym.Variable("kernel")
    channels = mx.sym.SliceChannel(simg, num_outputs=nchannel)
    out = mx.sym.Concat(*[
        mx.sym.Convolution(data=channels[i], weight=skernel,
                           num_filter=1,
                           kernel=(3, 3), pad=(1,1),
                           no_bias=True, stride=(1,1))
        for i in range(nchannel)])
    kernel = mx.nd.array(np.array([[0, -1, 0],
                                   [-1, 4, -1],
                                   [0, -1, 0]])
                         .reshape((1, 1, 3, 3)),
                         ctx) / 8.0
    out = out * tv_weight
    return out.bind(ctx, args={"img": img,
                               "kernel": kernel})

In [27]:

parser = argparse.ArgumentParser(description='mrf neural style')

parser.add_argument('--style-image', type=str)
parser.add_argument('--content-weight', nargs='+', type=float)
parser.add_argument('--style-weight', nargs='+', type=float)
parser.add_argument('--tv-weight', type=float)
parser.add_argument('--num-image', type=int)
parser.add_argument('--epochs', type=int)
parser.add_argument('--style-size', nargs='+', type=int)
parser.add_argument('--lr', type=float)
parser.add_argument('--model-name', type=str)
parser.add_argument('--num-res', type=int)
parser.add_argument('--num-rotation', type=int)
parser.add_argument('--num-scale', type=int)
parser.add_argument('--stride', type=int)
parser.add_argument('--patch-size', type=int)

args = parser.parse_args("--style-image /home/ubuntu/vatsal/neural_style/images/bark.jpg --content-weight 1e-1 3e-1 3e-1 --style-weight 1 1 1 --tv-weight 1e-5 --num-image 2 --epochs 30 --style-size 512 512 --lr 3e3 --model-name ./bark_encoder --num-res 3 --num-rotation 2 --num-scale 2 --stride 4 --patch-size 3".split())

In [5]:
sym, arg_params, aux_params = mx.model.load_checkpoint('../resnet-18', 0)
mod = mx.mod.Module(symbol=sym, context=mx.gpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))], 
         label_shapes=mod._label_shapes)
mod.set_params(arg_params, aux_params, allow_missing=True)
fe_sym = all_layers['stage4_unit2_conv2_output']
arg_names = fe_sym.list_arguments()
print(arg_names)
arg_dict = {}
arg_params = {('%s' % k) : v.as_in_context(mx.gpu()) for k, v in arg_params.items()}
aux_params = {('%s' % k) : v.as_in_context(mx.gpu()) for k, v in aux_params.items()}
arg_dict = arg_params
#print(arg_params.keys())

'''
vgg_symbol = symbol.descriptor_symbol(args.num_res)
arg_names = vgg_symbol.list_arguments()
print(arg_names)
arg_dict = {}
pretrained = mx.nd.load(VGGPATH)
for name in arg_names:
    if name == "data":
        continue
    key = "arg:" + name
    if key in pretrained:
        arg_dict[name] = pretrained[key].copyto(mx.gpu())
del pretrained
'''
img = None
args.style_size[0] = args.style_size[0] // 4 * 4
args.style_size[1] = args.style_size[1] // 4 * 4
size = [512, 512]
args.style_size = args.style_size[::-1]
rotations = [15*i for i in range(-args.num_rotation, args.num_rotation+1)]
scales = [1.05**i for i in range(-args.num_scale, args.num_scale+1)]

# extract patches
style_img = crop_img(args.style_image, args.style_size) 
patches = [[] for i in range(args.num_res)]
patches_normed = []
for s in scales:
    scaled = transform.rescale(style_img, s)
    arg_dict['data'] = mx.nd.zeros([len(rotations),3,scaled.shape[0],scaled.shape[1]], mx.gpu())
    print(arg_dict['data'].shape)
    for r in range(len(rotations)):
        arg_dict['data'][r:r+1] = preprocess_img(transform.rotate(scaled, rotations[r], mode='reflect'))
    #print(arg_dict.keys())
    #vgg_executor = vgg_symbol.bind(ctx=mx.gpu(), args=arg_dict, grad_req='null')
    vgg_executor = fe_sym.bind(ctx=mx.gpu(), args=arg_dict, aux_states=aux_params, grad_req='null')
    vgg_executor.forward()
    for l in range(args.num_res):
        tmp = vgg_executor.outputs[l].asnumpy()
        for jj in range(0, 464-args.patch_size+1, args.stride):
            for r in range(len(rotations)):
                patches[l].append(tmp[r,jj:jj+args.patch_size])
for l in range(args.num_res):
    patches[l] = np.array(patches[l])
    tmp = np.linalg.norm(np.reshape(patches[l], [patches[l].shape[0], np.prod(patches[l].shape[1:])]), axis=1)
    norm = np.reshape(tmp, [tmp.shape[0],1])
    patches_normed.append(patches[l]/norm)
    patches[l] = mx.nd.array(patches[l], mx.gpu())

arg_dict['data'] = mx.nd.zeros([1,3,size[0],size[1]], mx.gpu())
grad_dict = {"data": arg_dict["data"].copyto(mx.gpu())}
#vgg_executor = vgg_symbol.bind(ctx=mx.gpu(), args=arg_dict, args_grad=grad_dict, grad_req='write')
vgg_executor = fe_sym.bind(ctx=mx.gpu(), args=arg_dict, args_grad=grad_dict, aux_states=aux_params, grad_req='write')
tv_grad_executor = get_tv_grad_executor(vgg_executor.arg_dict['data'], mx.gpu(), args.tv_weight) 
optimizer = mx.optimizer.SGD(learning_rate=args.lr, wd=0e-0, momentum=0.9)

# get mrf executor

NameError: name 'all_layers' is not defined

In [6]:
for i in fe_sym.list_outputs():
    print(i)

NameError: name 'fe_sym' is not defined

In [7]:
print(vgg_executor.outputs[0].shape)
print(scaled.shape)
print(arg_dict['data'].shape)
print(tmp.shape, len(rotations), args.patch_size)
print(preprocess_img(transform.rotate(scaled, rotations[r], mode='reflect')).shape)
norm = np.reshape(tmp, [tmp.shape[0],1])

NameError: name 'vgg_executor' is not defined

In [8]:
mrf_executors = []
target_patch = []
for l in range(args.num_res):
    mrf_executors.append(get_mrf_executor(vgg_executor.outputs[l], patches[l].shape))
    mrf_executors[l].arg_dict['weight'][:] = patches_normed[l]
# get assign executor
pcs = []
ass_executors = []
nns = []
for l in range(args.num_res):
    pc = np.zeros(vgg_executor.outputs[l].shape)
    for i1 in range(0, vgg_executor.outputs[l].shape[2]-args.patch_size+1):
        for i2 in range(0, vgg_executor.outputs[l].shape[3]-args.patch_size+1):
            pc[0,:,i1:i1+args.patch_size,i2:i2+args.patch_size] += 1
    pc = mx.nd.array(pc, mx.gpu())
    nn = mx.nd.zeros([vgg_executor.outputs[l].shape[2]-args.patch_size+1, vgg_executor.outputs[l].shape[3]-args.patch_size+1], mx.gpu())
    assign_symbol = symbol.assign_symbol()
    assign_executor = assign_symbol.bind(args={'source':patches[l], 'nn':nn}, ctx=mx.gpu())
    assign_executor.forward()
    pcs.append(pc)
    ass_executors.append(assign_executor)
    nns.append(nn)

img_list = os.listdir(COCOPATH)

NameError: name 'vgg_executor' is not defined

In [9]:
l = 0
print(vgg_executor.outputs[0].shape)
print(patches[l].shape)
get_mrf_executor(vgg_executor.outputs[l], patches[l].shape)

NameError: name 'vgg_executor' is not defined

In [10]:
for idx in range(args.num_image):
    #   break
    optim_state = optimizer.create_state(0, arg_dict['data'])
    selected = np.random.randint(0, len(img_list))
    img = crop_img(os.path.join(COCOPATH, img_list[selected]), size)
    io.imsave('%s/data/image%dx.jpg'%(args.model_name, idx), img/255)
    img = preprocess_img(img)
    vgg_executor.arg_dict['data'][:] = img
    vgg_executor.forward()
    original_content = []
    for l in range(args.num_res):
        original_content.append(vgg_executor.outputs[l].copyto(mx.gpu()))
    for epoch in range(args.epochs):
        vgg_executor.forward(is_train=True)
        print(vgg_executor)
        if epoch % 10 == 0:
            mrf_executors[l].forward()
            nns[l][:] = mx.nd.argmax_channel(mrf_executors[l].outputs[0])[0]
            ass_executors[l].outputs[0][:] = 0
            ass_executors[l].forward()
            ass_executors[l].outputs[0][:] /= pcs[l]
            # compute target layer
            ass_executors[l].outputs[0][:] *= args.style_weight[l]
            ass_executors[l].outputs[0][:] += args.content_weight[l]*original_content[l]
            ass_executors[l].outputs[0][:] *= 1./(args.style_weight[l]+args.content_weight[l])
        tv_grad_executor.forward()
        if epoch > args.epochs - 30:
            for l in range(1,args.num_res):
                vgg_executor.outputs[l][:] = 0
            vgg_executor.outputs[0][:] -= ass_executors[0].outputs[0] # grad
            vgg_executor.outputs[0][:] *= (args.style_weight[0]+args.content_weight[0]) / np.prod(vgg_executor.outputs[0].shape)
        else:
            for l in range(args.num_res):
                vgg_executor.outputs[l][:] -= ass_executors[l].outputs[0] # grad
                vgg_executor.outputs[l][:] *= (args.style_weight[l]+args.content_weight[l]) / np.prod(vgg_executor.outputs[l].shape)

        vgg_executor.backward(vgg_executor.outputs)
        optimizer.update(0, vgg_executor.arg_dict['data'], vgg_executor.grad_dict['data']+tv_grad_executor.outputs[0], optim_state)

    img = postprocess_img(vgg_executor.arg_dict['data'].asnumpy())
    io.imsave('%s/data/image%d.jpg'%(args.model_name, idx), img)


NameError: name 'optimizer' is not defined

In [None]:
A = resnet18_v1(pretrained=True)
A.save_params('/home/ubuntu/vatsal/neural_style/fast_mrf_cnn/')


In [13]:
A = mx.nd.load('../resnet18.params')
print(A)

{'stage2_conv1_weight': 
[[[[-0.01 -0.01  0.  ]
   [-0.01  0.03  0.05]
   [-0.03  0.01  0.01]]

  [[ 0.04  0.03  0.01]
   [ 0.    0.02  0.03]
   [ 0.04  0.04  0.01]]

  [[ 0.02 -0.02 -0.04]
   [-0.02 -0.02 -0.06]
   [-0.02 -0.01  0.  ]]

  ..., 
  [[-0.01  0.01 -0.  ]
   [ 0.01  0.01 -0.  ]
   [ 0.01  0.    0.  ]]

  [[ 0.01  0.01  0.  ]
   [-0.01 -0.01 -0.01]
   [-0.    0.02  0.01]]

  [[ 0.01 -0.03 -0.01]
   [ 0.03  0.03 -0.02]
   [ 0.01  0.   -0.01]]]


 [[[ 0.01  0.02 -0.  ]
   [-0.02 -0.01  0.04]
   [-0.    0.02  0.02]]

  [[-0.01 -0.02 -0.01]
   [-0.01  0.04  0.01]
   [ 0.03  0.02  0.02]]

  [[-0.02 -0.02  0.02]
   [-0.03 -0.06 -0.07]
   [ 0.02  0.05  0.02]]

  ..., 
  [[-0.03 -0.04 -0.04]
   [-0.06 -0.07 -0.1 ]
   [-0.06 -0.08 -0.07]]

  [[ 0.03  0.03  0.03]
   [ 0.01  0.05  0.01]
   [-0.    0.02  0.01]]

  [[ 0.   -0.02 -0.  ]
   [-0.01 -0.03 -0.05]
   [-0.01 -0.   -0.04]]]


 [[[ 0.    0.01 -0.01]
   [ 0.02 -0.09 -0.08]
   [-0.01 -0.05 -0.01]]

  [[ 0.01  0.02 -0.05]
   [ 0.01

In [14]:
B = mx.nd.load('../vgg19.params')

In [16]:
C = resnet18_v1(pretrained=True)

In [20]:
C.hybridize()

In [22]:
C.collect_params()

resnetv11_ (
  Parameter resnetv11_conv0_weight (shape=(64, 3, 7, 7), dtype=<class 'numpy.float32'>)
  Parameter resnetv11_batchnorm0_gamma (shape=(0,), dtype=<class 'numpy.float32'>)
  Parameter resnetv11_batchnorm0_beta (shape=(0,), dtype=<class 'numpy.float32'>)
  Parameter resnetv11_batchnorm0_running_mean (shape=(0,), dtype=<class 'numpy.float32'>)
  Parameter resnetv11_batchnorm0_running_var (shape=(0,), dtype=<class 'numpy.float32'>)
  Parameter resnetv11_stage1_conv0_weight (shape=(64, 64, 3, 3), dtype=<class 'numpy.float32'>)
  Parameter resnetv11_stage1_batchnorm0_gamma (shape=(0,), dtype=<class 'numpy.float32'>)
  Parameter resnetv11_stage1_batchnorm0_beta (shape=(0,), dtype=<class 'numpy.float32'>)
  Parameter resnetv11_stage1_batchnorm0_running_mean (shape=(0,), dtype=<class 'numpy.float32'>)
  Parameter resnetv11_stage1_batchnorm0_running_var (shape=(0,), dtype=<class 'numpy.float32'>)
  Parameter resnetv11_stage1_conv1_weight (shape=(64, 64, 3, 3), dtype=<class 'numpy.fl

In [23]:
convs = ['conv0', 'conv1', 'conv2', 'conv3', 'conv4']

In [25]:
convs[:4]

['conv0', 'conv1', 'conv2', 'conv3']

In [32]:
import symbol_resnet_final as symbol
VGGPATH = '../vgg19-0000.params'
RESNETPATH= '../resnet18.params'
COCOPATH = '/home/ubuntu/data/train2014'

try:
    os.mkdir(args.model_name)
    os.mkdir('%s/data'%args.model_name)
    os.mkdir('%s/output'%args.model_name)
except:
    pass

def postprocess_img(im):
    im = im[0]
    im[0,:] += 123.68
    im[1,:] += 116.779
    im[2,:] += 103.939
    im = np.swapaxes(im, 0, 2)
    im = np.swapaxes(im, 0, 1)
    im[im<0] = 0
    im[im>255] = 255
    return im.astype(np.uint8)

def crop_img(im, size):
    im = io.imread(im)
    if len(im.shape) == 2:
        im = color.gray2rgb(im)
    if im.shape[0]*size[1] > im.shape[1]*size[0]:
        c = (im.shape[0]-1.*im.shape[1]/size[1]*size[0]) / 2
        c = int(c)
        im = im[c:-(1+c),:,:]
    else:
        c = (im.shape[1]-1.*im.shape[0]/size[0]*size[1]) / 2
        c = int(c)
        im = im[:,c:-(1+c),:]
    im = transform.resize(im, size)
    im = exposure.equalize_adapthist(im, kernel_size=(16,16), clip_limit=0.01)
    im *= 255
    return im

def preprocess_img(im):
    im = im.astype(np.float32)
    im = np.swapaxes(im, 0, 2)
    im = np.swapaxes(im, 1, 2)
    im[0,:] -= 123.68
    im[1,:] -= 116.779
    im[2,:] -= 103.939
    im = np.expand_dims(im, 0)
    return im

def get_mrf_executor(layer, patch_shape):
    patch_size = patch_shape[-1]
    data = mx.sym.Variable('conv')
    weight = mx.sym.Variable('weight')
    dist = mx.sym.Convolution(data=data, weight=weight, kernel=(patch_size, patch_size), num_filter=patch_shape[0], no_bias=True)
    dist_executor = dist.bind(args={'conv': layer, 'weight': mx.nd.zeros(patch_shape, mx.gpu())}, ctx=mx.gpu())
    return dist_executor

def get_tv_grad_executor(img, ctx, tv_weight):
    nchannel = img.shape[1]
    simg = mx.sym.Variable("img")
    skernel = mx.sym.Variable("kernel")
    channels = mx.sym.SliceChannel(simg, num_outputs=nchannel)
    out = mx.sym.Concat(*[
        mx.sym.Convolution(data=channels[i], weight=skernel,
                           num_filter=1,
                           kernel=(3, 3), pad=(1,1),
                           no_bias=True, stride=(1,1))
        for i in range(nchannel)])
    kernel = mx.nd.array(np.array([[0, -1, 0],
                                   [-1, 4, -1],
                                   [0, -1, 0]])
                         .reshape((1, 1, 3, 3)),
                         ctx) / 8.0
    out = out * tv_weight
    return out.bind(ctx, args={"img": img,
                               "kernel": kernel})


#sym, arg_params, aux_params = mx.model.load_checkpoint('../vgg19', 0000)

'''
def get_fine_tune_model(symbol, arg_params, num_classes, layer_name='flatten0'):
    """
    symbol: the pretrained network symbol
    arg_params: the argument parameters of the pretrained model
    num_classes: the number of classes for the fine-tune datasets
    layer_name: the layer name before the last fully-connected layer
    """
    all_layers = symbol.get_internals()
    net = all_layers
    net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc1')
    net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
    new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})
    return (net, new_args)

#(new_sym, new_args) = get_fine_tune_model(sym, arg_params, 1000)
#sym = mx.symbol.SoftmaxOutput(data=sym, name='prob')
#vgg_symbol = new_sym
#arg_names = vgg_symbol.list_arguments()
#print(arg_names)
#arg_dict = {}
#pretrained = mx.nd.load(VGGPATH)


sym, arg_params, aux_params = mx.model.load_checkpoint('../resnet-18', 0)
mod = mx.mod.Module(symbol=sym, context=mx.gpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))], 
         label_shapes=mod._label_shapes)
mod.set_params(arg_params, aux_params, allow_missing=True)
all_layers = sym.get_internals()
fe_sym = all_layers['stage4_unit2_conv2_output']
arg_names = fe_sym.list_arguments()
print(arg_names)
arg_dict = {}
arg_params = {('%s' % k) : v.as_in_context(mx.gpu()) for k, v in arg_params.items()}
aux_params = {('%s' % k) : v.as_in_context(mx.gpu()) for k, v in aux_params.items()}
arg_dict = arg_params
#print(arg_params.keys())
'''

vgg_symbol = symbol.descriptor_resnet_symbol(args.num_res)
arg_names = vgg_symbol.list_arguments()
arg_dict = {}
pretrained = mx.nd.load(RESNETPATH)
for name in arg_names:
    if name == "data":
        continue
    key = name
    if key in pretrained.keys():
        arg_dict[name] = pretrained[key].copyto(mx.gpu())
del pretrained
#print(len(arg_dict.keys()), len(arg_names))
#for i in arg_dict.keys():
#    print(i)
#    if i in arg_names:
#        print(i)
img = None
args.style_size[0] = args.style_size[0] // 4 * 4
args.style_size[1] = args.style_size[1] // 4 * 4
size = [512, 512]
args.style_size = args.style_size[::-1]
rotations = [15*i for i in range(-args.num_rotation, args.num_rotation+1)]
scales = [1.05**i for i in range(-args.num_scale, args.num_scale+1)]

# extract patches
style_img = crop_img(args.style_image, args.style_size) 
patches = [[] for i in range(args.num_res)]
patches_normed = []
for s in scales:
    scaled = transform.rescale(style_img, s)
    arg_dict['data'] = mx.nd.zeros([len(rotations),3,scaled.shape[0],scaled.shape[1]], mx.gpu())
    for r in range(len(rotations)):
        arg_dict['data'][r:r+1] = preprocess_img(transform.rotate(scaled, rotations[r], mode='reflect'))
    #print(arg_dict.keys())
    vgg_executor = vgg_symbol.bind(ctx=mx.gpu(), args=arg_dict, grad_req='null')
    #vgg_executor = fe_sym.bind(ctx=mx.gpu(), args=arg_dict, aux_states=aux_params, grad_req='null')
    vgg_executor.forward()
    for l in range(args.num_res):
        tmp = vgg_executor.outputs[l].asnumpy()
        #print(vgg_executor.outputs[l].shape[2])
        for ii in range(0, vgg_executor.outputs[l].shape[2]-args.patch_size+1, args.stride):
            for jj in range(0, vgg_executor.outputs[l].shape[3]-args.patch_size+1, args.stride):
                for r in range(len(rotations)):
                    patches[l].append(tmp[r,:,ii:ii+args.patch_size,jj:jj+args.patch_size])
for l in range(args.num_res):
    patches[l] = np.array(patches[l])
    tmp = np.linalg.norm(np.reshape(patches[l], [patches[l].shape[0], np.prod(patches[l].shape[1:])]), axis=1)
    norm = np.reshape(tmp, [tmp.shape[0],1,1,1])
    patches_normed.append(patches[l]/norm)
    patches[l] = mx.nd.array(patches[l], mx.gpu())

arg_dict['data'] = mx.nd.zeros([1,3,size[0],size[1]], mx.gpu())
grad_dict = {"data": arg_dict["data"].copyto(mx.gpu())}
vgg_executor = vgg_symbol.bind(ctx=mx.gpu(), args=arg_dict, args_grad=grad_dict, grad_req='write')
#vgg_executor = fe_sym.bind(ctx=mx.gpu(), args=arg_dict, args_grad=grad_dict, aux_states=aux_params, grad_req='write')
tv_grad_executor = get_tv_grad_executor(vgg_executor.arg_dict['data'], mx.gpu(), args.tv_weight) 
optimizer = mx.optimizer.SGD(learning_rate=args.lr, wd=0e-0, momentum=0.9)

# get mrf executor
mrf_executors = []
target_patch = []
for l in range(args.num_res):
    mrf_executors.append(get_mrf_executor(vgg_executor.outputs[l], patches[l].shape))
    mrf_executors[l].arg_dict['weight'][:] = patches_normed[l]
# get assign executor
pcs = []
ass_executors = []
nns = []
for l in range(args.num_res):
    pc = np.zeros(vgg_executor.outputs[l].shape)
    for i1 in range(0, vgg_executor.outputs[l].shape[2]-args.patch_size+1):
        for i2 in range(0, vgg_executor.outputs[l].shape[3]-args.patch_size+1):
            pc[0,:,i1:i1+args.patch_size,i2:i2+args.patch_size] += 1
    pc = mx.nd.array(pc, mx.gpu())
    nn = mx.nd.zeros([vgg_executor.outputs[l].shape[2]-args.patch_size+1, vgg_executor.outputs[l].shape[3]-args.patch_size+1], mx.gpu())
    assign_symbol = symbol.assign_symbol()
    assign_executor = assign_symbol.bind(args={'source':patches[l], 'nn':nn}, ctx=mx.gpu())
    assign_executor.forward()
    pcs.append(pc)
    ass_executors.append(assign_executor)
    nns.append(nn)

img_list = os.listdir(COCOPATH)
for idx in range(args.num_image):
#    break
    optim_state = optimizer.create_state(0, arg_dict['data'])
    selected = np.random.randint(0, len(img_list))
    img = crop_img(os.path.join(COCOPATH, img_list[selected]), size)
    io.imsave('%s/data/image%dx.jpg'%(args.model_name, idx), img/255)
    img = preprocess_img(img)
    vgg_executor.arg_dict['data'][:] = img
    vgg_executor.forward()
    original_content = []
    for l in range(args.num_res):
        original_content.append(vgg_executor.outputs[l].copyto(mx.gpu()))
    for epoch in range(args.epochs):
        vgg_executor.forward(is_train=True)
        if epoch % 10 == 0:
            for l in range(args.num_res):
                mrf_executors[l].forward()
                nns[l][:] = mx.nd.argmax_channel(mrf_executors[l].outputs[0])[0]
                ass_executors[l].outputs[0][:] = 0
                ass_executors[l].forward()
                ass_executors[l].outputs[0][:] /= pcs[l]
                # compute target layer
                ass_executors[l].outputs[0][:] *= args.style_weight[l]
                ass_executors[l].outputs[0][:] += args.content_weight[l]*original_content[l]
                ass_executors[l].outputs[0][:] *= 1./(args.style_weight[l]+args.content_weight[l])
        tv_grad_executor.forward()
        if epoch > args.epochs - 30:
            for l in range(1,args.num_res):
                vgg_executor.outputs[l][:] = 0
            vgg_executor.outputs[0][:] -= ass_executors[0].outputs[0] # grad
            vgg_executor.outputs[0][:] *= (args.style_weight[0]+args.content_weight[0]) / np.prod(vgg_executor.outputs[0].shape)
        else:
            for l in range(args.num_res):
                vgg_executor.outputs[l][:] -= ass_executors[l].outputs[0] # grad
                vgg_executor.outputs[l][:] *= (args.style_weight[l]+args.content_weight[l]) / np.prod(vgg_executor.outputs[l].shape)

        vgg_executor.backward(vgg_executor.outputs)
        optimizer.update(0, vgg_executor.arg_dict['data'], vgg_executor.grad_dict['data']+tv_grad_executor.outputs[0], optim_state)

    img = postprocess_img(vgg_executor.arg_dict['data'].asnumpy())
    io.imsave('%s/data/image%d.jpg'%(args.model_name, idx), img)

# Train a generative network
vgg_symbol = symbol.descriptor_resnet_symbol(1)
vgg_executor = vgg_symbol.bind(ctx=mx.gpu(), args=arg_dict, args_grad=grad_dict, grad_req='write')
decoder = symbol.decoder_symbol()
arg_shapes, output_shapes, aux_shapes = decoder.infer_shape(data=vgg_executor.outputs[0].shape)
arg_names = decoder.list_arguments()
arg_dict = dict(zip(arg_names, [mx.nd.zeros(shape, ctx=mx.gpu()) for shape in arg_shapes]))
aux_names = decoder.list_auxiliary_states()
aux_dict = dict(zip(aux_names, [mx.nd.zeros(shape, ctx=mx.gpu()) for shape in aux_shapes]))
grad_dict = {}
for k in arg_dict:
    if k != 'data':
        grad_dict[k] = arg_dict[k].copyto(mx.gpu())
initializer = mx.init.Normal(1e-3)
for name in arg_names:
    if name != 'data':
        initializer(name, arg_dict[name])
deco_executor = decoder.bind(ctx=mx.gpu(), args=arg_dict, args_grad=grad_dict, aux_states=aux_dict, grad_req='write')



  warn("The default mode, 'constant', will be changed to 'reflect' in "
  .format(dtypeobj_in, dtypeobj_out))


ValueError: Length of aux_states does not match the number of arguments

In [31]:
print(vgg_symbol.list_auxiliary_states())

['batchnorm0_running_mean', 'batchnorm0_running_var', 'stage1_batchnorm0_running_mean', 'stage1_batchnorm0_running_var', 'stage1_batchnorm1_running_mean', 'stage1_batchnorm1_running_var', 'stage1_batchnorm2_running_mean', 'stage1_batchnorm2_running_var', 'stage1_batchnorm3_running_mean', 'stage1_batchnorm3_running_var', 'batchnorm0_moving_mean', 'batchnorm0_moving_var', 'stage2_batchnorm1_running_mean', 'stage2_batchnorm1_running_var', 'stage2_batchnorm2_running_mean', 'stage2_batchnorm2_running_var', 'stage2_batchnorm3_running_mean', 'stage2_batchnorm3_running_var', 'stage2_batchnorm4_running_mean', 'stage2_batchnorm4_running_var', 'batchnorm0_moving_mean', 'batchnorm0_moving_var', 'stage3_batchnorm1_running_mean', 'stage3_batchnorm1_running_var', 'stage3_batchnorm2_running_mean', 'stage3_batchnorm2_running_var', 'stage3_batchnorm3_running_mean', 'stage3_batchnorm3_running_var', 'stage3_batchnorm4_running_mean', 'stage3_batchnorm4_running_var']
