## The base code i get from
## Fine-tuning an ONNX model with MXNet/Gluon
## https://mxnet.apache.org/versions/1.5.0/tutorials/onnx/fine_tuning_gluon.html

In [1]:
import mxnet as mx
from mxnet import gluon, nd
import mxnet.contrib.onnx as onnx_mxnet
import numpy as np

ctx = mx.cpu()

In [2]:
print(mx.__version__)

1.6.0


In [3]:
# model from https://github.com/onnx/models/tree/master/vision/classification/resnet/model
model_file = 'resnet18-v1-7.onnx' 
sym, arg_params, aux_params = onnx_mxnet.import_model(model_file)

In [4]:
model_metadata = onnx_mxnet.get_model_metadata(model_file)
print(model_metadata)

{'input_tensor_data': [('data', (1, 3, 224, 224))], 'output_tensor_data': [('resnetv15_dense0_fwd', (1, 1000))]}


In [5]:
data_names = [inputs[0] for inputs in model_metadata.get('input_tensor_data')]
print(data_names)

['data']


# NOTE
### uncomment two rows in get_layer output() function (as it make in original) and when you try output = net(input_data) you look error:
### RuntimeError: Parameter 'resnetv15_batchnorm0_running_mean' has not been initialized...
#### I saw that in "Fine-tune with Pretrained Models" (https://mxnet.apache.org/versions/0.11.0/how_to/finetune.html) not using new_aux production and tried replace new_aux to aux_params directly. And it works!

In [6]:
def get_layer_output(symbol, arg_params, aux_params, layer_name):
    all_layers = symbol.get_internals()
    net = all_layers[layer_name+'_output']
    net = mx.symbol.Flatten(data=net)
    new_args = dict({k:arg_params[k] for k in arg_params if k in net.list_arguments()})
    ###new_aux = dict({k:aux_params[k] for k in aux_params if k in net.list_arguments()})
    ###return (net, new_args, new_aux)
    return (net, new_args, aux_params)

In [7]:
all_layers = sym.get_internals()
new_sym, new_arg_params, new_aux_params = get_layer_output(sym, arg_params, aux_params, 'flatten0') 

In [8]:
import warnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    pre_trained = gluon.nn.SymbolBlock(outputs=new_sym, inputs=mx.sym.var('data'))
net_params = pre_trained.collect_params()
for param in new_arg_params:
    if param in net_params:
        net_params[param]._load_init(new_arg_params[param], ctx=ctx)
for param in new_aux_params:
    if param in net_params:
        net_params[param]._load_init(new_aux_params[param], ctx=ctx)

In [9]:
net = gluon.nn.HybridSequential()
with net.name_scope():
    net.add(pre_trained)
    #net.add(... your own layers ...)

In [10]:
EDGE = 224
SIZE = (EDGE, EDGE)
def transform(image):
    resized = mx.image.resize_short(image, EDGE)
    cropped, crop_info = mx.image.center_crop(resized, SIZE)
    transposed = nd.transpose(cropped, (2,0,1))
    return transposed

In [11]:
import cv2
imgdir="./"
imgname="test.jpg"
img=cv2.imread("%s/%s" % (imgdir, imgname))
ndimg=mx.nd.array(img)
timg =transform(ndimg)


In [12]:
a = timg.expand_dims(0)
input_data = mx.nd.concat(a, a, dim=0)

In [13]:
np.shape(input_data.asnumpy())

(2, 3, 224, 224)

In [14]:
output = net(input_data)

In [15]:
np.shape(output.asnumpy())

(2, 512)

# Done

In [16]:
### Want to look model schema? Uncomment and run next row.
#mx.visualization.plot_network(new_sym,  node_attrs={"shape":"oval","fixedsize":"false"})