Skip to content

Commit

Permalink
ReduceMax, ReduceProdの対応を部分的に追加
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Apr 16, 2018
1 parent a842cf3 commit 47e63ad
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
10 changes: 10 additions & 0 deletions python/src/nnabla/utils/converter/onnx/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
"ReduceSum": "Sum",
"ReduceMean": "Mean",
"ReduceMin": "Min",
"ReduceMax": "Max",
"ReduceProd": "Prod",
# Constant does not get converted to a function
# but we list it here so we can accept it
"Constant": ""
Expand Down Expand Up @@ -643,6 +645,14 @@ def convert_to_functions(pb, network, node, base_name, initializers,
mp = func.min_param
set_reduction_attrs(mp, node)
func_list.append(func)
elif node.op_type == "ReduceMax":
mp = func.max_param
set_reduction_attrs(mp, node)
func_list.append(func)
elif node.op_type == "ReduceProd":
pp = func.prod_param
set_reduction_attrs(pp, node)
func_list.append(func)
else:
# Simply add the function for all other conversions
func_list.append(func)
Expand Down
17 changes: 16 additions & 1 deletion python/test/utils/conversion/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,12 +447,27 @@ def test_nnp_onnx_conversion_reduce_mean(tmpdir, nnp_fixture):
tmpdir, TEST_DATA_DIR, "reduce_mean.nnp", "reduce_mean.onnx", "out_data_1", "exec_0")


# These following tests are invalidated due to a
# backend bug? decribed in the following issue:
# https://github.com/Microsoft/CNTK/issues/3127
#def test_onnx_nnp_conversion_reduce_min(tmpdir, nnp_fixture):
# convert_onnx_to_nnp_and_compare(
# tmpdir, TEST_DATA_DIR, "reduce_min.onnx", "reduce_min.nnp",
# "ReduceElements7_Output_0", "exec_0",
# backend="cntk")

#
#
#def test_onnx_nnp_conversion_reduce_max(tmpdir, nnp_fixture):
# convert_onnx_to_nnp_and_compare(
# tmpdir, TEST_DATA_DIR, "reduce_max.onnx", "reduce_max.nnp",
# "ReduceElements7_Output_0", "exec_0",
# backend="cntk")
#
#def test_onnx_nnp_conversion_reduce_prod(tmpdir, nnp_fixture):
# convert_onnx_to_nnp_and_compare(
# tmpdir, TEST_DATA_DIR, "reduce_prod.onnx", "reduce_prod.nnp",
# "ReduceElements7_Output_0", "exec_0",
# backend="cntk")

def test_onnx_nnp_conversion_squeezenet(tmpdir, nnp_fixture):
img = np.random.rand(1, 3, 224, 224).astype(np.float32)
Expand Down

0 comments on commit 47e63ad

Please sign in to comment.