Skip to content

Commit

Permalink
BatchNormalizationのNNP to ONNXを暫定的に実装
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Feb 28, 2018
1 parent 38bac96 commit 92e058f
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 8 deletions.
91 changes: 85 additions & 6 deletions python/src/nnabla/utils/converter/onnx/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# optype with same names
"Dropout": "Dropout",
"Softmax": "Softmax",
"BatchNormalization": "BatchNormalization",
# optype with different names
"ReLU": "Relu",
"Concatenate": "Concat",
Expand Down Expand Up @@ -124,8 +125,91 @@ def convert_to_node(func, variables):
s = onnx.helper.make_attribute("strides", app.stride.dim)
p = onnx.helper.make_attribute("pads", app.pad.dim)
n.attribute.extend([k, s, p])
elif func.type == "BatchNormalization":
# We need to rearrange the input data order.
# NNabla BatchNormalization input order: X, beta, gamma, mean, variance
# ONNX BatchNormalization input order: X, scale, bias, mean, variance
onnx_order = [0, 2, 1, 3, 4]
if len(func.input) != len(onnx_order):
raise ValueError("The number of BatchNormalization input must be {}".format(len(onnx_order)))
onnx_input = [func.input[i] for i in onnx_order]
del n.input[:]
n.input.extend(onnx_input)
bpp = func.batch_normalization_param
if bpp.batch_stat:
# Batch normalization for training is currently not supported
raise ValueError("BatchNormalization with batch_stat=True is currently not supported for ONNX conversion")
t = onnx.helper.make_attribute("is_test", not bpp.batch_stat)
attrs = [t]
# Set values if a valid value has been set
if bpp.eps != 0.0:
e = onnx.helper.make_attribute("epsilon", bpp.eps)
attrs.append(e)
if bpp.decay_rate != 0.0:
m = onnx.helper.make_attribute("momentum", bpp.decay_rate)
attrs.append(m)
# We set an undocumented attribute 'consumed_inputs' here because
# ONNX will check if BatchNormalization has the attribute set.
# consumed_inputs is basically a list of flags indicating which
# input data will be updated in-place (meaning the input and output will have
# same names). The value we are setting here is showing that the mean and variance
# will be specified as an in-place input.
# This should not be needed when is_test=True
# since we will not be outputting mean or variance, but since ONNX is enforcing
# the check we need to set it.
ci = onnx.helper.make_attribute("consumed_inputs", [0, 0, 0, 1, 1])
attrs.append(ci)
n.attribute.extend(attrs)
return n

def create_dim(val):
"""Create a dimension message for a given dimension"""
dim = TensorShapeProto.Dimension()
dim.dim_value = val
return dim

def convert_parameter_shape(graph):
"""Convert the shape of some parameters so they fit ONNX's requirements.
We do this as a post conversion because in the future we may be able to
delete the whole conversion if NNabla's code gets changed"""
batch_norm_constants = []
for n in graph.node:
if n.op_type == "BatchNormalization":
# BatchNormalization in ONNX requires the scale, bias, mean, and variance input to be
# one dimensional (https://github.com/onnx/onnx/blob/master/docs/Operators.md#batchnormalization).
# However in NNabla these input must have a specific shape that matches the input shape.
# For example if the input shape is (1,3,3,3), the above variables must have the shape (1,3,1,1) and not (3).
# (1,3,1,1) is actually the same as a one-dimensional tensor of size 3,
# but NNabla's check currently does not allow this.
# Thus, we convert the shape of the above input so we can pass ONNX's check.
# If NNabla or ONNX lightens the requirements, we should be able to remove this conversion.
batch_norm_constants.extend(n.input[1:5]) # copy all input names for scale, bias, mean, variance

