Skip to content

Commit

Permalink
Merge pull request #9 from sintel-dev/amend-transformation
Browse files Browse the repository at this point in the history
Support a 3d list (batch, output, samples)
  • Loading branch information
Linh-nk authored May 4, 2024
2 parents 5824724 + 517a9b1 commit 1d91935
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 29 deletions.
15 changes: 7 additions & 8 deletions sigllm/primitives/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _from_string_to_integer(text, sep=',', trunc=None, errors='ignore'):
else:
raise KeyError(f"Unknown errors strategy {errors}.")

clean = np.array(clean).astype(float)
clean = np.array(clean, dtype=float)

if trunc:
clean = clean[:trunc]
Expand All @@ -74,7 +74,7 @@ def _from_string_to_integer(text, sep=',', trunc=None, errors='ignore'):
def format_as_integer(strings, sep=',', trunc=None, errors='ignore'):
"""Format a nested list of text into an array of integers.
Transforms a list of list pf string input as 2-D array of integers,
Transforms a list of list of string input as 3-D array of integers,
seperated by the indicated seperator and truncated based on `trunc`.
Args:
Expand All @@ -98,18 +98,17 @@ def format_as_integer(strings, sep=',', trunc=None, errors='ignore'):
result = list()
for string_list in strings:
sample = list()
string = string_list
if not isinstance(string, list):
string = [string]
if not isinstance(string_list, list):
raise ValueError("Input is not a list of lists.")

for text in string:
for text in string_list:
scalar = _from_string_to_integer(text, sep, trunc, errors)
sample.extend(scalar)
sample.append(scalar)

result.append(sample)

output = np.array(result, dtype=object)
if output.ndim >= 2:
if output.ndim >= 3:
output = output.astype(float)

return output
Expand Down
55 changes: 34 additions & 21 deletions tests/primitives/test_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,24 +151,19 @@ def test__from_string_to_integer_error(self):
_from_string_to_integer(data, errors='unknown')


def test_format_as_integer_one():
data = ['1,2,3,4,5']
def test_format_as_integer_one_list():
data = ['1,2,3,4,5', '6,7,8,9,10']

expected = np.array([[
1, 2, 3, 4, 5
]])

output = format_as_integer(data)

np.testing.assert_equal(output, expected)
with pytest.raises(ValueError):
format_as_integer(data)


def test_format_as_integer_list():
data = [['1,2,3,4,5']]
def test_format_as_integer_list_of_list():
data = [['1,2,3,4,5', '6,7,8,9,10']]

expected = np.array([[
1, 2, 3, 4, 5
]])
expected = np.array([
[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]
])

output = format_as_integer(data)

Expand All @@ -179,30 +174,48 @@ def test_format_as_integer_2d_shape_mismatch():
data = [['1,2,3,4,5'], ['1, 294., 3 , j34,5'], ['!232, 23,3,4,5']]

expected = np.array([
[1, 2, 3, 4, 5],
[1, 3, 5],
[23, 3, 4, 5]
[np.array([1., 2, 3, 4, 5])],
[np.array([1., 3, 5])],
[np.array([23., 3, 4, 5])]
], dtype=object)

output = format_as_integer(data)

np.testing.assert_equal(output, expected)
for out, exp in list(zip(output, expected)):
for o, e in list(zip(out, exp)):
np.testing.assert_equal(o, e)


def test_format_as_integer_2d_trunc():
data = [['1,2,3,4,5'], ['1,294.,3,j34,5'], ['!232, 23,3,4,5']]

expected = np.array([
[1, 2],
[1, 3],
[23, 3]
[[1, 2]],
[[1, 3]],
[[23, 3]]
])

output = format_as_integer(data, trunc=2)

np.testing.assert_equal(output, expected)


def test_format_as_integer_3d():
data = [
['1,2,3,4,5', '6,7,8,9,10'],
['11,12,13,14,15', '16,17,18,19,20']
]

expected = np.array([
[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]],
[[11, 12, 13, 14, 15], [16, 17, 18, 19, 20]]
])

output = format_as_integer(data)

np.testing.assert_equal(output, expected)


class Float2ScalarTest(unittest.TestCase):

def test_transform_default(self):
Expand Down

0 comments on commit 1d91935

Please sign in to comment.