Skip to content

Commit

Permalink
Constant対応に備えてリファクタ
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Apr 4, 2018
1 parent 6f249d8 commit 1971dac
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions python/src/nnabla/utils/converter/onnx/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
}


def add_value_info_as_variable(info, network):
def add_value_info_as_variable(network, info):
if not info.type.HasField("tensor_type"): # accepting only tensor
raise ValueError("Only TensorProto is allowed as ValueInfoProto's type for info.name (Got {})"
.format(info.name, info.type))
Expand All @@ -62,15 +62,15 @@ def add_value_info_as_variable(info, network):
v.shape.dim.extend([x.dim_value for x in t.shape.dim])
return v

def add_value_info_as_parameter(info, network):
v = add_value_info_as_variable(info, network)
def add_value_info_as_parameter(network, info):
v = add_value_info_as_variable(network, info)
v.type = "Parameter"
v.initializer.type = "Constant"
v.initializer.multiplier = 1.0
return v

def add_value_info_as_buffer(info, network):
v = add_value_info_as_variable(info, network)
def add_value_info_as_buffer(network, info):
v = add_value_info_as_variable(network, info)
v.type = "Buffer"
return v

Expand Down Expand Up @@ -623,11 +623,11 @@ def onnx_graph_to_nnp_protobuf(pb, graph):
for i in graph.input:
if i.name in param_vars:
# This input is a parameter
v = add_value_info_as_parameter(i, network)
v = add_value_info_as_parameter(network, i)
param_list.append(v)
else:
# This input is a buffer
v = add_value_info_as_buffer(i, network)
v = add_value_info_as_buffer(network, i)
in_list.append(v)
if i.name in all_vars:
del all_vars[i.name]
Expand All @@ -638,12 +638,17 @@ def onnx_graph_to_nnp_protobuf(pb, graph):
# No one is using this buffer so we show a warning and ignore it.
logger.warning("Input buffer {} is not used as input for any node.".format(i.name))
for o in graph.output:
v = add_value_info_as_buffer(o, network)
v = add_value_info_as_buffer(network, o)
out_list.append(v)
del all_vars[v.name]

for varg in all_vars:
# We add all remaining variables as intermediate buffer
# We add all remaining variables as intermediate buffer,
# except for the ones that was converted to a parameter.
# A conversion of a buffer to a parameter may occur in functions
# such as Constant
if varg in param_vars:
pass
# We leave the buffer size of all intermediate buffers empty
v = network.variable.add()
v.type = "Buffer"
Expand Down

0 comments on commit 1971dac

Please sign in to comment.