# This loop should be fairly slow since we loop through all variables and parameters per constant
for c in batch_norm_constants:
# Reshape all BatchNormalization constant inputs assuming the size is (1,size,1,1)
for i in graph.initializer:
if i.name == c:
size = i.dims
if not (len(size) == 4 and size[2] == 1 and size[3] == 1):
raise ValueError(
"beta, gamma, mean, and variance parameters"
"must have the shape of N*C*1*1 in {}".format(n.optype))
chan = size[1]
del i.dims[:]
i.dims.extend([chan])
break
for i in graph.input:
if i.name == c:
size = i.type.tensor_type.shape.dim
if not (len(size) == 4 and size[2].dim_value == 1 and size[3].dim_value == 1):
raise ValueError(
"beta, gamma, mean, and variance parameters"
"must have the shape of N*C*1*1 in {}".format(n.optype))
chan = size[1].dim_value
del i.type.tensor_type.shape.dim[:]
i.type.tensor_type.shape.dim.extend([create_dim(chan)])
break

def nnp_model_to_onnx_graph(graph, nnp):
if len(nnp.network) != 1:
Expand Down Expand Up @@ -155,12 +239,6 @@ def nnp_model_to_onnx_graph(graph, nnp):

# Add all the constant parameters for all nodes
# and the first node's input as input
def create_dim(val):
"""Create a dimension message for a given dimension"""
dim = TensorShapeProto.Dimension()
dim.dim_value = val
return dim

for iv in exe.data_variable:
i = graph.input.add()
i.name = iv.variable_name
Expand All @@ -180,6 +258,7 @@ def create_dim(val):
o.type.tensor_type.elem_type = TensorProto.FLOAT
dims = [create_dim(d) for d in var_dict[ov.variable_name].dim]
o.type.tensor_type.shape.dim.extend(dims)
convert_parameter_shape(graph)


def nnp_model_to_onnx_protobuf(nnp):
Expand Down
7 changes: 5 additions & 2 deletions python/src/nnabla/utils/converter/onnx/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ def convert_parameter_shape(pb):
"""Convert the shape of some parameters so they fit NNabla's requirements.
We do this as a post conversion because in the future we may be able to
delete the whole conversion if NNabla's code gets changed"""
if len(pb.network) != 1:
raise ValueError("NNP with more then a single network is currently not supported")
net = pb.network[0]
batch_norm_constants = []
for f in net.function:
Expand All @@ -299,10 +301,11 @@ def convert_parameter_shape(pb):
# one dimensional (https://github.com/onnx/onnx/blob/master/docs/Operators.md#batchnormalization).
# However in NNabla these input must have a specific shape that matches the input shape.
# For example if the input shape is (1,3,3,3), the above variables must have the shape (1,3,1,1) and not (3).
# (1,3,1,1) is actually the same as a one-dimensional tensor of size 3, but NNabla's check currently does not allow this.
# (1,3,1,1) is actually the same as a one-dimensional tensor of size 3,
# but NNabla's check currently does not allow this.
# Thus, we convert the shape of the above input so we can pass NNabla's check.
# If NNabla lightens the shape check, we should be able to remove this conversion.
batch_norm_constants.extend(f.input[1:]) # We copy all input except the first one
batch_norm_constants.extend(f.input[1:5]) # copy all input names for scale, bias, mean, variance

# This loop should be fairly slow since we loop through all variables and parameters per constant
for c in batch_norm_constants:
Expand Down
4 changes: 4 additions & 0 deletions python/test/utils/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ def test_onnx_nnp_conversion_batch_normalization(tmpdir, nnp_fixture):
convert_onnx_to_nnp_and_compare(
tmpdir, TEST_DATA_DIR, "batch_norm.onnx", "batch_norm.nnp", "out_data_1", "exec_0", atol=1e-05)

def test_nnp_onnx_conversion_batch_normalization(tmpdir, nnp_fixture):
convert_nnp_to_onnx_and_compare(
tmpdir, TEST_DATA_DIR, "batch_norm.nnp", "batch_norm.onnx", "out_data_1", "exec_0", atol=1e-05)

def test_onnx_nnp_conversion_squeezenet(tmpdir, nnp_fixture):
onnx_dir = TEST_DATA_DIR
onnx_name = "squeezenet.onnx"
Expand Down

0 comments on commit 92e058f

Please sign in to comment.