Skip to content

Commit

Permalink
Divに対応
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Apr 26, 2018
1 parent 5b1db87 commit bf4c446
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 6 deletions.
4 changes: 3 additions & 1 deletion python/src/nnabla/utils/converter/onnx/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"Sum": "ReduceSum",
"Mean": "ReduceMean",
"Mul2": "Mul",
"Div2": "Div",
"LogicalAnd": "And",
"LogicalOr": "Or",
"LogicalXor": "Xor",
Expand Down Expand Up @@ -273,7 +274,8 @@ def convert_to_nodes(func, variables, input_types, output_types, broadcast_targe
broadcast_target[func.output[0]] = (func.input[1], bp.axis)
# we do not append node here because BroadcastTo should disappear
elif (func.type == "Add2" or
func.type == "Mul2"):
func.type == "Mul2" or
func.type == "Div2"):
# Check if the second input is a brodcast target.
bt = func.input[1]
if bt in broadcast_target:
Expand Down
2 changes: 2 additions & 0 deletions python/src/nnabla/utils/converter/onnx/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"Gemm": "Affine",
"Add": "Add2",
"Mul": "Mul2",
"Div": "Div2",
"MatMul": "BatchMatmul",
"LeakyRelu": "LeakyReLU",
"Not": "LogicalNot",
Expand Down Expand Up @@ -523,6 +524,7 @@ def convert_to_functions(pb, network, node, base_name, initializers,
func_list.append(func)
elif (node.op_type == "Add" or
node.op_type == "Mul" or
node.op_type == "Div" or
node.op_type == "And" or
node.op_type == "Or" or
node.op_type == "Xor"):
Expand Down
43 changes: 38 additions & 5 deletions python/test/utils/conversion/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@
from nnabla.utils.converter.nnabla import NnpReader, NnpExporter
from nnabla.utils.converter.onnx import OnnxReader, OnnxExporter, onnx_model_to_nnp_protobuf

# The directory of which the input ONNX files will be at
TEST_DATA_DIR = "nnabla-sample-data/conversion_data"

# Set a path to this parameter (preferably the same as TEST_DATA_DIR)
# if you want to update all the NNP files
DEFAULT_NNP_EXPORT_PATH = None

def print_buffer_shape(net):
for k, v in net.functions.items():
out = v.outputs[0]
Expand All @@ -47,7 +52,7 @@ def convert_onnx_to_nnp_and_compare(
in_img=None, in_name="",
compare_values=True, show_onnx=False, show_nnp=False,
show_output=False, atol=1e-08,
export_nnp_path=None):
export_nnp_path=DEFAULT_NNP_EXPORT_PATH):
"""Convert specified ONNX to NNP and compare each results ran by Caffe2 and NNabla"""
path = os.path.join(onnx_dir, onnx_name)
backend_out = None
Expand Down Expand Up @@ -565,12 +570,39 @@ def test_onnx_nnp_conversion_xor_broadcast_axis1(tmpdir, nnp_fixture):
"out_data_1", "exec_0")


def test_nnp_onnx_conversion_or_broadcast_axis1(tmpdir, nnp_fixture):
def test_nnp_onnx_conversion_xor_broadcast_axis1(tmpdir, nnp_fixture):
convert_nnp_to_onnx_and_compare(tmpdir, TEST_DATA_DIR,
"or_broadcast_axis1.nnp",
"or_broadcast_axis1.onnx",
"xor_broadcast_axis1.nnp",
"xor_broadcast_axis1.onnx",
"out_data_1", "exec_0")


def test_onnx_nnp_conversion_div_no_broadcast(tmpdir, nnp_fixture):
convert_onnx_to_nnp_and_compare(tmpdir, TEST_DATA_DIR,
"div_no_broadcast.onnx",
"div_no_broadcast.nnp",
"out_data_1", "exec_0")


def test_nnp_onnx_conversion_div_no_broadcast(tmpdir, nnp_fixture):
convert_nnp_to_onnx_and_compare(tmpdir, TEST_DATA_DIR,
"div_no_broadcast.nnp",
"div_no_broadcast.onnx",
"out_data_1", "exec_0")


def test_onnx_nnp_conversion_div_broadcast_axis1(tmpdir, nnp_fixture):
convert_onnx_to_nnp_and_compare(tmpdir, TEST_DATA_DIR,
"div_broadcast_axis1.onnx",
"div_broadcast_axis1.nnp",
"out_data_1", "exec_0")


def test_nnp_onnx_conversion_div_broadcast_axis1(tmpdir, nnp_fixture):
convert_nnp_to_onnx_and_compare(tmpdir, TEST_DATA_DIR,
"div_broadcast_axis1.nnp",
"div_broadcast_axis1.onnx",
"out_data_1", "exec_0")
# These following tests are invalidated due to a
# backend bug? decribed in the following issue:
# https://github.com/Microsoft/CNTK/issues/3127
Expand Down Expand Up @@ -613,14 +645,15 @@ def test_nnp_onnx_conversion_squeezenet(tmpdir, nnp_fixture):
# tmpdir, TEST_DATA_DIR, "inception_v2.onnx", "inception_v2.nnp", "prob_1", "exec_0",
# in_name="data_0", in_img=img)


@pytest.mark.slow
def test_onnx_nnp_conversion_densenet121(tmpdir, nnp_fixture):
img = np.random.rand(1, 3, 224, 224).astype(np.float32)
convert_onnx_to_nnp_and_compare(
tmpdir, TEST_DATA_DIR, "densenet121.onnx", "densenet121.nnp", "fc6_1", "exec_0",
in_name="data_0", in_img=img, atol=1e-5)


@pytest.mark.slow
def test_nnp_onnx_conversion_densenet121(tmpdir, nnp_fixture):
img = np.random.rand(1, 3, 224, 224).astype(np.float32)
convert_nnp_to_onnx_and_compare(
Expand Down

0 comments on commit bf4c446

Please sign in to comment.