Skip to content

Commit

Permalink
MulScalarの変換を修正しPowScalarの変換を追加
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed May 30, 2018
1 parent 7726497 commit 0df4cb4
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 15 deletions.
52 changes: 41 additions & 11 deletions python/src/nnabla/utils/converter/onnx/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,29 @@
"RDivScalar": "Reciprocal",
# optype that gets converted
"Affine": "Gemm",
"MulScalar": "Neg",
"MulScalar": "Mul",
"MinimumScalar": "Clip",
"MaximumScalar": "Clip",
"AddScalar": "Add",
"PowScalar": "Pow",
#"SumPooling": "Mul",
# optype that should get merged
# with other operators
"BroadcastTo": ""
}

def generate_scalar_constant(name, tensor_name, scalar):
"""Convert a scalar value to a Constant buffer.
This is mainly used for xxScalar operators."""
t = onnx.helper.make_tensor(tensor_name,
data_type=TensorProto.FLOAT,
dims=[1], vals=[scalar])
c = onnx.helper.make_node("Constant",
[],
[name],
value=t)
return c

def merge_broadcast(node, func, target_name, broadcast_target):
# Set the broadcast attribute to the operator
# so we can combine BroadcastTo with this operator.
Expand Down Expand Up @@ -354,8 +368,17 @@ def convert_to_nodes(func, variables, input_types, output_types, broadcast_targe
nl.append(n)
elif func.type == "MulScalar":
mp = func.mul_scalar_param
if mp.val != -1.0:
raise ValueError("MulScalar can be converted to Neg only if val is -1")
if mp.val == -1.0:
# Convert to Neg
n.op_type = "Neg"
else:
# Convert the scalar param to a Const node and add it with input
x = func.input[0]
sval = x+"_scalar"
c = generate_scalar_constant(sval, func.name+"_scalar", mp.val)
del n.input[:]
n.input.extend([x, sval])
nl.append(c)
nl.append(n)
elif func.type == "MinimumScalar":
msp = func.minimum_scalar_param
Expand All @@ -372,18 +395,25 @@ def convert_to_nodes(func, variables, input_types, output_types, broadcast_targe
# Convert the scalar param to a Const node and add it with input
x = func.input[0]
sval = x+"_scalar"
t = onnx.helper.make_tensor(func.name+"_scalar",
data_type=TensorProto.FLOAT,
dims=[1], vals=[asp.val])
c = onnx.helper.make_node(
"Constant",
[],
[sval],
value=t)
c = generate_scalar_constant(sval, func.name+"_scalar", asp.val)
del n.input[:]
n.input.extend([x, sval])
nl.append(c)
nl.append(n)
elif func.type == "PowScalar":
psp = func.pow_scalar_param
# Convert the scalar param to a Const node and add it with input
x = func.input[0]
sval = x+"_scalar"
c = generate_scalar_constant(sval, func.name+"_scalar", psp.val)
del n.input[:]
n.input.extend([x, sval])
nl.append(c)
nl.append(n)
#elif func.type == "SumPooling":
# # SumPooling gets converted to AveragePooling+Mul.
# # Mul is used to counter the division in AveragePooling
# # since SumPooling is just summing the values in each kernel.
else:
# Simply append node to list
nl.append(n)
Expand Down
11 changes: 7 additions & 4 deletions python/test/utils/conversion/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@
import pytest
import nnabla
import nnabla.utils.load as nnload
import onnx
import numpy as np
import pdb
import onnx
from collections import OrderedDict
import caffe2.python.onnx.backend as oc2
import cntk
import cntk.ops.functions as cntkf
from nnabla.utils.converter.nnabla import NnpReader, NnpExporter
from nnabla.utils.converter.onnx import (
OnnxReader, OnnxExporter,
onnx_model_to_nnp_protobuf,
)
try:
import caffe2.python.onnx.backend as oc2
import cntk
import cntk.ops.functions as cntkf
except:
print('Need to install Caffe2 and CNTK for testing.')

# The directory of which the input ONNX files will be at
TEST_DATA_DIR = "nnabla-sample-data/conversion_data"
Expand Down

0 comments on commit 0df4cb4

Please sign in to comment.