Skip to content

Commit

Permalink
ONNXのhelperを使うよう修正
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Feb 2, 2018
1 parent f88f400 commit fa99734
Showing 1 changed file with 25 additions and 54 deletions.
79 changes: 25 additions & 54 deletions python/test/utils/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import nnabla.utils.load as nnload
from nnabla.utils import nnabla_pb2
import onnx
import onnx.helper
from onnx import (ModelProto, TensorProto, GraphProto,
TensorShapeProto, AttributeProto, NodeProto)
import nnabla.logger as logger
Expand Down Expand Up @@ -343,51 +344,37 @@ def read(self):
return onnx_model_to_nnp_protobuf(model_proto)

def convert_to_node(func, variables):
n = NodeProto()
n.name = func.name
n.op_type = nnabla_function_type_to_onnx_optype.get(func.type, func.type)
n.input.extend(func.input)
n.output.extend(func.output)
n = onnx.helper.make_node(
nnabla_function_type_to_onnx_optype.get(func.type, func.type),
func.input,
func.output,
name=func.name)
if func.type == "Concatenate":
# ONNX requires axis setting as a parameter
# for the concat op_type.
attr = n.attribute.add()
attr.name = "axis"
attr.type = AttributeProto.INT
# If no value is set for axis,
# the default value 0 will be set
attr.i = func.concatenate_param.axis
attr = onnx.helper.make_attribute("axis", func.concatenate_param.axis)
n.attribute.extend([attr])
elif func.type == "Dropout":
# NNP Dropout is always is_test=false
# since we always apply dropout when it is
# included in a network.
attr = n.attribute.add()
attr.name = "is_test"
attr.type = AttributeProto.INT
attr.i = 0
attr = onnx.helper.make_attribute("is_test", 0)
n.attribute.extend([attr])
elif func.type == "Identity":
# Convert Identity to a Dropout with is_test=true
# so we just copy the input to output
n.op_type = "Dropout"
attr = n.attribute.add()
attr.name = "is_test"
attr.type = AttributeProto.INT
attr.i = 1
attr = onnx.helper.make_attribute("is_test", 1)
n.attribute.extend([attr])
elif func.type == "MaxPooling":
mpp = func.max_pooling_param
# Copy kernel, stride, and pads values
k = n.attribute.add()
k.name = "kernel_shape"
k.type = AttributeProto.INTS
k.ints.extend(mpp.kernel.dim)
s = n.attribute.add()
s.name = "strides"
s.type = AttributeProto.INTS
s.ints.extend(mpp.stride.dim)
p = n.attribute.add()
p.name = "pads"
p.type = AttributeProto.INTS
p.ints.extend(mpp.pad.dim)
k = onnx.helper.make_attribute("kernel_shape", mpp.kernel.dim)
s = onnx.helper.make_attribute("strides", mpp.stride.dim)
p = onnx.helper.make_attribute("pads", mpp.pad.dim)
n.attribute.extend([k, s, p])
elif func.type == "Convolution":
cp = func.convolution_param
# Calculate the kernel_shape from input weight data.
Expand All @@ -405,27 +392,12 @@ def convert_to_node(func, variables):
weight_shape = weight_var[0].shape
# The base axis for weights is the next axis from the data's base axis
weight_base = cp.base_axis + 1
k = n.attribute.add()
k.name = "kernel_shape"
k.type = AttributeProto.INTS
k.ints.extend(weight_shape.dim[weight_base:])

d = n.attribute.add()
d.name = "dilations"
d.type = AttributeProto.INTS
d.ints.extend(cp.dilation.dim)
s = n.attribute.add()
s.name = "strides"
s.type = AttributeProto.INTS
s.ints.extend(cp.stride.dim)
p = n.attribute.add()
p.name = "pads"
p.type = AttributeProto.INTS
p.ints.extend(cp.pad.dim)
g = n.attribute.add()
g.name = "group"
g.type = AttributeProto.INT
g.i = cp.group
k = onnx.helper.make_attribute("kernel_shape", weight_shape.dim[weight_base:])
d = onnx.helper.make_attribute("dilations", cp.dilation.dim)
s = onnx.helper.make_attribute("strides", cp.stride.dim)
p = onnx.helper.make_attribute("pads", cp.pad.dim)
g = onnx.helper.make_attribute("group", cp.group)
n.attribute.extend([k, d, s, p, g])
return n

def nnp_model_to_onnx_graph(graph, nnp):
Expand Down Expand Up @@ -693,8 +665,7 @@ def change_to_copy(node):

c2out = onnx_caffe2.backend.run_model(model, [img])
# Process onnx with naabla
r = OnnxReader(path)
nnp = r.read()
nnp = onnx_model_to_nnp_protobuf(model)
assert nnp is not None
assert len(nnp.other_files) == 0
assert nnp.protobuf is not None
Expand All @@ -705,9 +676,9 @@ def change_to_copy(node):
nnpdir = tmpdir.mkdir("nnp")
p = os.path.join(str(nnpdir), nnp_name)
nnpex.export_nnp(p)
pdb.set_trace()
#pdb.set_trace()
# read exported nnp and run network
nn_net = nnload.load([p])
#nn_net = nnload.load([p])
#exe = run_executor(nn_net, exec_name)
##in_data = exe.variables["in_data_0"]
##print(in_data.variable_instance.d)
Expand Down

0 comments on commit fa99734

Please sign in to comment.