In [None]:
import time
import random
import os
import mxnet as mx
import numpy as np
np.set_printoptions(precision=2)
import argparse
import symbol_resnet as symbol
import copy
from skimage import io, transform, exposure, color
from mxnet.gluon.model_zoo.vision import resnet18_v1
parser = argparse.ArgumentParser(description='mrf neural style')

parser.add_argument('--style-image', type=str)
parser.add_argument('--content-weight', nargs='+', type=float)
parser.add_argument('--style-weight', nargs='+', type=float)
parser.add_argument('--tv-weight', type=float)
parser.add_argument('--num-image', type=int)
parser.add_argument('--epochs', type=int)
parser.add_argument('--style-size', nargs='+', type=int)
parser.add_argument('--lr', type=float)
parser.add_argument('--model-name', type=str)
parser.add_argument('--num-res', type=int)
parser.add_argument('--num-rotation', type=int)
parser.add_argument('--num-scale', type=int)
parser.add_argument('--stride', type=int)
parser.add_argument('--patch-size', type=int)

args = parser.parse_args("--style-image /home/ubuntu/vatsal/neural_style/images/bark.jpg --content-weight 1e-1 3e-1 3e-1 --style-weight 1 1 1 --tv-weight 1e-5 --num-image 2 --epochs 2 --style-size 512 512 --lr 3e3 --model-name ./bark_encoder --num-res 4 --num-rotation 2 --num-scale 2 --stride 4 --patch-size 5".split())

VGGPATH = '../vgg19-0000.params'
RESNETPATH= '../resnet18.params'
COCOPATH = '/home/ubuntu/data/train2014'


In [2]:

try:
    os.mkdir(args.model_name)
    os.mkdir('%s/data'%args.model_name)
    os.mkdir('%s/output'%args.model_name)
except:
    pass

def postprocess_img(im):
    im = im[0]
    im[0,:] += 123.68
    im[1,:] += 116.779
    im[2,:] += 103.939
    im = np.swapaxes(im, 0, 2)
    im = np.swapaxes(im, 0, 1)
    im[im<0] = 0
    im[im>255] = 255
    return im.astype(np.uint8)

def crop_img(im, size):
    im = io.imread(im)
    if len(im.shape) == 2:
        im = color.gray2rgb(im)
    if im.shape[0]*size[1] > im.shape[1]*size[0]:
        c = (im.shape[0]-1.*im.shape[1]/size[1]*size[0]) / 2
        c = int(c)
        im = im[c:-(1+c),:,:]
    else:
        c = (im.shape[1]-1.*im.shape[0]/size[0]*size[1]) / 2
        c = int(c)
        im = im[:,c:-(1+c),:]
    im = transform.resize(im, size)
    im = exposure.equalize_adapthist(im, kernel_size=(16,16), clip_limit=0.01)
    im *= 255
    return im

def preprocess_img(im):
    im = im.astype(np.float32)
    im = np.swapaxes(im, 0, 2)
    im = np.swapaxes(im, 1, 2)
    im[0,:] -= 123.68
    im[1,:] -= 116.779
    im[2,:] -= 103.939
    im = np.expand_dims(im, 0)
    return im

def get_mrf_executor(layer, patch_shape):
    patch_size = patch_shape[-1]
    data = mx.sym.Variable('conv')
    weight = mx.sym.Variable('weight')
    dist = mx.sym.Convolution(data=data, weight=weight, kernel=(patch_size, patch_size), num_filter=patch_shape[0], no_bias=True)
    dist_executor = dist.bind(args={'conv': layer, 'weight': mx.nd.zeros(patch_shape, mx.cpu())}, ctx=mx.cpu())
    return dist_executor

def get_tv_grad_executor(img, ctx, tv_weight):
    nchannel = img.shape[1]
    simg = mx.sym.Variable("img")
    skernel = mx.sym.Variable("kernel")
    channels = mx.sym.SliceChannel(simg, num_outputs=nchannel)
    out = mx.sym.Concat(*[
        mx.sym.Convolution(data=channels[i], weight=skernel,
                           num_filter=1,
                           kernel=(3, 3), pad=(1,1),
                           no_bias=True, stride=(1,1))
        for i in range(nchannel)])
    kernel = mx.nd.array(np.array([[0, -1, 0],
                                   [-1, 4, -1],
                                   [0, -1, 0]])
                         .reshape((1, 1, 3, 3)),
                         ctx) / 8.0
    out = out * tv_weight
    return out.bind(ctx, args={"img": img,
                               "kernel": kernel})


