Skip to content

Commit

Permalink
broadcast無しのMulの変換に対応
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Mar 15, 2018
1 parent d8161c9 commit b125910
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
3 changes: 2 additions & 1 deletion python/src/nnabla/utils/converter/onnx/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
"Add2": "Add",
# optype that gets converted
"Identity": "Dropout",
"Affine": "Gemm"
"Affine": "Gemm",
"Mul2": "Mul"
}


Expand Down
32 changes: 31 additions & 1 deletion python/src/nnabla/utils/converter/onnx/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
"AveragePool": "AveragePooling",
"Sum": "Add2",
"Gemm": "Affine",
"Add": "Add2"
"Add": "Add2",
"Mul": "Mul2"
}


Expand Down Expand Up @@ -412,6 +413,35 @@ def convert_to_functions(node, base_name, initializers, func_counter):
raise ValueError("broadcasting is currently not supported for {}".format(node.op_type))
# Add2 broadcasts by default so we do nothing here
#pass
elif node.op_type == "Mul":
# We need the input buffer's dimension information here
# in order to reshape the bias vector correctly.
# Therefore we cannot support broadcasting unless we get an operator like ReshapeTo
# which allows reshaping without shape specification.
reshaped_postfix = "_reshaped"
input = node.input[:]
for attr in node.attribute:
if attr.name == "axis":
pass
## Reshape the input bias so it fits
## the specifed axis's broadcasted shape
#rin = node.input[1]
#rout = rin+reshaped_postfix
#rs = nnabla_pb2.Function()
#rs.type = "Reshape"
#set_function_name(rs, node.name, base_name, func_counter)
#rs.input.extend([rin])
#rs.output.extend([rout])
## Calculate the reshaped size for the bias.
## We calculate this from the input buffer's dimension and
## the specified axis.
#rp = rs.reshape_param
#rp.shape.dim.extend(reshaped)
#input[1] = rout # rewire input to reshaped input
elif attr.name == "broadcast":
raise ValueError("broadcasting is currently not supported for {}".format(node.op_type))
# Mul2 broadcasts by default so we do nothing here
#pass
func_list.append(func)
return func_list

Expand Down
8 changes: 8 additions & 0 deletions python/test/utils/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,14 @@ def test_nnp_onnx_conversion_add_no_broadcast(tmpdir, nnp_fixture):
convert_nnp_to_onnx_and_compare(
tmpdir, TEST_DATA_DIR, "add_no_broadcast.nnp", "add_no_broadcast.onnx", "out_data_1", "exec_0")

def test_onnx_nnp_conversion_mul_no_broadcast(tmpdir, nnp_fixture):
convert_onnx_to_nnp_and_compare(
tmpdir, TEST_DATA_DIR, "mul_no_broadcast.onnx", "mul_no_broadcast.nnp", "out_data_1", "exec_0")

def test_nnp_onnx_conversion_mul_no_broadcast(tmpdir, nnp_fixture):
convert_nnp_to_onnx_and_compare(
tmpdir, TEST_DATA_DIR, "mul_no_broadcast.nnp", "mul_no_broadcast.onnx", "out_data_1", "exec_0")

def test_onnx_nnp_conversion_squeezenet(tmpdir, nnp_fixture):
onnx_dir = TEST_DATA_DIR
onnx_name = "squeezenet.onnx"
Expand Down

0 comments on commit b125910

Please sign in to comment.