Skip to content

Commit

Permalink
[NNVM][TENSORFLOW] Mobilenet support. (apache#1335)
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 authored and tqchen committed Jun 26, 2018
1 parent dee005f commit 464fbdd
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 37 deletions.
52 changes: 35 additions & 17 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def __call__(self, inputs, attrs, *args):
self._ignores.append('use_cudnn_on_gpu')
self._ignores.append('_node_name')
self._ignores.append('is_training')
# Retain the names
try:
attrs['name'] = attrs['_node_name']
except KeyError:
pass
return AttrConvert(self._op_name, self._transforms, self._excludes,
self._disables, self._ignores, self._extras,
self._custom_check)(inputs, attrs, *args)
Expand Down Expand Up @@ -405,13 +410,19 @@ def _impl(inputs, attr, params):

def _reshape():
def _impl(inputs, attr, params):
pop_node = inputs.pop(1)
shape_arg = params[pop_node.list_output_names()[0]]
params.pop(pop_node.list_output_names()[0])
return AttrCvt(
op_name="reshape",
extras={'shape':tuple(shape_arg.asnumpy())},
ignores=['Tshape'])(inputs, attr)
try:
pop_node = inputs[1]
shape_arg = params.pop(pop_node.list_output_names()[0])
inputs.pop(1)

return AttrCvt(
op_name="reshape",
extras={'shape':tuple(shape_arg.asnumpy())},
ignores=['Tshape'])(inputs, attr)
except KeyError:
return AttrCvt(
op_name="reshape_like",
ignores=['Tshape'])(inputs, attr)
return _impl

def _bias_add():
Expand All @@ -427,6 +438,18 @@ def _impl(inputs, attr, params):
ignores=['T'])(inputs, attr)
return _impl

def _fused_batch_norm():
def _impl(inputs, attr, params):
# Tensorflow: (data, gamma, beta, moving_mean, moving_variance)
# NNVM: (data, gamma, beta, moving_mean, moving_varience)
return AttrCvt(
op_name='batch_norm',
transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'},
extras={'axis': 3}, # Fix axis
ignores=['data_format'],
disables=['momentum'])(inputs, attr)
return _impl

def _batch_norm():
def _impl(inputs, attr, params):
# Rearrange inputs from
Expand All @@ -445,19 +468,14 @@ def _impl(inputs, attr, params):

def _relu6():
def _impl(inputs, attr, params):
return _sym.clip(inputs[0], a_min=0, a_max=6)
return _sym.clip(inputs[0], a_min=0, a_max=6, name=attr['_node_name'])
return _impl

def _shape():
def _impl(inputs, attr, params):
input_shapes = attr['_input_shapes'][inputs[0]]

# Fix the -1 dimensions to 1
input_shapes[0] = [1 if x == -1 else x for x in input_shapes[0]]
params[attr['_node_name']] = tvm.nd.array(input_shapes[0])

return _sym.Variable(name=attr['_node_name'],
shape=params[attr['_node_name']].shape)
# Result of this operator is prominently used by reshape operator.
# Just pass the input as it is so that reshape_like can be used there.
return inputs[0]
return _impl

# compatible operators that do NOT require any conversion.
Expand Down Expand Up @@ -491,7 +509,7 @@ def _impl(inputs, attr, params):
'Add' : _elemwise('add'),
'Rsqrt' : _rsqrt(),
'Squeeze' : _squeeze(),
'FusedBatchNorm' : _batch_norm(),
'FusedBatchNorm' : _fused_batch_norm(),
'Relu6' : _relu6(),
'DepthwiseConv2dNative' : _depthwise_conv(),
'Shape' : _shape(),
Expand Down
70 changes: 50 additions & 20 deletions nnvm/python/nnvm/testing/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,35 @@ def read_normalized_tensor_from_image_file(file_name,
np_array = normalized.eval()
return np_array

def get_workload(model_path):
""" Import workload from frozen protobuf
Parameters
----------
model_path: str
model_path on remote repository to download from.
Returns
-------
graph_def: graphdef
graph_def is the tensorflow workload for mobilenet.
"""

repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/'
model_name = os.path.basename(model_path)
model_url = os.path.join(repo_base, model_path)

from mxnet.gluon.utils import download
download(model_url, model_name)

# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
return graph_def

def get_workload_inception_v3():
""" Import Inception V3 workload from frozen protobuf
Expand All @@ -168,23 +197,15 @@ def get_workload_inception_v3():
"""

repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV3/'
model_name = 'inception_v3_2016_08_28_frozen-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)
model_path = 'InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb'

image_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, image_name)

from mxnet.gluon.utils import download
download(model_url, model_name)
download(image_url, image_name)

normalized = read_normalized_tensor_from_image_file(os.path.join("./", image_name))

# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
return (normalized, graph_def)
return (normalized, get_workload(model_path))

def get_workload_inception_v1():
""" Import Inception V1 workload from frozen protobuf
Expand All @@ -203,13 +224,11 @@ def get_workload_inception_v1():
"""

repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
model_name = 'classify_image_graph_def-with_shapes.pb'
model_url = os.path.join(repo_base, model_name)
model_path = 'InceptionV1/classify_image_graph_def-with_shapes.pb'
image_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, image_name)

from mxnet.gluon.utils import download
download(model_url, model_name)
download(image_url, image_name)

if not tf.gfile.Exists(os.path.join("./", image_name)):
Expand All @@ -221,9 +240,20 @@ def get_workload_inception_v1():
tvm_data = Image.open(os.path.join("./", image_name)).resize((299, 299))
tvm_data = np.array(tvm_data)

# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
return (image_data, tvm_data, graph_def)
return (image_data, tvm_data, get_workload(model_path))

def get_workload_mobilenet():
""" Import mobilenet workload from frozen protobuf
Parameters
----------
Nothing.
Returns
-------
graph_def: graphdef
graph_def is the tensorflow workload for mobilenet.
"""

return get_workload("MobilenetV1/mobilenet_v1_1.0_224_frozen-with-shapes.pb")
24 changes: 24 additions & 0 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,29 @@ def test_forward_inception_v1():

np.testing.assert_allclose(tf_output, tvm_output, rtol=2e-2, atol=2e-2)

#######################################################################
# Mobilenet
# ---------
def test_forward_mobilenet():
'''test mobilenet model'''
with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload_mobilenet()
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)

data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
out_node = 'MobilenetV1/Predictions/Reshape_1'

with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')

out_shape = tf_output.shape
tvm_output = run_tvm_graph(graph_def, data, 'input', out_shape, 'float32')
top_tvm = np.squeeze(tvm_output).argsort()[-10:][::-1]
top_tf = np.squeeze(tf_output).argsort()[-10:][::-1]

np.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)

#######################################################################
# Main
# ----
Expand All @@ -419,3 +442,4 @@ def test_forward_inception_v1():
test_forward_multi_input()
test_forward_inception_v3()
test_forward_inception_v1()
test_forward_mobilenet()

0 comments on commit 464fbdd

Please sign in to comment.