In [3]:
resnet = resnet18_v1(pretrained=True)
resnet1 = resnet.features
data = mx.sym.Variable("data")
out = resnet1(data)
all_layers = out.get_internals()
all_layers.list_outputs()[5]
sym0 = all_layers[43]
sym1 = all_layers[84]
sym2 = all_layers[125]
sym3 = all_layers[166]
resnet1.save_params('resnet18.params')


pretrained = mx.nd.load('resnet18.params')

arg_names = resnet_symbol.list_arguments()
arg_dict = {}
length = len(resnet.name+'_')

for name in arg_names:
    if name == "data":
        continue
    key = name[length:]
    if key in pretrained.keys():
        arg_dict[name] = pretrained[key].copyto(mx.cpu())
        

aux_names = resnet_symbol.list_auxiliary_states()
aux_dict = {}
for name in aux_names:
    if name == "data":
        continue
    key = name[length:]
    if key in pretrained:
        aux_dict[name] = pretrained[key].copyto(mx.cpu())


img = None
args.style_size[0] = args.style_size[0] // 4 * 4
args.style_size[1] = args.style_size[1] // 4 * 4
size = [512, 512]
args.style_size = args.style_size[::-1]
rotations = [15*i for i in range(-args.num_rotation, args.num_rotation+1)]
scales = [1.05**i for i in range(-args.num_scale, args.num_scale+1)]



In [None]:
# extract patches
style_img = crop_img(args.style_image, args.style_size) 
patches = [[] for i in range(args.num_res)]
patches_normed = []
for s in scales:
    scaled = transform.rescale(style_img, s)
    #print("HHii", scaled.shape)
    arg_dict['data'] = mx.nd.zeros([len(rotations),3,scaled.shape[0],scaled.shape[1]], mx.cpu())
    for r in range(len(rotations)):
        arg_dict['data'][r:r+1] = preprocess_img(transform.rotate(scaled, rotations[r], mode='reflect'))
    #vgg_executor = vgg_symbol.bind(ctx=mx.cpu(), args=arg_dict, grad_req='null')
    #print(arg_dict['data'].shape)
    resnet_executor = resnet_symbol.bind(ctx=mx.cpu(), args=arg_dict, aux_states=aux_dict, grad_req='null')
    resnet_executor.forward()
    for l in range(args.num_res):
        tmp = resnet_executor.outputs[l].asnumpy()
        #print(vgg_executor.outputs[l].shape[2])
        for ii in range(0, resnet_executor.outputs[l].shape[2]-args.patch_size+1, args.stride):
            for jj in range(0, resnet_executor.outputs[l].shape[3]-args.patch_size+1, args.stride):
                for r in range(len(rotations)):
                    patches[l].append(tmp[r,:,ii:ii+args.patch_size,jj:jj+args.patch_size])
for l in range(args.num_res):
    patches[l] = np.array(patches[l])
    tmp = np.linalg.norm(np.reshape(patches[l], [patches[l].shape[0], np.prod(patches[l].shape[1:])]), axis=1)
    norm = np.reshape(tmp, [tmp.shape[0],1,1,1])
    patches_normed.append(patches[l]/norm)
    patches[l] = mx.nd.array(patches[l], mx.cpu())

#arg_dict['data'] = mx.nd.zeros([1,3,size[0],size[1]], mx.cpu())
#grad_dict = {"data": arg_dict["data"].copyto(mx.cpu())}
#vgg_executor = vgg_symbol.bind(ctx=mx.cpu(), args=arg_dict, args_grad=grad_dict, grad_req='write')
#vgg_executor = resnet_symbol.bind(ctx=mx.cpu(), args=arg_dict, args_grad=grad_dict, aux_states=aux_dict, grad_req='write')



