Skip to content

Commit

Permalink
Gemm対応に向けて準備中
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Mar 13, 2018
1 parent 80f8d0c commit 1baf995
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 9 deletions.
68 changes: 59 additions & 9 deletions python/src/nnabla/utils/converter/onnx/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"MaxPool": "MaxPooling",
"AveragePool": "AveragePooling",
"Sum": "Add2",
"Gemm": "Affine"
}


Expand Down Expand Up @@ -100,8 +101,10 @@ def set_kernel_parameter(node, kp):
if kernel:
kp.kernel.dim.extend(kernel[:dim])

def convert_to_function(node, base_name, func_counter):
"""Convert given node to corresponding function"""

def generate_default_function(node, base_name, func_counter):
"""Generate a default function from the given node
"""
ft = onnx_optype_to_nnabla_function_type.get(node.op_type)
if ft is None:
raise ValueError("op_type {} is currently not supported for NNP conversion".format(node.op_type))
Expand Down Expand Up @@ -129,7 +132,16 @@ def convert_to_function(node, base_name, func_counter):
func_counter[func.type] = count+1
func.input.extend(node.input)
func.output.extend(node.output)
return func

def convert_to_functions(node, base_name, func_counter):
"""Convert given node to corresponding functions.
A node is usually converted to a single function,
but some nodes end up as a composition of functions.
"""
func_list = []
if node.op_type == "Concat":
func = generate_default_function(node, base_name, func_counter)
# Since concat axis is currently not required in ONNX,
# the default axis depends on which backend we use.
# For now we are comparing with caffe2, so we are
Expand All @@ -145,7 +157,9 @@ def convert_to_function(node, base_name, func_counter):
else:
raise ValueError("Unsupported attribute {} was specified at {}"
.format(attr.name, node.op_type))
func_list.append(func)
elif node.op_type == "Softmax":
func = generate_default_function(node, base_name, func_counter)
logger.warning(SOFTMAX_WARNING)
# default to channel axis
func.softmax_param.axis = DEFAULT_SOFTMAX_AXIS
Expand All @@ -157,7 +171,9 @@ def convert_to_function(node, base_name, func_counter):
else:
raise ValueError("Unsupported attribute {} was specified at {}"
.format(attr.name, node.op_type))
func_list.append(func)
elif node.op_type == "Dropout":
func = generate_default_function(node, base_name, func_counter)
# Dropout requires a ratio to be set
for attr in node.attribute:
if attr.name == "is_test":
Expand Down Expand Up @@ -185,7 +201,9 @@ def convert_to_function(node, base_name, func_counter):
else:
raise ValueError("Unsupported attribute {} was specified at {}"
.format(attr.name, node.op_type))
func_list.append(func)
elif node.op_type == "Conv":
func = generate_default_function(node, base_name, func_counter)
cp = func.convolution_param
# We shouldn't need these default settings
# since NNabla will set these for us
Expand Down Expand Up @@ -240,19 +258,25 @@ def convert_to_function(node, base_name, func_counter):
# Set default values.
# Do we really need this? (Default value should be set by NNabla)
cp.dilation.dim.extend([1 for _ in range(dim)])
func_list.append(func)
elif node.op_type == "MaxPool":
func = generate_default_function(node, base_name, func_counter)
mpp = func.max_pooling_param
set_kernel_parameter(node, mpp)
# Always ignore borders in order to match ONNX(caffe2) results?
# Not quite sure yet.
mpp.ignore_border = True
func_list.append(func)
elif node.op_type == "AveragePool":
func = generate_default_function(node, base_name, func_counter)
app = func.average_pooling_param
set_kernel_parameter(node, app)
# Always ignore borders in order to match ONNX(caffe2) results?
# Not quite sure yet.
app.ignore_border = True
func_list.append(func)
elif node.op_type == "BatchNormalization":
func = generate_default_function(node, base_name, func_counter)
# We need to rearrange the input data order.
# ONNX BatchNormalization input order: X, scale, bias, mean, variance
# NNabla BatchNormalization input order: X, beta, gamma, mean, variance
Expand Down Expand Up @@ -289,7 +313,32 @@ def convert_to_function(node, base_name, func_counter):
else:
raise ValueError("Unsupported attribute {} was specified at {}"
.format(attr.name, node.op_type))
return func
func_list.append(func)
elif node.op_type == "Gemm":
for attr in node.attribute:
if attr.name == "transA":
if attr.type != AttributeProto.INT:
raise ValueError("Only INT is supported for transA in {} op_type".format(node.op_type))
# We need to transpose the input weight beforehand
# since NNabla does not support transpose with Affine.
# Add a new intermediate buffer for transposition,
# and rewire the buffer as input.
elif attr.name == "transB":
if attr.type != AttributeProto.INT:
raise ValueError("Only INT is supported for transB in {} op_type".format(node.op_type))
# We need to transpose the input weight beforehand
# since NNabla does not support transpose with Affine.
# Add a new intermediate buffer for transposition,
# and rewire the buffer as input.
elif attr.name == "broadcast":
if attr.type != AttributeProto.INT:
raise ValueError("Only INT is supported for broadcast in {} op_type".format(node.op_type))
# Add a new intermediate buffer for broadcasting
# and rewire the buffer as input,
else:
raise ValueError("Unsupported attribute {} was specified at {}"
.format(attr.name, node.op_type))
return func_list

def convert_parameter_shape(pb):
"""Convert the shape of some parameters so they fit NNabla's requirements.
Expand Down Expand Up @@ -343,13 +392,14 @@ def onnx_graph_to_nnp_protobuf(pb, graph):
# We do not allow any operator from an unknown domain
if not (n.domain == '' or n.domain == NNABLA_DOMAIN):
raise ValueError("Unsupported operator from domain {} was found".format(n.domain))
f = convert_to_function(n, graph.name, func_counter)
fl = convert_to_functions(n, graph.name, func_counter)
# Gather all unique names for input and output
for i in f.input:
all_vars[i] = None
for o in f.output:
all_vars[o] = None
network.function.extend([f])
for f in fl:
for i in f.input:
all_vars[i] = None
for o in f.output:
all_vars[o] = None
network.function.extend(fl)

# convert parameters
# We use an OrderedDict and not a set
Expand Down
4 changes: 4 additions & 0 deletions python/test/utils/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ def test_nnp_onnx_conversion_batch_normalization(tmpdir, nnp_fixture):
convert_nnp_to_onnx_and_compare(
tmpdir, TEST_DATA_DIR, "batch_norm.nnp", "batch_norm.onnx", "out_data_1", "exec_0", atol=1e-05)

#def test_onnx_nnp_conversion_gemm(tmpdir, nnp_fixture):
# convert_onnx_to_nnp_and_compare(
# tmpdir, TEST_DATA_DIR, "gemm.onnx", "gemm.nnp", "out_data_1", "exec_0")

def test_onnx_nnp_conversion_squeezenet(tmpdir, nnp_fixture):
onnx_dir = TEST_DATA_DIR
onnx_name = "squeezenet.onnx"
Expand Down

0 comments on commit 1baf995

Please sign in to comment.