Skip to content

Commit

Permalink
Powに対応
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Apr 26, 2018
1 parent b9ede66 commit 34ebd6e
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/src/nnabla/utils/converter/onnx/ONNXOpCoverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ that indicates if each operator can be converted to NNP.
|Or|:yellow_heart:|broadcast will be converted to a BroadcastTo|
|PRelu|:black_heart:||
|Pad|:black_heart:||
|Pow|:black_heart:||
|Pow|:yellow_heart:|broadcast will be converted to a BroadcastTo|
|RNN|:black_heart:||
|RandomNormal|:black_heart:||
|RandomNormalLike|:black_heart:||
Expand Down
4 changes: 3 additions & 1 deletion python/src/nnabla/utils/converter/onnx/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"Mean": "ReduceMean",
"Mul2": "Mul",
"Div2": "Div",
"Pow2": "Pow",
"LogicalAnd": "And",
"LogicalOr": "Or",
"LogicalXor": "Xor",
Expand Down Expand Up @@ -275,7 +276,8 @@ def convert_to_nodes(func, variables, input_types, output_types, broadcast_targe
# we do not append node here because BroadcastTo should disappear
elif (func.type == "Add2" or
func.type == "Mul2" or
func.type == "Div2"):
func.type == "Div2" or
func.type == "Pow2"):
# Check if the second input is a brodcast target.
bt = func.input[1]
if bt in broadcast_target:
Expand Down
2 changes: 2 additions & 0 deletions python/src/nnabla/utils/converter/onnx/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"Add": "Add2",
"Mul": "Mul2",
"Div": "Div2",
"Pow": "Pow2",
"MatMul": "BatchMatmul",
"LeakyRelu": "LeakyReLU",
"Not": "LogicalNot",
Expand Down Expand Up @@ -525,6 +526,7 @@ def convert_to_functions(pb, network, node, base_name, initializers,
elif (node.op_type == "Add" or
node.op_type == "Mul" or
node.op_type == "Div" or
node.op_type == "Pow" or
node.op_type == "And" or
node.op_type == "Or" or
node.op_type == "Xor"):
Expand Down
29 changes: 29 additions & 0 deletions python/test/utils/conversion/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,35 @@ def test_nnp_onnx_conversion_div_broadcast_axis1(tmpdir, nnp_fixture):
"div_broadcast_axis1.nnp",
"div_broadcast_axis1.onnx",
"out_data_1", "exec_0")


def test_onnx_nnp_conversion_pow_no_broadcast(tmpdir, nnp_fixture):
convert_onnx_to_nnp_and_compare(tmpdir, TEST_DATA_DIR,
"pow_no_broadcast.onnx",
"pow_no_broadcast.nnp",
"out_data_1", "exec_0")


def test_nnp_onnx_conversion_pow_no_broadcast(tmpdir, nnp_fixture):
convert_nnp_to_onnx_and_compare(tmpdir, TEST_DATA_DIR,
"pow_no_broadcast.nnp",
"pow_no_broadcast.onnx",
"out_data_1", "exec_0")


def test_onnx_nnp_conversion_pow_broadcast_axis1(tmpdir, nnp_fixture):
convert_onnx_to_nnp_and_compare(tmpdir, TEST_DATA_DIR,
"pow_broadcast_axis1.onnx",
"pow_broadcast_axis1.nnp",
"out_data_1", "exec_0")


def test_nnp_onnx_conversion_pow_broadcast_axis1(tmpdir, nnp_fixture):
convert_nnp_to_onnx_and_compare(tmpdir, TEST_DATA_DIR,
"pow_broadcast_axis1.nnp",
"pow_broadcast_axis1.onnx",
"out_data_1", "exec_0")

# These following tests are invalidated due to a
# backend bug? decribed in the following issue:
# https://github.com/Microsoft/CNTK/issues/3127
Expand Down

0 comments on commit 34ebd6e

Please sign in to comment.