In [None]:
print(resnet_executor.output_dict.keys())
#tv_grad_executor = get_tv_grad_executor(vgg_executor.arg_dict['data'], mx.cpu(), args.tv_weight) 
#tv_grad_executor = get_tv_grad_executor(arg_dict['data'], mx.cpu(), args.tv_weight) 
#optimizer = mx.optimizer.SGD(learning_rate=args.lr, wd=0e-0, momentum=0.9)


In [6]:
print(resnet_symbol.list_arguments())

['data', 'resnetv10_conv0_weight', 'resnetv10_batchnorm0_gamma', 'resnetv10_batchnorm0_beta', 'resnetv10_stage1_conv0_weight', 'resnetv10_stage1_batchnorm0_gamma', 'resnetv10_stage1_batchnorm0_beta', 'resnetv10_stage1_conv1_weight', 'resnetv10_stage1_batchnorm1_gamma', 'resnetv10_stage1_batchnorm1_beta', 'resnetv10_stage1_conv2_weight', 'resnetv10_stage1_batchnorm2_gamma', 'resnetv10_stage1_batchnorm2_beta', 'resnetv10_stage1_conv3_weight', 'resnetv10_stage1_batchnorm3_gamma', 'resnetv10_stage1_batchnorm3_beta', 'resnetv10_stage2_conv2_weight', 'resnetv10_stage2_batchnorm2_gamma', 'resnetv10_stage2_batchnorm2_beta', 'resnetv10_stage2_conv0_weight', 'resnetv10_stage2_batchnorm0_gamma', 'resnetv10_stage2_batchnorm0_beta', 'resnetv10_stage2_conv1_weight', 'resnetv10_stage2_batchnorm1_gamma', 'resnetv10_stage2_batchnorm1_beta', 'resnetv10_stage2_conv3_weight', 'resnetv10_stage2_batchnorm3_gamma', 'resnetv10_stage2_batchnorm3_beta', 'resnetv10_stage2_conv4_weight', 'resnetv10_stage2_batchno

In [4]:
print(arg_dict)

{'resnetv10_stage4_batchnorm3_beta': 
[-0.17 -0.3  -0.22 -0.29 -0.2  -0.23 -0.19 -0.19 -0.25 -0.16 -0.24 -0.25
 -0.19 -0.27 -0.33 -0.38 -0.3  -0.5  -0.03 -0.27 -0.2  -0.22 -0.29 -0.26
 -0.52 -0.2  -0.5  -0.27 -0.45 -0.4  -0.23 -0.18 -0.13 -0.27 -0.33 -0.32
 -0.33 -0.19 -0.07 -0.22 -0.21 -0.26 -0.16 -0.16 -0.21 -0.22 -0.25 -0.28
 -0.31 -0.2  -0.2  -0.02 -0.25 -0.42 -0.17 -0.19 -0.33 -0.15 -0.3  -0.23
 -0.32 -0.27 -0.16 -0.34 -0.41 -0.21 -0.06 -0.27 -0.28 -0.22 -0.21 -0.45
 -0.34 -0.32 -0.29 -0.2  -0.34 -0.21 -0.28 -0.29 -0.17 -0.22 -0.4  -0.25
 -0.21 -0.38 -0.38 -0.59 -0.16 -0.17 -0.22 -0.51 -0.25 -0.21 -0.43 -0.19
 -0.31 -0.26 -0.41 -0.06 -0.2  -0.13 -0.33 -0.26 -0.23 -0.09 -0.16 -0.29
 -0.33 -0.16 -0.51 -0.2  -0.17 -0.18 -0.23 -0.19 -0.34 -0.09 -0.31 -0.19
 -0.19 -0.03 -0.33 -0.17 -0.08 -0.24 -0.21 -0.39 -0.22 -0.29 -0.05 -0.31
 -0.07 -0.26 -0.22 -0.51 -0.29 -0.21 -0.31 -0.21 -0.19 -0.25 -0.55 -0.23
 -0.34 -0.35 -0.22 -0.36 -0.08 -0.15 -0.03 -0.14 -0.31 -0.31 -0.11 -0.27
 -0.27 -0.31 

In [10]:
resnet_symbol = mx.sym.Group([sym0, sym1, sym2, sym3])


In [11]:
resnet_symbol[:3]

<Symbol group [resnetv10_stage1_activation0, resnetv10_stage2_activation0, resnetv10_stage3_activation0]>