Skip to content

Commit

Permalink
padsの仕様を少し勘違いしていたので修正
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Apr 24, 2018
1 parent 6575fc1 commit ed02034
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
14 changes: 10 additions & 4 deletions python/src/nnabla/utils/converter/onnx/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ def set_kernel_parameter(node, kp):
if attr.type != AttributeProto.INTS:
raise ValueError("Only INTS are supported for pads in {}"
.format(node.op_type))
if len(attr.ints) > 4:
# pads with more than 4 (NCHW) dimension means
# it has a start and end specified.
# NNabla does not support different padding for start and end.
raise ValueError("NNabla does not support different padding for start and end of each axis")
pads.extend(attr.ints)
dims.append(len(pads))
elif attr.name == "kernel_shape":
Expand All @@ -119,20 +124,21 @@ def set_kernel_parameter(node, kp):
raise ValueError("Unsupported attribute {} was specified at {}"
.format(attr.name, node.op_type))
# NNabla requires for the dimensions of strides, pads, kernels to match.
# We align the dimensions for all three attributes to the shortest one
# We align the dimensions for all three attributes to the shortest one.
# We use the dimensions from the end (hence the negative dim).
dim = min(dims)
if strides:
kp.stride.dim.extend(strides[:dim])
kp.stride.dim.extend(strides[-dim:])
if pads:
kp.pad.dim.extend(pads[:dim])
kp.pad.dim.extend(pads[-dim:])
else:
# In case we don't have padding set,
# we set zero padding just in case NNabla does not set the
# default padding values correctly (such as in AveragePooling).
# This code should not be needed if NNabla handles default values correctly.
kp.pad.dim.extend([0]*dim)
if kernel:
kp.kernel.dim.extend(kernel[:dim])
kp.kernel.dim.extend(kernel[-dim:])


def update_function_counter(func_type, func_counter, count):
Expand Down
15 changes: 12 additions & 3 deletions python/test/utils/conversion/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@

TEST_DATA_DIR = "nnabla-sample-data/conversion_data"

def print_buffer_shape(net):
for k, v in net.functions.items():
out = v.outputs[0]
print(out.name, net.variables[out.name].variable_instance.shape)

def run_executor(nn_net, exec_name):
"""Run specified executor and return its network"""
Expand All @@ -51,10 +55,14 @@ def convert_onnx_to_nnp_and_compare(
if show_onnx:
print(model)
c2out = None
rep = oc2.prepare(model)
if type(in_img) is np.ndarray:
c2out = oc2.run_model(model, [in_img])
c2out = rep.run([in_img])
else:
c2out = oc2.run_model(model, [])
c2out = rep.run([])
#for k in rep.workspace.Blobs():
# v = rep.workspace.FetchBlob(k)
# print(k, v.shape)
backend_out = c2out[out_name]
elif backend == "cntk":
n = cntkf.Function.load(path, format=cntk.ModelFormat.ONNX)
Expand All @@ -80,13 +88,14 @@ def convert_onnx_to_nnp_and_compare(
p = os.path.join(str(nnpdir), nnp_name)
nnpex.export_nnp(p)
# read exported nnp and run network
# pdb.set_trace()
nn_net = nnload.load([p])
if type(in_img) is np.ndarray:
net = nn_net.executors[exec_name].network
in_data = net.variables[in_name]
in_data.variable_instance.d = in_img
# pdb.set_trace()
exe = run_executor(nn_net, exec_name)
#print_buffer_shape(exe)
# in_data = exe.variables["in_data_0"]
# print(in_data.variable_instance.d)
nnout = exe.variables[out_name].variable_instance.d
Expand Down

0 comments on commit ed02034

Please sign in to comment.