Skip to content

Commit

Permalink
GlobalAveragePoolingの単体テストを追加
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Feb 15, 2018
1 parent e423e05 commit a979694
Showing 1 changed file with 7 additions and 40 deletions.
47 changes: 7 additions & 40 deletions python/test/utils/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def onnx_value_info_proto_to_variable(info, network):
"Relu": "ReLU",
"Concat": "Concatenate",
"Conv": "Convolution",
"GlobalAveragePool": "AveragePooling",
"GlobalAveragePool": "GlobalAveragePooling",
"MaxPool": "MaxPooling",
}

Expand All @@ -69,7 +69,7 @@ def onnx_value_info_proto_to_variable(info, network):
"ReLU": "Relu",
"Concatenate": "Concat",
"Convolution": "Conv",
"AveragePooling": "GlobalAveragePool",
"GlobalAveragePooling": "GlobalAveragePool",
"MaxPooling": "MaxPool",
}

Expand Down Expand Up @@ -198,14 +198,6 @@ def convert_to_function(node, base_name, func_counter):
# Set default values.
# Do we really need this? (Default value should be set by NNabla)
cp.dilation.dim.extend([1 for _ in range(dim)])
elif node.op_type == "GlobalAveragePool":
func.type = "Identity"
## We substitute GlobalAveragePool with an AveragePool
## that has the same kernel size as the input WxH
#app = func.average_pooling_param
#app.kernel.dim.extend([3,3])
#app.stride.dim.extend([3,3])
#app.pad.dim.extend([0,0])
elif node.op_type == "MaxPool":
mpp = func.max_pooling_param
dims = []
Expand Down Expand Up @@ -537,7 +529,7 @@ def convert_onnx_to_nnp_and_compare(
p = os.path.join(str(nnpdir), nnp_name)
nnpex.export_nnp(p)
# read exported nnp and run network
#pdb.set_trace()
pdb.set_trace()
nn_net = nnload.load([p])
exe = run_executor(nn_net, exec_name)
#in_data = exe.variables["in_data_0"]
Expand Down Expand Up @@ -663,6 +655,10 @@ def test_nnp_onnx_conversion_conv(tmpdir, nnp_fixture):
convert_nnp_to_onnx_and_compare(
tmpdir, TEST_DATA_DIR, "conv.nnp", "conv.onnx", "out_data_1", "exec_0")

def test_onnx_nnp_conversion_gap(tmpdir, nnp_fixture):
convert_onnx_to_nnp_and_compare(
tmpdir, TEST_DATA_DIR, "gap.onnx", "gap.nnp", "out_data_1", "exec_0")

def test_onnx_nnp_conversion_squeezenet(tmpdir, nnp_fixture):
onnx_dir = TEST_DATA_DIR
onnx_name = "squeezenet.onnx"
Expand Down Expand Up @@ -820,32 +816,3 @@ def test_nnp_onnx_conversion_squeezenet(tmpdir, nnp_fixture):
# print(nnout, nnout.shape)
# #assert np.allclose(c2, nnout)

##
#def test_onnx_nnp_conversion_gap(tmpdir):
# path = os.path.join(TEST_DATA_DIR, "gap.onnx")
# # Process onnx with caffe2 backend
# model = onnx.load(path)
# c2out = onnx_caffe2.backend.run_model(model, [])
# #print(c2out)
# # Process onnx with naabla
# r = OnnxReader(path)
# nnp = r.read()
# assert nnp is not None
# assert len(nnp.other_files) == 0
# assert nnp.protobuf is not None
# #logger.log(99, nnp.protobuf)
#
# nnpex = NnpExporter(nnp, batch_size=0)
# nnpdir = tmpdir.mkdir("nnp")
# p = os.path.join(str(nnpdir), "gap.nnp")
# nnpex.export_nnp(p)
# # read exported nnp and run network
# #pdb.set_trace()
# nn_net = nnload.load([p])
# gap = run_executor(nn_net, "exec_0")
# OUT_DATA_NAME = "out_data_1"
# out_data = gap.variables[OUT_DATA_NAME]
# nnout = gap.variables[OUT_DATA_NAME].variable_instance.d
# c2 = c2out[OUT_DATA_NAME]
# #print(c2, nnout)
# assert np.allclose(c2, nnout)

0 comments on commit a979694

Please sign in to comment.