Skip to content

Commit

Permalink
Xorに対応
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Apr 26, 2018
1 parent 174ba98 commit 5b1db87
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
4 changes: 3 additions & 1 deletion python/src/nnabla/utils/converter/onnx/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"Mul2": "Mul",
"LogicalAnd": "And",
"LogicalOr": "Or",
"LogicalXor": "Xor",
# optype that gets converted
"Identity": "Dropout",
"Affine": "Gemm",
Expand Down Expand Up @@ -279,7 +280,8 @@ def convert_to_nodes(func, variables, input_types, output_types, broadcast_targe
merge_broadcast(n, func, bt, broadcast_target)
nl.append(n)
elif (func.type == "LogicalAnd" or
func.type == "LogicalOr"):
func.type == "LogicalOr" or
func.type == "LogicalXor"):
# Store the input/output tensor's name and convert it to boolean
input_types[n.input[0]] = TensorProto.BOOL
output_types[n.output[0]] = TensorProto.BOOL
Expand Down
4 changes: 3 additions & 1 deletion python/src/nnabla/utils/converter/onnx/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"ReduceProd": "Prod",
"And": "LogicalAnd",
"Or": "LogicalOr",
"Xor": "LogicalXor",
# Constant does not get converted to a function
# but we list it here so we can accept it
"Constant": ""
Expand Down Expand Up @@ -523,7 +524,8 @@ def convert_to_functions(pb, network, node, base_name, initializers,
elif (node.op_type == "Add" or
node.op_type == "Mul" or
node.op_type == "And" or
node.op_type == "Or"):
node.op_type == "Or" or
node.op_type == "Xor"):
convert_broadcasting_operator(func_list, node, func, base_name, func_counter)
func_list.append(func)
elif node.op_type == "Constant":
Expand Down
29 changes: 29 additions & 0 deletions python/test/utils/conversion/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,35 @@ def test_nnp_onnx_conversion_or_broadcast_axis1(tmpdir, nnp_fixture):
"or_broadcast_axis1.nnp",
"or_broadcast_axis1.onnx",
"out_data_1", "exec_0")


def test_onnx_nnp_conversion_xor_no_broadcast(tmpdir, nnp_fixture):
convert_onnx_to_nnp_and_compare(tmpdir, TEST_DATA_DIR,
"xor_no_broadcast.onnx",
"xor_no_broadcast.nnp",
"out_data_1", "exec_0")


def test_nnp_onnx_conversion_xor_no_broadcast(tmpdir, nnp_fixture):
convert_nnp_to_onnx_and_compare(tmpdir, TEST_DATA_DIR,
"xor_no_broadcast.nnp",
"xor_no_broadcast.onnx",
"out_data_1", "exec_0")


def test_onnx_nnp_conversion_xor_broadcast_axis1(tmpdir, nnp_fixture):
convert_onnx_to_nnp_and_compare(tmpdir, TEST_DATA_DIR,
"xor_broadcast_axis1.onnx",
"xor_broadcast_axis1.nnp",
"out_data_1", "exec_0")


def test_nnp_onnx_conversion_or_broadcast_axis1(tmpdir, nnp_fixture):
convert_nnp_to_onnx_and_compare(tmpdir, TEST_DATA_DIR,
"or_broadcast_axis1.nnp",
"or_broadcast_axis1.onnx",
"out_data_1", "exec_0")

# These following tests are invalidated due to a
# backend bug? decribed in the following issue:
# https://github.com/Microsoft/CNTK/issues/3127
Expand Down

0 comments on commit 5b1db87

Please sign in to comment.