Skip to content

Commit

Permalink
MaxPoolのときはconstantでreplicateと同等の働きをするよう修正
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Jun 1, 2018
1 parent bc30df3 commit 4234b96
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
9 changes: 7 additions & 2 deletions python/src/nnabla/utils/converter/onnx/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def add_value_info_as_buffer(network, info):
return v


def set_kernel_parameter_and_add_padding(node, kp, base_name, func_counter):
def set_kernel_parameter_and_add_padding(node, kp,
pad_mode, pad_val,
base_name, func_counter):
"""Set kernel related parameters(strides, pads, kernel_shape) to the given
parameter. This function also generates a padding function if we need a
seperate pad function for asymmetry padding.
Expand Down Expand Up @@ -170,7 +172,7 @@ def set_kernel_parameter_and_add_padding(node, kp, base_name, func_counter):
ends = pads[half:]
pad_width = [j for i in zip(starts, ends) for j in i]
padf = generate_pad(node.name, input, padded,
"replicate", pad_width, 0,
pad_mode, pad_width, pad_val,
base_name, func_counter)
kp.pad.dim.extend(padval)
else:
Expand Down Expand Up @@ -595,7 +597,9 @@ def convert_to_functions(pb, network, node, base_name, initializers,
func_list.append(func)
elif node.op_type == "MaxPool":
mpp = func.max_pooling_param
# We simulate replicate mode by padding with negative infinite
padf = set_kernel_parameter_and_add_padding(node, mpp,
"constant", -np.inf,
base_name, func_counter)
if padf:
# append a pad function if we need asymmetry padding.
Expand All @@ -610,6 +614,7 @@ def convert_to_functions(pb, network, node, base_name, initializers,
elif node.op_type == "AveragePool":
app = func.average_pooling_param
padf = set_kernel_parameter_and_add_padding(node, app,
"replicate", 0,
base_name, func_counter)
if padf:
# append a pad function if we need asymmetry padding
Expand Down
2 changes: 2 additions & 0 deletions python/test/utils/conversion/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ def test_onnx_nnp_conversion_maxpool_p0_0_1_1_s1_k2(tmpdir, nnp_fixture):
"out_data_1", "exec_0")


# NNP to ONNX conversion for asymmetry maxpool padding
# currently ends in an unknown error at caffe2 backend
#def test_nnp_onnx_conversion_maxpool_p0_0_1_1_s1_k2(tmpdir, nnp_fixture):
# convert_nnp_to_onnx_and_compare(tmpdir, TEST_DATA_DIR,
# "maxpool_p0_0_1_1_s1_k2.nnp",
Expand Down

0 comments on commit 4234b96

Please sign in to comment.