Skip to content

Commit

Permalink
setではなくOrderedDictを使用して順番が保持されるよう修正
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Feb 19, 2018
1 parent 3bd0b43 commit 8b3c6e3
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions python/src/nnabla/utils/converter/onnx/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import struct
from collections import OrderedDict
import nnabla.logger as logger
from nnabla.utils import nnabla_pb2
from onnx import (ModelProto, TensorProto, AttributeProto)
Expand Down Expand Up @@ -211,7 +212,9 @@ def onnx_graph_to_nnp_protobuf(pb, graph):
network = pb.network.add()
network.name = graph.name

all_vars = set()
# We use an OrderedDict and not a set
# to preserve order
all_vars = OrderedDict()
func_counter = {}
# convert nodes
for n in graph.node:
Expand All @@ -221,13 +224,15 @@ def onnx_graph_to_nnp_protobuf(pb, graph):
f = convert_to_function(n, graph.name, func_counter)
#Gather all unique names for input and output
for i in f.input:
all_vars.add(i)
all_vars[i] = None
for o in f.output:
all_vars.add(o)
all_vars[o] = None
network.function.extend([f])

# convert parameters
param_vars = set()
# We use an OrderedDict and not a set
# to preserve order
param_vars = OrderedDict()
for init in graph.initializer:
if init.data_type != TensorProto.FLOAT:
raise ValueError("Only floating point data is supported for parameters {} (Got {})"
Expand All @@ -242,7 +247,7 @@ def onnx_graph_to_nnp_protobuf(pb, graph):
p.data.extend(data)
p.need_grad = False
# Keep the list of all initializer names
param_vars.add(init.name)
param_vars[init.name] = None
# We need to distinguish constant parameters (which become 'Parameter' in NNabla)
# from input/output variables (which become 'Buffer' in NNabla).
# Contant parameters appear in the initializer list so we keep
Expand All @@ -267,19 +272,19 @@ def onnx_graph_to_nnp_protobuf(pb, graph):
# This input is a variable
v.type ="Buffer"
in_list.append(v)
all_vars.remove(v.name)
del all_vars[v.name]
for o in graph.output:
v = onnx_value_info_proto_to_variable(o, network)
v.type = "Buffer"
out_list.append(v)
all_vars.remove(v.name)
del all_vars[v.name]

for varg in all_vars:
# We add all remaining variables as intermediate buffer
# We leave the buffer size of all intermediate buffers empty
v = network.variable.add()
v.type = "Buffer"
v.name = varg
# We calculate the buffer size of all intermediate buffers here

#pdb.set_trace()

Expand Down

0 comments on commit 8b3c6e3

Please sign in to comment.