Skip to content

Commit

Permalink
broadcastの無いAdd対応し、Sumを使うときも入力数をチェックするよう修正
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Mar 15, 2018
1 parent 9ff5731 commit d8161c9
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/src/nnabla/utils/converter/onnx/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"GlobalAveragePooling": "GlobalAveragePool",
"MaxPooling": "MaxPool",
"AveragePooling": "AveragePool",
"Add2": "Sum",
"Add2": "Add",
# optype that gets converted
"Identity": "Dropout",
"Affine": "Gemm"
Expand Down
3 changes: 3 additions & 0 deletions python/src/nnabla/utils/converter/onnx/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,9 @@ def convert_to_functions(node, base_name, initializers, func_counter):
# update Gemm input with the converted inputs
del func.input[:]
func.input.extend(input)
elif node.op_type == "Sum":
if len(func.input) > 2:
raise ValueError("Sum operations with more than two input is currently not supported")
elif node.op_type == "Add":
# We need the input buffer's dimension information here
# in order to reshape the bias vector correctly.
Expand Down
8 changes: 6 additions & 2 deletions python/test/utils/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,13 @@ def test_nnp_onnx_conversion_gemm(tmpdir, nnp_fixture):
convert_nnp_to_onnx_and_compare(
tmpdir, TEST_DATA_DIR, "gemm.nnp", "gemm.onnx", "out_data_1", "exec_0")

def test_onnx_nnp_conversion_add(tmpdir, nnp_fixture):
def test_onnx_nnp_conversion_add_no_broadcast(tmpdir, nnp_fixture):
convert_onnx_to_nnp_and_compare(
tmpdir, TEST_DATA_DIR, "add.onnx", "add.nnp", "out_data_1", "exec_0")
tmpdir, TEST_DATA_DIR, "add_no_broadcast.onnx", "add_no_broadcast.nnp", "out_data_1", "exec_0")

def test_nnp_onnx_conversion_add_no_broadcast(tmpdir, nnp_fixture):
convert_nnp_to_onnx_and_compare(
tmpdir, TEST_DATA_DIR, "add_no_broadcast.nnp", "add_no_broadcast.onnx", "out_data_1", "exec_0")

def test_onnx_nnp_conversion_squeezenet(tmpdir, nnp_fixture):
onnx_dir = TEST_DATA_DIR
Expand Down

0 comments on commit d8161c9

Please sign in to comment.