Skip to content

Commit

Permalink
Concatの単体テストでcaffe2とnnablaの結果を比較するよう修正
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Jan 26, 2018
1 parent caf4fa2 commit c8f8e44
Showing 1 changed file with 31 additions and 30 deletions.
61 changes: 31 additions & 30 deletions python/test/utils/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,36 +319,37 @@ def test_nnp_onnx_conversion_relu(tmpdir):
#print(c2, nnout)
assert np.allclose(c2, nnout)

#def test_onnx_nnp_conversion_concat(tmpdir):
# path = os.path.join(TEST_DATA_DIR, "concat.onnx")
# # Process onnx with caffe2 backend
# #model = onnx.load(path)
# #c2out = onnx_caffe2.backend.run_model(model, [])
# #print(c2out)
# # Process onnx with naabla
# r = OnnxReader(path)
# nnp = r.read()
# assert nnp is not None
# assert len(nnp.other_files) == 0
# assert nnp.protobuf is not None
# logger.log(99, nnp.protobuf)
#
# nnpex = NnpExporter(nnp, batch_size=0)
# nnpdir = tmpdir.mkdir("nnp")
# p = os.path.join(str(nnpdir), "concat.nnp")
# nnpex.export_nnp(p)
# # read exported nnp and run network
# #pdb.set_trace()
# nn_net = nnload.load([p])
# concat = nn_net.networks["concat_net"]
# id0 = concat.variables["in_data_0_0"]
# id1 = concat.variables["in_data_1_0"]
# print(id0.variable_instance.d)
# print(id1.variable_instance.d)
# out_data = concat.variables["out_data_1"]
# ovi = out_data.variable_instance
# ovi.forward()
# print(ovi.d)
def test_onnx_nnp_conversion_concat(tmpdir):
path = os.path.join(TEST_DATA_DIR, "concat.onnx")
# Process onnx with caffe2 backend
model = onnx.load(path)
c2out = onnx_caffe2.backend.run_model(model, [])
# Process onnx with naabla
r = OnnxReader(path)
nnp = r.read()
assert nnp is not None
assert len(nnp.other_files) == 0
assert nnp.protobuf is not None
logger.log(99, nnp.protobuf)

nnpex = NnpExporter(nnp, batch_size=0)
nnpdir = tmpdir.mkdir("nnp")
p = os.path.join(str(nnpdir), "concat.nnp")
nnpex.export_nnp(p)
# read exported nnp and run network
#pdb.set_trace()
nn_net = nnload.load([p])
concat = run_executor(nn_net, "exec_0")
#id0 = concat.variables["in_data_0_0"]
#id1 = concat.variables["in_data_1_0"]
#print(id0.variable_instance.d)
#print(id1.variable_instance.d)
OUT_DATA_NAME = "out_data_1"
nnout = concat.variables[OUT_DATA_NAME].variable_instance.d
c2 = c2out[OUT_DATA_NAME]
#print(c2, c2.shape)
#print(nnout, nnout.shape)
assert np.allclose(c2, nnout)
#
#def test_onnx_nnp_conversion_gap(tmpdir):
# path = os.path.join(TEST_DATA_DIR, "gap.onnx")
Expand Down

0 comments on commit c8f8e44

Please sign in to comment.