diff --git a/torcharrow/icolumn.py b/torcharrow/icolumn.py index 457f5f659..dc6c1ff34 100644 --- a/torcharrow/icolumn.py +++ b/torcharrow/icolumn.py @@ -269,18 +269,21 @@ def __len__(self): # printing ---------------------------------------------------------------- def __str__(self): - return f"Column([{', '.join(str(i) for i in self)}], id = {self.id})" + item_padding = "'" if dt.is_string(self.dtype) else "" + return f"Column([{', '.join(f'{item_padding}{i}{item_padding}' for i in self)}], id = {self.id})" def __repr__(self): - rows = [[l if l is not None else "None"] for l in self] + item_padding = "'" if dt.is_string(self.dtype) else "" + rows = [ + [f"{item_padding}{l}{item_padding}" if l is not None else "None"] + for l in self + ] tab = tabulate( rows, tablefmt="plain", showindex=True, ) - typ = ( - f"dtype: {self._dtype}, length: {len(self)}, null_count: {self.null_count}" - ) + typ = f"dtype: {self._dtype}, length: {len(self)}, null_count: {self.null_count}, device: {self.device}" return tab + dt.NL + typ # selectors/getters ------------------------------------------------------- diff --git a/torcharrow/test/test_numerical_column.py b/torcharrow/test/test_numerical_column.py index f323e24f9..6b86798ff 100644 --- a/torcharrow/test/test_numerical_column.py +++ b/torcharrow/test/test_numerical_column.py @@ -4,13 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import operator import statistics import typing as ty import unittest +from collections import defaultdict from math import ceil, floor, isnan, log import numpy as np import numpy.testing +import pandas as pd import torcharrow as ta import torcharrow.dtypes as dt from torcharrow.icolumn import Column @@ -32,7 +35,7 @@ def base_test_empty(self): return empty_i64_column def base_test_full(self): - col = ta.column([i for i in range(4)], dtype=dt.int64, device=self.device) + col = ta.column(list(range(4)), dtype=dt.int64, device=self.device) # self.assertEqual(col._offset, 0) self.assertEqual(len(col), 4) @@ -43,7 +46,7 @@ def base_test_full(self): return col def base_test_is_immutable(self): - col = ta.column([i for i in range(4)], dtype=dt.int64, device=self.device) + col = ta.column(list(range(4)), dtype=dt.int64, device=self.device) with self.assertRaises(AttributeError): # AssertionError: can't append a finalized list col._append(None) @@ -164,6 +167,13 @@ def base_test_map_where_filter(self): # Values that are not found in the dict are converted to None self.assertEqual(list(col.map({3: 33})), [None, None, None, 33, None, None]) + # maps default dict + d_dict = defaultdict(lambda: 1, {None: 2}) + self.assertEqual( + list(col.map(arg=d_dict)), + [2, 2, 2, 1, 1, 1], + ) + # maps None self.assertEqual( list(col.map({None: 1, 3: 33})), @@ -196,6 +206,18 @@ def base_test_map_where_filter(self): # filter self.assertEqual(list(col.filter([True, False] * 3)), [None, None, 4]) + with self.assertRaisesRegex( + expected_exception=TypeError, + expected_regex="columns parameter for flat columns not supported", + ): + col.filter([True, False], columns=["test", "test2"]) + + with self.assertRaisesRegex( + expected_exception=TypeError, + expected_regex="predicate must be a unary boolean predicate or iterable of booleans", + ): + col.filter(123) + @staticmethod def _accumulate(col, val): if len(col) == 0: @@ -217,6 +239,26 @@ def base_test_reduce(self): ) self.assertEqual(list(d), [1, 3, 6]) + col_no_init = c.reduce( + fun=operator.add, + ) + self.assertEqual(sum(c), col_no_init) + + c_empty = ta.column(dtype=dt.int64, device=self.device) + result = c_empty.reduce( + fun=TestNumericalColumn._accumulate, + initializer=c, + ) + self.assertTrue(all(c == result)) + + with self.assertRaisesRegex( + expected_exception=TypeError, + expected_regex="reduce of empty sequence with no initial value", + ): + c_empty.reduce( + fun=TestNumericalColumn._accumulate, + ) + def base_test_sort_stuff(self): col = ta.column([2, 1, 3], device=self.device) @@ -795,6 +837,46 @@ def base_test_batch_collate(self): it = c.batch(2) self.assertEqual(list(Column.unbatch(it)), [1, 2, 3, 4, 5, 6, 7]) + def base_test_str(self): + c = ta.column(list(range(5)), device=self.device) + c.id = 123 + + expected = "Column([0, 1, 2, 3, 4], id = 123)" + self.assertEqual(expected, str(c)) + + def base_test_repr(self): + c = ta.column(list(range(5)), device=self.device) + expected_repr = ( + "0 0\n" + "1 1\n" + "2 2\n" + "3 3\n" + "4 4\n" + f"dtype: int64, length: 5, null_count: 0, device: {self.device}" + ) + + self.assertEqual(expected_repr, repr(c)) + + def base_test_to_pandas(self): + c_repr = list(range(10)) + c = ta.column(c_repr, device=self.device) + expected = pd.Series(c_repr) + self.assertTrue(all(expected == c.to_pandas())) + + def base_test_transform(self): + c_repr = list(range(10)) + c = ta.column(c_repr, device=self.device) + + result = c.transform(lambda x: x * 10) + + self.assertEqual([x * 10 for x in c_repr], list(result)) + + with self.assertRaisesRegex( + expected_exception=TypeError, + expected_regex="columns parameter for flat columns not supported", + ): + c.transform(lambda x: x * 10, columns=["test"]) + if __name__ == "__main__": diff --git a/torcharrow/test/test_numerical_column_cpu.py b/torcharrow/test/test_numerical_column_cpu.py index a4b6a790c..1ea74cc00 100644 --- a/torcharrow/test/test_numerical_column_cpu.py +++ b/torcharrow/test/test_numerical_column_cpu.py @@ -83,6 +83,18 @@ def test_batch_collate(self): def test_cast(self): return self.base_test_cast() + def test_str(self): + self.base_test_str() + + def test_repr(self): + self.base_test_repr() + + def test_to_pandas(self): + self.base_test_to_pandas() + + def test_transform(self): + self.base_test_transform() + if __name__ == "__main__": unittest.main()