Skip to content

Commit

Permalink
Orに対応
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Apr 26, 2018
1 parent 6941a3a commit 8f08001
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 16 deletions.
2 changes: 1 addition & 1 deletion python/src/nnabla/utils/converter/onnx/ONNXOpCoverage.md
Expand Up @@ -61,7 +61,7 @@ that indicates if each operator can be converted to NNP.
|Mul|:yellow_heart:|broadcast will be converted to a BroadcastTo|
|Neg|:black_heart:||
|Not|:green_heart:||
|Or|:black_heart:||
|Or|:yellow_heart:|broadcast will be converted to a BroadcastTo|
|PRelu|:black_heart:||
|Pad|:black_heart:||
|Pow|:black_heart:||
Expand Down
13 changes: 5 additions & 8 deletions python/src/nnabla/utils/converter/onnx/exporter.py
Expand Up @@ -49,6 +49,7 @@
"Mean": "ReduceMean",
"Mul2": "Mul",
"LogicalAnd": "And",
"LogicalOr": "Or",
# optype that gets converted
"Identity": "Dropout",
"Affine": "Gemm",
Expand Down Expand Up @@ -270,19 +271,15 @@ def convert_to_nodes(func, variables, input_types, output_types, broadcast_targe
bp = func.broadcast_to_param
broadcast_target[func.output[0]] = (func.input[1], bp.axis)
# we do not append node here because BroadcastTo should disappear
elif func.type == "Add2":
elif (func.type == "Add2" or
func.type == "Mul2"):
# Check if the second input is a brodcast target.
bt = func.input[1]
if bt in broadcast_target:
merge_broadcast(n, func, bt, broadcast_target)
nl.append(n)
elif func.type == "Mul2":
# Check if the second input is a brodcast target.
bt = func.input[1]
if bt in broadcast_target:
merge_broadcast(n, func, bt, broadcast_target)
nl.append(n)
elif func.type == "LogicalAnd":
elif (func.type == "LogicalAnd" or
func.type == "LogicalOr"):
# Store the input/output tensor's name and convert it to boolean
input_types[n.input[0]] = TensorProto.BOOL
output_types[n.output[0]] = TensorProto.BOOL
Expand Down
12 changes: 5 additions & 7 deletions python/src/nnabla/utils/converter/onnx/reader.py
Expand Up @@ -66,6 +66,7 @@
"ReduceMax": "Max",
"ReduceProd": "Prod",
"And": "LogicalAnd",
"Or": "LogicalOr",
# Constant does not get converted to a function
# but we list it here so we can accept it
"Constant": ""
Expand Down Expand Up @@ -519,13 +520,10 @@ def convert_to_functions(pb, network, node, base_name, initializers,
if len(func.input) > 2:
raise ValueError("Sum operations with more than two input is currently not supported")
func_list.append(func)
elif node.op_type == "Add":
convert_broadcasting_operator(func_list, node, func, base_name, func_counter)
func_list.append(func)
elif node.op_type == "Mul":
convert_broadcasting_operator(func_list, node, func, base_name, func_counter)
func_list.append(func)
elif node.op_type == "And":
elif (node.op_type == "Add" or
node.op_type == "Mul" or
node.op_type == "And" or
node.op_type == "Or"):
convert_broadcasting_operator(func_list, node, func, base_name, func_counter)
func_list.append(func)
elif node.op_type == "Constant":
Expand Down
27 changes: 27 additions & 0 deletions python/test/utils/conversion/test_conversion.py
Expand Up @@ -515,6 +515,33 @@ def test_nnp_onnx_conversion_and_broadcast_axis1(tmpdir, nnp_fixture):
"and_broadcast_axis1.onnx",
"out_data_1", "exec_0")


def test_onnx_nnp_conversion_or_no_broadcast(tmpdir, nnp_fixture):
convert_onnx_to_nnp_and_compare(tmpdir, TEST_DATA_DIR,
"or_no_broadcast.onnx",
"or_no_broadcast.nnp",
"out_data_1", "exec_0")


def test_nnp_onnx_conversion_or_no_broadcast(tmpdir, nnp_fixture):
convert_nnp_to_onnx_and_compare(tmpdir, TEST_DATA_DIR,
"or_no_broadcast.nnp",
"or_no_broadcast.onnx",
"out_data_1", "exec_0")


def test_onnx_nnp_conversion_or_broadcast_axis1(tmpdir, nnp_fixture):
convert_onnx_to_nnp_and_compare(tmpdir, TEST_DATA_DIR,
"or_broadcast_axis1.onnx",
"or_broadcast_axis1.nnp",
"out_data_1", "exec_0")


def test_nnp_onnx_conversion_or_broadcast_axis1(tmpdir, nnp_fixture):
convert_nnp_to_onnx_and_compare(tmpdir, TEST_DATA_DIR,
"or_broadcast_axis1.nnp",
"or_broadcast_axis1.onnx",
"out_data_1", "exec_0")
# These following tests are invalidated due to a
# backend bug? decribed in the following issue:
# https://github.com/Microsoft/CNTK/issues/3127
Expand Down

0 comments on commit 8f08001

Please sign in to comment.