Skip to content

Commit

Permalink
Reshape対応を追加中
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Apr 6, 2018
1 parent d37b4a7 commit 9bc08d7
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 4 deletions.
53 changes: 49 additions & 4 deletions python/src/nnabla/utils/converter/onnx/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"Dropout": "Dropout",
"Softmax": "Softmax",
"BatchNormalization": "BatchNormalization",
"Reshape": "Reshape",
# optype with different names
"Relu": "ReLU",
"Concat": "Concatenate",
Expand Down Expand Up @@ -176,7 +177,8 @@ def generate_transpose(node_name, in_name, out_name, base_name, func_counter):
tp.axes.extend([1, 0]) # switch H and W
return trans

def convert_to_functions(pb, network, node, base_name, initializers, func_counter, param_vars, param_list):
def convert_to_functions(pb, network, node, base_name, initializers,
func_counter, param_vars, param_list, merged_inputs):
"""Convert given node to corresponding functions.
A node is usually converted to a single function,
but some nodes end up as a composition of functions,
Expand All @@ -189,7 +191,7 @@ def convert_to_functions(pb, network, node, base_name, initializers, func_counte
elif node.op_type == "GlobalAveragePool":
func_list.append(func)
elif node.op_type == "Concat":
# Concat axis was not required for old versions of Concat,
# Concat axis was not required for Concat-1 (it is required from Concat-4),
# so the default axis depended on which backend we use.
# Since we are comparing with caffe2, we are
# defaulting to the channel axis if the axis is not specified.
Expand Down Expand Up @@ -348,7 +350,7 @@ def convert_to_functions(pb, network, node, base_name, initializers, func_counte
raise ValueError("Only FLOAT is supported for momentum in BatchNormalization op_type")
bnp.decay_rate = attr.f
elif attr.name == "consumed_inputs":
# Old BatchNormalization has this field.
# BatchNormalization-1 has this field.
# Since NNabla does not need this, we ignore it
pass
else:
Expand Down Expand Up @@ -498,6 +500,40 @@ def convert_to_functions(pb, network, node, base_name, initializers, func_counte
param_list.append(v)
# We do not add any function to the list here
# since the node is converted as a parameter
elif node.op_type == "Reshape":
rp = func.reshape_param
shape_found = False
for attr in node.attribute:
if attr.name == "shape":
# Shape comes as attribute for Reshape-1
if attr.type != AttributeProto.INTS:
raise ValueError("Only INTS is supported for shape in {} op_type".format(node.op_type))
rp.shape.dim.extend(attr.ints)
shape_found = True
if len(func.input) == 2:
# Shape comes as input for Reshape-5.
# NNabla reshape excepts a single input (data),
# while Reshape-5 will have two inputs (data, shape),
# so we convert the shape input to a parameter
shape_input = func.input[1]
# look for the initializer for matching input
for init in initializers:
if init.name == shape_input:
if init.data_type != TensorProto.INT64:
raise ValueError("Only INT64 is supported for shape in {} op_type".format(node.op_type))
# copy shape size from initializer
if init.raw_data:
rp.shape.dim.extend(np.fromstring(init.raw_data, dtype=np.int64))
elif init.int64_data:
rp.shape.dim.extend(init.int64_data)
shape_found = True
break
# stored the merged input so we can igore it later
merged_inputs.append(shape_input)
del func.input[1]
if not shape_found:
raise ValueError("Shape information was not found in {} op_type".format(node.op_type))
func_list.append(func)
return func_list

def convert_parameter_shape(pb):
Expand Down Expand Up @@ -568,13 +604,14 @@ def onnx_graph_to_nnp_protobuf(pb, graph):
param_vars = OrderedDict() # Dictionary for input parameters.
all_vars = OrderedDict() # Dictionary for all variables
param_list = [] # list of parameter variables
merged_inputs = [] # list of input buffers that was merged to a function
func_counter = {} # a counter for all functions
# convert nodes
for n in graph.node:
check_domain(n.domain)
fl = convert_to_functions(pb, network,
n, graph.name, graph.initializer,
func_counter, param_vars, param_list)
func_counter, param_vars, param_list, merged_inputs)
# Gather all unique names for input and output
for f in fl:
for i in f.input:
Expand All @@ -585,6 +622,10 @@ def onnx_graph_to_nnp_protobuf(pb, graph):

# convert parameters
for init in graph.initializer:
if init.name in merged_inputs:
# Ignore any initializer that is already merged
# to a function node
continue
add_tensor_as_parameter(pb, init)
# Keep the list of all initializer names
param_vars[init.name] = None
Expand All @@ -600,6 +641,10 @@ def onnx_graph_to_nnp_protobuf(pb, graph):
in_list = [] # list of input variables
out_list = [] # list of output variables
for i in graph.input:
if i.name in merged_inputs:
# Ignore any input that is already merged
# to a function node
continue
if i.name in param_vars:
# This input is a parameter
v = add_value_info_as_parameter(network, i)
Expand Down
4 changes: 4 additions & 0 deletions python/test/utils/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ def test_onnx_nnp_conversion_constant(tmpdir, nnp_fixture):
convert_onnx_to_nnp_and_compare(
tmpdir, TEST_DATA_DIR, "constant.onnx", "constant.nnp", "Pooling33_Output_0", "exec_0")

def test_onnx_nnp_conversion_reshape(tmpdir, nnp_fixture):
convert_onnx_to_nnp_and_compare(
tmpdir, TEST_DATA_DIR, "reshape.onnx", "reshape.nnp", "out_data_1", "exec_0")

def test_onnx_nnp_conversion_squeezenet(tmpdir, nnp_fixture):
onnx_dir = TEST_DATA_DIR
onnx_name = "squeezenet.onnx"
Expand Down

0 comments on commit 9bc08d7

Please sign in to comment.