Skip to content

Commit

Permalink
Seluのサポートを追加し、未対応のattributeが来たときのエラー処理も追加
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Apr 12, 2018
1 parent e416a14 commit 6b6d5ca
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/src/nnabla/utils/converter/onnx/ONNXOpCoverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ that indicates if each operator can be converted to NNP.
|ReduceSumSquare|:black_heart:||
|Relu|:green_heart:||
|Reshape|:yellow_heart:|implementing|
|Selu|:black_heart:||
|Selu|:green_heart:||
|Sigmoid|:green_heart:||
|Slice|:black_heart:||
|Softmax|:yellow_heart:|Supporting 2D input only|
Expand Down
6 changes: 6 additions & 0 deletions python/src/nnabla/utils/converter/onnx/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"BatchMatmul": "MatMul",
"LogicalNot": "Not",
"ELU": "Elu",
"SELU": "Selu",
# optype that gets converted
"Identity": "Dropout",
"Affine": "Gemm",
Expand Down Expand Up @@ -199,6 +200,11 @@ def convert_to_nodes(func, variables, input_types, output_types):
# Store the input/output tensor's name and convert it to boolean
input_types[n.input[0]] = TensorProto.BOOL
output_types[n.output[0]] = TensorProto.BOOL
elif func.type == "SELU":
sp = func.selu_param
a = onnx.helper.make_attribute("alpha", sp.alpha)
g = onnx.helper.make_attribute("gamma", sp.scale)
n.attribute.extend([a, g])
nl.append(n)
return nl

Expand Down
39 changes: 39 additions & 0 deletions python/src/nnabla/utils/converter/onnx/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"LeakyRelu": "LeakyReLU",
"Not": "LogicalNot",
"Elu": "ELU",
"Selu": "SELU",
# Constant does not get converted to a function
# but we list it here so we can accept it
"Constant": ""
Expand Down Expand Up @@ -456,6 +457,9 @@ def convert_to_functions(pb, network, node, base_name, initializers,
raise ValueError("broadcasting is currently not supported for {}".format(node.op_type))
# Add2 broadcasts by default so we do nothing here
#pass
else:
raise ValueError("Unsupported attribute {} was specified at {}"
.format(attr.name, node.op_type))
func_list.append(func)
elif node.op_type == "Mul":
# We need the input buffer's dimension information here
Expand Down Expand Up @@ -486,6 +490,9 @@ def convert_to_functions(pb, network, node, base_name, initializers,
raise ValueError("broadcasting is currently not supported for {}".format(node.op_type))
# Mul2 broadcasts by default so we do nothing here
#pass
else:
raise ValueError("Unsupported attribute {} was specified at {}"
.format(attr.name, node.op_type))
func_list.append(func)
elif node.op_type == "Constant":
# Convert a Constant node as an input parameter and not a function
Expand All @@ -510,6 +517,9 @@ def convert_to_functions(pb, network, node, base_name, initializers,
v.initializer.type = "Constant"
v.initializer.multiplier = 1.0
param_list.append(v)
else:
raise ValueError("Unsupported attribute {} was specified at {}"
.format(attr.name, node.op_type))
# We do not add any function to the list here
# since the node is converted as a parameter
elif node.op_type == "Reshape":
Expand All @@ -522,6 +532,9 @@ def convert_to_functions(pb, network, node, base_name, initializers,
raise ValueError("Only INTS is supported for shape in {} op_type".format(node.op_type))
rp.shape.dim.extend(attr.ints)
shape_found = True
else:
raise ValueError("Unsupported attribute {} was specified at {}"
.format(attr.name, node.op_type))
if len(func.input) == 2:
# Shape comes as input for Reshape-5.
# NNabla reshape excepts a single input (data),
Expand Down Expand Up @@ -555,6 +568,9 @@ def convert_to_functions(pb, network, node, base_name, initializers,
if attr.type != AttributeProto.INTS:
raise ValueError("Only INTS is supported for perm in {} op_type".format(node.op_type))
tp.axes.extend(attr.ints)
else:
raise ValueError("Unsupported attribute {} was specified at {}"
.format(attr.name, node.op_type))
func_list.append(func)
elif node.op_type == "LeakyRelu":
lrp = func.leaky_relu_param
Expand All @@ -564,6 +580,9 @@ def convert_to_functions(pb, network, node, base_name, initializers,
if attr.type != AttributeProto.FLOAT:
raise ValueError("Only FLOAT is supported for alpha in {} op_type".format(node.op_type))
lrp.alpha = attr.f
else:
raise ValueError("Unsupported attribute {} was specified at {}"
.format(attr.name, node.op_type))
func_list.append(func)
elif node.op_type == "Elu":
ep = func.elu_param
Expand All @@ -573,6 +592,26 @@ def convert_to_functions(pb, network, node, base_name, initializers,
if attr.type != AttributeProto.FLOAT:
raise ValueError("Only FLOAT is supported for alpha in {} op_type".format(node.op_type))
ep.alpha = attr.f
else:
raise ValueError("Unsupported attribute {} was specified at {}"
.format(attr.name, node.op_type))
func_list.append(func)
elif node.op_type == "Selu":
sp = func.selu_param
sp.alpha = 1.6732 # Default value for ONNX
sp.scale = 1.0507
for attr in node.attribute:
if attr.name == "alpha":
if attr.type != AttributeProto.FLOAT:
raise ValueError("Only FLOAT is supported for alpha in {} op_type".format(node.op_type))
sp.alpha = attr.f
elif attr.name == "gamma":
if attr.type != AttributeProto.FLOAT:
raise ValueError("Only FLOAT is supported for gamma in {} op_type".format(node.op_type))
sp.scale = attr.f
else:
raise ValueError("Unsupported attribute {} was specified at {}"
.format(attr.name, node.op_type))
func_list.append(func)
else:
# Simply add the function for all other conversions
Expand Down
11 changes: 11 additions & 0 deletions python/test/utils/conversion/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,17 @@ def test_nnp_onnx_conversion_elu(tmpdir, nnp_fixture):
convert_nnp_to_onnx_and_compare(
tmpdir, TEST_DATA_DIR, "elu.nnp", "elu.onnx", "out_data_1", "exec_0")


def test_onnx_nnp_conversion_selu(tmpdir, nnp_fixture):
convert_onnx_to_nnp_and_compare(
tmpdir, TEST_DATA_DIR, "selu.onnx", "selu.nnp", "out_data_1", "exec_0")


def test_nnp_onnx_conversion_selu(tmpdir, nnp_fixture):
convert_nnp_to_onnx_and_compare(
tmpdir, TEST_DATA_DIR, "selu.nnp", "selu.onnx", "out_data_1", "exec_0")


def test_onnx_nnp_conversion_squeezenet(tmpdir, nnp_fixture):
img = np.random.rand(1, 3, 224, 224).astype(np.float32)
convert_onnx_to_nnp_and_compare(
Expand Down

0 comments on commit 6b6d5ca

Please sign in to comment.