Skip to content

Commit

Permalink
padsの扱いが間違っていたのを修正
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Apr 24, 2018
1 parent ed02034 commit 2f5a87c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
6 changes: 3 additions & 3 deletions python/src/nnabla/utils/converter/onnx/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def convert_to_nodes(func, variables, input_types, output_types, broadcast_targe
# Copy kernel, stride, and pads values
k = onnx.helper.make_attribute("kernel_shape", mpp.kernel.dim)
s = onnx.helper.make_attribute("strides", mpp.stride.dim)
p = onnx.helper.make_attribute("pads", mpp.pad.dim)
p = onnx.helper.make_attribute("pads", np.repeat(mpp.pad.dim, 2))
n.attribute.extend([k, s, p])
nl.append(n)
elif func.type == "Convolution":
Expand All @@ -134,7 +134,7 @@ def convert_to_nodes(func, variables, input_types, output_types, broadcast_targe
k = onnx.helper.make_attribute("kernel_shape", weight_shape.dim[weight_base:])
d = onnx.helper.make_attribute("dilations", cp.dilation.dim)
s = onnx.helper.make_attribute("strides", cp.stride.dim)
p = onnx.helper.make_attribute("pads", cp.pad.dim)
p = onnx.helper.make_attribute("pads", np.repeat(cp.pad.dim, 2))
g = onnx.helper.make_attribute("group", cp.group)
n.attribute.extend([k, d, s, p, g])
nl.append(n)
Expand Down Expand Up @@ -166,7 +166,7 @@ def convert_to_nodes(func, variables, input_types, output_types, broadcast_targe
# Copy kernel, stride, and pads values
k = onnx.helper.make_attribute("kernel_shape", app.kernel.dim)
s = onnx.helper.make_attribute("strides", app.stride.dim)
p = onnx.helper.make_attribute("pads", app.pad.dim)
p = onnx.helper.make_attribute("pads", np.repeat(app.pad.dim, 2))
n.attribute.extend([k, s, p])
nl.append(n)
elif func.type == "BatchNormalization":
Expand Down
43 changes: 28 additions & 15 deletions python/src/nnabla/utils/converter/onnx/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,6 @@ def set_kernel_parameter(node, kp):
if attr.type != AttributeProto.INTS:
raise ValueError("Only INTS are supported for pads in {}"
.format(node.op_type))
if len(attr.ints) > 4:
# pads with more than 4 (NCHW) dimension means
# it has a start and end specified.
# NNabla does not support different padding for start and end.
raise ValueError("NNabla does not support different padding for start and end of each axis")
pads.extend(attr.ints)
dims.append(len(pads))
elif attr.name == "kernel_shape":
Expand All @@ -125,20 +120,21 @@ def set_kernel_parameter(node, kp):
.format(attr.name, node.op_type))
# NNabla requires for the dimensions of strides, pads, kernels to match.
# We align the dimensions for all three attributes to the shortest one.
# We use the dimensions from the end (hence the negative dim).
dim = min(dims)
if strides:
kp.stride.dim.extend(strides[-dim:])
kp.stride.dim.extend(strides[:])
if pads:
kp.pad.dim.extend(pads[-dim:])
padval = check_padding(pads, dim)
kp.pad.dim.extend(padval)
else:
# In case we don't have padding set,
# we set zero padding just in case NNabla does not set the
# default padding values correctly (such as in AveragePooling).
# This code should not be needed if NNabla handles default values correctly.
# This code should not be needed
# if NNabla handles default values correctly.
kp.pad.dim.extend([0]*dim)
if kernel:
kp.kernel.dim.extend(kernel[-dim:])
kp.kernel.dim.extend(kernel[:])


def update_function_counter(func_type, func_counter, count):
Expand Down Expand Up @@ -258,6 +254,22 @@ def set_reduction_attrs(p, node):
.format(attr.name, node.op_type))


def check_padding(pads, dim):
"""Check each padding start/end value
so that they match, becuase NNabla cannot set
different values for start/end per axis."""
padval = []
for i in range(dim):
ofs = i*2 # start and end for each axis
s = pads[ofs]
e = pads[ofs+1]
if s != e:
raise ValueError("NNabla does not support different padding"
" for start and end of each axis")
# If the values match, we set it as the padding for current axis
padval.append(s)
return padval

def convert_to_functions(pb, network, node, base_name, initializers,
func_counter, param_vars, param_list, merged_inputs,
removed_outputs):
Expand Down Expand Up @@ -374,19 +386,20 @@ def convert_to_functions(pb, network, node, base_name, initializers,
# We align the dimensions for all three attributes to the shortest one
dim = min(dims)
if strides:
cp.stride.dim.extend(strides[:dim])
cp.stride.dim.extend(strides[:])
if pads:
cp.pad.dim.extend(pads[:dim])
padval = check_padding(pads, dim)
cp.pad.dim.extend(padval)
else:
# Set default values.
# Do we really need this? (Default value should be set by NNabla)
cp.pad.dim.extend([0 for _ in range(dim)])
cp.pad.dim.extend([0]*dim)
if dilations:
cp.dilation.dim.extend(dilations[:dim])
cp.dilation.dim.extend(dilations[:])
else:
# Set default values.
# Do we really need this? (Default value should be set by NNabla)
cp.dilation.dim.extend([1 for _ in range(dim)])
cp.dilation.dim.extend([1]*dim)
func_list.append(func)
elif node.op_type == "MaxPool":
mpp = func.max_pooling_param
Expand Down

0 comments on commit 2f5a87c

Please sign in to comment.