Skip to content

Commit

Permalink
Investigate string separator for SplitString operator
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Oct 20, 2023
1 parent 92d17b0 commit 578a6e4
Showing 1 changed file with 70 additions and 66 deletions.
136 changes: 70 additions & 66 deletions test/test_string_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,79 +913,83 @@ def test_string_split_python(self):
onnx_model = _create_test_model_string_split('Py')
self.assertIn('op_type: "PyStringSplit"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
input = np.array(["a,,b", "", "aa,b,c", "dddddd"])
delimiter = np.array([","])

for skip in [True, False]:
with self.subTest(skip=skip):
skip_empty = np.array([skip])

txout = sess.run(
None, {'input': input, 'delimiter': delimiter,
'skip_empty': skip_empty})

if skip_empty:
exp_indices = np.array(
[[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
else:
exp_indices = np.array(
[[0, 0], [0, 1], [0, 2], [2, 0], [2, 1],
[2, 2], [3, 0]])
exp_text = np.array(
['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
exp_shape = np.array([4, 3])
self.assertEqual(exp_indices.tolist(), txout[0].tolist())
self.assertEqual(exp_text.tolist(), txout[1].tolist())
self.assertEqual(exp_shape.tolist(), txout[2].tolist())
for sep in [",", ":/", ",,"]:
with self.subTest(sep=sep):
input = np.array([f"a{sep}{sep}b", "", f"aa{sep}b{sep}c", "dddddd"])
delimiter = np.array([sep])

for skip in [True, False]:
with self.subTest(skip=skip):
skip_empty = np.array([skip])

txout = sess.run(
None, {'input': input, 'delimiter': delimiter,
'skip_empty': skip_empty})

if skip_empty:
exp_indices = np.array(
[[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
else:
exp_indices = np.array(
[[0, 0], [0, 1], [0, 2], [2, 0], [2, 1],
[2, 2], [3, 0]])
exp_text = np.array(
['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
exp_shape = np.array([4, 3])
self.assertEqual(exp_indices.tolist(), txout[0].tolist())
self.assertEqual(exp_text.tolist(), txout[1].tolist())
self.assertEqual(exp_shape.tolist(), txout[2].tolist())

def test_string_split_cc(self):
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model = _create_test_model_string_split('')
self.assertIn('op_type: "StringSplit"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
input = np.array(["a,,b", "", "aa,b,c", "dddddd"])
delimiter = np.array([","])

for skip in [True, False]:
with self.subTest(skip=skip):
skip_empty = np.array([skip])

txout = sess.run(
None, {'input': input, 'delimiter': delimiter,
'skip_empty': skip_empty})

try:
from tensorflow.raw_ops import StringSplit
dotf = True
except ImportError:
dotf = False
if dotf:
tfres = StringSplit(
input=input, delimiter=",,", skip_empty=skip)
self.assertEqual(
[_.decode() for _ in tfres[1].numpy().tolist()],
txout[1].tolist())
self.assertEqual(
tfres[0].numpy().tolist(), txout[0].tolist())
self.assertEqual(
tfres[2].numpy().tolist(), txout[2].tolist())

if skip_empty:
exp_indices = np.array(
[[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
else:
exp_indices = np.array(
[[0, 0], [0, 1], [0, 2], [2, 0], [2, 1],
[2, 2], [3, 0]])
exp_text = np.array(
['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
exp_shape = np.array([4, 3])
self.assertEqual(exp_indices.tolist(), txout[0].tolist())
self.assertEqual(exp_text.tolist(), txout[1].tolist())
self.assertEqual(exp_shape.tolist(), txout[2].tolist())
for sep in [",", ":/", ",,"]:
with self.subTest(sep=sep):
input = np.array([f"a{sep}{sep}b", "", f"aa{sep}b{sep}c", "dddddd"])
delimiter = np.array([sep])

for skip in [True, False]:
with self.subTest(skip=skip):
skip_empty = np.array([skip])

txout = sess.run(
None, {'input': input, 'delimiter': delimiter,
'skip_empty': skip_empty})

try:
from tensorflow.raw_ops import StringSplit
dotf = True
except ImportError:
dotf = False
if dotf:
tfres = StringSplit(
input=input, delimiter=",,", skip_empty=skip)
self.assertEqual(
[_.decode() for _ in tfres[1].numpy().tolist()],
txout[1].tolist())
self.assertEqual(
tfres[0].numpy().tolist(), txout[0].tolist())
self.assertEqual(
tfres[2].numpy().tolist(), txout[2].tolist())

if skip_empty:
exp_indices = np.array(
[[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
else:
exp_indices = np.array(
[[0, 0], [0, 1], [0, 2], [2, 0], [2, 1],
[2, 2], [3, 0]])
exp_text = np.array(
['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
exp_shape = np.array([4, 3])
self.assertEqual(exp_indices.tolist(), txout[0].tolist())
self.assertEqual(exp_text.tolist(), txout[1].tolist())
self.assertEqual(exp_shape.tolist(), txout[2].tolist())

def test_string_split_cc_sep2(self):
so = _ort.SessionOptions()
Expand Down

0 comments on commit 578a6e4

Please sign in to comment.