Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit 25b6e4d

Browse files
andrewaikens87facebook-github-bot
authored andcommitted
Enforce 90% cc for icolumn.py
Differential Revision: D37289878 fbshipit-source-id: b7daf93df345747b11740b38bd7b8f9fd6659322
1 parent c13aa1b commit 25b6e4d

File tree

4 files changed

+271
-23
lines changed

4 files changed

+271
-23
lines changed

torcharrow/icolumn.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ def filter(
705705
dtype: boolean, length: 2, null_count: 0
706706
"""
707707
if columns is not None:
708-
raise TypeError(f"columns parameter for flat columns not supported")
708+
raise TypeError("columns parameter for flat columns not supported")
709709

710710
if not isinstance(predicate, ty.Iterable) and not callable(predicate):
711711
raise TypeError(
@@ -1006,8 +1006,6 @@ def fill_null(self, fill_value: ty.Union[dt.ScalarTypes, ty.Dict]):
10061006
"""
10071007
self._prototype_support_warning("fill_null")
10081008

1009-
if not isinstance(fill_value, Column._scalar_types):
1010-
raise TypeError(f"fill_null with {type(fill_value)} is not supported")
10111009
if isinstance(fill_value, Column._scalar_types):
10121010
res = Scope._EmptyColumn(self.dtype.constructor(nullable=False))
10131011
for m, i in self._items():
@@ -1017,7 +1015,9 @@ def fill_null(self, fill_value: ty.Union[dt.ScalarTypes, ty.Dict]):
10171015
res._append_value(fill_value)
10181016
return res._finalize()
10191017
else:
1020-
raise TypeError(f"fill_null with {type(fill_value)} is not supported")
1018+
raise TypeError(
1019+
f"fill_null with {type(fill_value).__name__} is not supported"
1020+
)
10211021

10221022
@trace
10231023
@expression
@@ -1050,7 +1050,7 @@ def drop_null(self, how: ty.Literal["any", "all", None] = None):
10501050

10511051
if how is not None:
10521052
# "any or "all" is only used for DataFrame
1053-
raise TypeError(f"how parameter for flat columns not supported")
1053+
raise TypeError("how parameter for flat columns not supported")
10541054

10551055
if dt.is_primitive(self.dtype):
10561056
res = Scope._EmptyColumn(self.dtype.constructor(nullable=False))
@@ -1076,7 +1076,7 @@ def drop_duplicates(
10761076
# TODO Add functionality for first and last
10771077
assert keep == "first"
10781078
if subset is not None:
1079-
raise TypeError(f"subset parameter for flat columns not supported")
1079+
raise TypeError("subset parameter for flat columns not supported")
10801080
res = Scope._EmptyColumn(self._dtype)
10811081
res._extend(list(OrderedDict.fromkeys(self)))
10821082
return res._finalize()

torcharrow/test/test_string_column.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import typing as ty
88
import unittest
9+
from unittest import mock
910

1011
import torcharrow as ta
1112
import torcharrow.dtypes as dt
@@ -345,6 +346,222 @@ def base_test_regular_expressions(self):
345346
],
346347
)
347348

349+
def base_test_is_unique(self):
350+
unique_column = ta.column(
351+
[f"test{x}" for x in range(3)],
352+
device=self.device,
353+
)
354+
355+
self.assertTrue(unique_column.is_unique)
356+
357+
non_unique_column = ta.column(
358+
[
359+
"test",
360+
"test",
361+
],
362+
device=self.device,
363+
)
364+
365+
self.assertFalse(non_unique_column.is_unique)
366+
367+
def base_test_is_monotonic_increasing(self):
368+
c = ta.column([f"test{x}" for x in range(5)], device=self.device)
369+
self.assertTrue(c.is_monotonic_increasing)
370+
self.assertFalse(c.is_monotonic_decreasing)
371+
372+
def base_test_is_monotonic_decreasing(self):
373+
c = ta.column([f"test{x}" for x in range(5, 0, -1)], device=self.device)
374+
self.assertFalse(c.is_monotonic_increasing)
375+
self.assertTrue(c.is_monotonic_decreasing)
376+
377+
def base_test_if_else(self):
378+
left_repr = ["a1", "a2", "a3", "a4"]
379+
right_repr = ["b1", "b2", "b3", "b4"]
380+
cond_repr = [True, False, True, False]
381+
cond = ta.column(cond_repr, device=self.device)
382+
left = ta.column(left_repr, device=self.device)
383+
right = ta.column(right_repr, device=self.device)
384+
385+
# Ensure py-iterables work as intended
386+
expected = [left_repr[0], right_repr[1], left_repr[2], right_repr[3]]
387+
result = ta.if_else(cond, left_repr, right_repr)
388+
self.assertEqual(expected, list(result))
389+
390+
# Non common d type
391+
with self.assertRaisesRegex(
392+
expected_exception=TypeError,
393+
expected_regex="then and else branches must have compatible types, got.*and.*, respectively",
394+
), mock.patch("torcharrow.icolumn.dt.common_dtype") as mock_common_dtype:
395+
mock_common_dtype.return_value = None
396+
397+
ta.if_else(cond, left, right)
398+
399+
mock_common_dtype.assert_called_once_with(
400+
left.dtype,
401+
right.dtype,
402+
)
403+
404+
with self.assertRaisesRegex(
405+
expected_exception=TypeError,
406+
expected_regex="then and else branches must have compatible types, got.*and.*, respectively",
407+
), mock.patch(
408+
"torcharrow.icolumn.dt.common_dtype"
409+
) as mock_common_dtype, mock.patch(
410+
"torcharrow.icolumn.dt.is_void"
411+
) as mock_is_void:
412+
mock_is_void.return_value = True
413+
mock_common_dtype.return_value = dt.int64
414+
415+
ta.if_else(cond, left, right)
416+
mock_common_dtype.assert_called_once_with(
417+
left.dtype,
418+
right.dtype,
419+
)
420+
mock_is_void.assert_called_once_with(mock_common_dtype.return_value)
421+
422+
# Invalid condition input
423+
with self.assertRaisesRegex(
424+
expected_exception=TypeError,
425+
expected_regex="condition must be a boolean vector",
426+
):
427+
ta.if_else(
428+
cond=left,
429+
left=left,
430+
right=right,
431+
)
432+
433+
def base_test_str(self):
434+
c = ta.column([f"test{x}" for x in range(5)], device=self.device)
435+
c.id = 321
436+
437+
expected = "Column(['test0', 'test1', 'test2', 'test3', 'test4'], id = 321)"
438+
self.assertEqual(expected, str(c))
439+
440+
def base_test_repr(self):
441+
c = ta.column([f"test{x}" for x in range(5)], device=self.device)
442+
443+
expected = (
444+
"0 'test0'\n"
445+
"1 'test1'\n"
446+
"2 'test2'\n"
447+
"3 'test3'\n"
448+
"4 'test4'\n"
449+
f"dtype: string, length: 5, null_count: 0, device: {self.device}"
450+
)
451+
self.assertEqual(expected, repr(c))
452+
453+
def base_test_is_valid_at(self):
454+
c = ta.column([f"test{x}" for x in range(5)], device=self.device)
455+
456+
# Normal access
457+
self.assertTrue(all(c.is_valid_at(x) for x in range(5)))
458+
459+
# Negative access
460+
self.assertTrue(c.is_valid_at(-1))
461+
462+
def base_test_cast(self):
463+
c_repr = ["0", "1", "2", "3", "4", None]
464+
c_repr_after_cast = [0, 1, 2, 3, 4, None]
465+
c = ta.column(c_repr, device=self.device)
466+
467+
result = c.cast(dt.int64)
468+
self.assertEqual(c_repr_after_cast, list(result))
469+
470+
def base_test_drop_null(self):
471+
c_repr = ["0", "1", "2", "3", "4", None]
472+
c = ta.column(c_repr, device=self.device)
473+
474+
result = c.drop_null()
475+
476+
self.assertEqual(c_repr[:-1], list(result))
477+
478+
with self.assertRaisesRegex(
479+
expected_exception=TypeError,
480+
expected_regex="how parameter for flat columns not supported",
481+
):
482+
c.drop_null(how="any")
483+
484+
def base_test_drop_duplicates(self):
485+
c_repr = ["test", "test2", "test3", "test"]
486+
c = ta.column(c_repr, device=self.device)
487+
488+
result = c.drop_duplicates()
489+
490+
self.assertEqual(c_repr[:-1], list(result))
491+
492+
# TODO: Add functionality for last
493+
with self.assertRaises(expected_exception=AssertionError):
494+
c.drop_duplicates(keep="last")
495+
496+
with self.assertRaisesRegex(
497+
expected_exception=TypeError,
498+
expected_regex="subset parameter for flat columns not supported",
499+
):
500+
c.drop_duplicates(subset=c_repr[:2])
501+
502+
def base_test_fill_null(self):
503+
c_repr = ["0", "1", None, "3", "4", None]
504+
expected_fill = "TEST"
505+
expected_repr = ["0", "1", expected_fill, "3", "4", expected_fill]
506+
c = ta.column(c_repr, device=self.device)
507+
508+
result = c.fill_null(expected_fill)
509+
510+
self.assertEqual(expected_repr, list(result))
511+
512+
with self.assertRaisesRegex(
513+
expected_exception=TypeError,
514+
expected_regex="fill_null with bytes is not supported",
515+
):
516+
c.fill_null(expected_fill.encode())
517+
518+
def base_test_isin(self):
519+
c_repr = [f"test{x}" for x in range(5)]
520+
c = ta.column(c_repr, device=self.device)
521+
self.assertTrue(all(c.isin(values=c_repr + ["test_123"])))
522+
self.assertFalse(any(c.isin(values=["test5", "test6", "test7"])))
523+
524+
def base_test_bool(self):
525+
c = ta.column([f"test{x}" for x in range(5)], device=self.device)
526+
with self.assertRaisesRegex(
527+
expected_exception=ValueError,
528+
expected_regex=r"The truth value of a.*is ambiguous. Use a.any\(\) or a.all\(\).",
529+
):
530+
bool(c)
531+
532+
def base_test_flatmap(self):
533+
c = ta.column(["test1", "test2", None, None, "test3"], device=self.device)
534+
expected_result = [
535+
"test1",
536+
"test1",
537+
"test2",
538+
"test2",
539+
None,
540+
None,
541+
None,
542+
None,
543+
"test3",
544+
"test3",
545+
]
546+
result = c.flatmap(lambda xs: [xs, xs])
547+
self.assertEqual(expected_result, list(result))
548+
549+
def base_test_any(self):
550+
c_some = ta.column(["test1", "test2", None, None, "test3"], device=self.device)
551+
c_none = ta.column([], dtype=dt.string, device=self.device)
552+
c_none = c_none.append([None])
553+
self.assertTrue(any(c_some))
554+
self.assertFalse(any(c_none))
555+
556+
def base_test_all(self):
557+
c_all = ta.column(["test", "test2", "test3"], device=self.device)
558+
c_partial = ta.column(["test", "test2", None, None], device=self.device)
559+
c_none = ta.column([], dtype=dt.string, device=self.device)
560+
c_none = c_none.append([None])
561+
self.assertTrue(all(c_all))
562+
self.assertFalse(all(c_partial))
563+
self.assertFalse(all(c_none))
564+
348565

349566
if __name__ == "__main__":
350567
unittest.main()

torcharrow/test/test_string_column_cpu.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,54 @@ def test_string_pattern_matching_methods(self):
5757
def test_regular_expressions(self):
5858
self.base_test_regular_expressions()
5959

60+
def test_is_unique(self):
61+
self.base_test_is_unique()
62+
63+
def test_is_monotonic_increasing(self):
64+
self.base_test_is_monotonic_increasing()
65+
66+
def test_is_monotonic_decreasing(self):
67+
self.base_test_is_monotonic_decreasing()
68+
69+
def test_if_else(self):
70+
self.base_test_if_else()
71+
72+
def test_repr(self):
73+
self.base_test_repr()
74+
75+
def test_str(self):
76+
self.base_test_str()
77+
78+
def test_is_valid_at(self):
79+
self.base_test_is_valid_at()
80+
81+
def test_cast(self):
82+
self.base_test_cast()
83+
84+
def test_drop_null(self):
85+
self.base_test_drop_null()
86+
87+
def test_drop_duplicates(self):
88+
self.base_test_drop_duplicates()
89+
90+
def test_fill_null(self):
91+
self.base_test_fill_null()
92+
93+
def test_isin(self):
94+
self.base_test_isin()
95+
96+
def test_bool(self):
97+
self.base_test_bool()
98+
99+
def test_flatmap(self):
100+
self.base_test_flatmap()
101+
102+
def test_any(self):
103+
self.base_test_any()
104+
105+
def test_all(self):
106+
self.base_test_all()
107+
60108

61109
if __name__ == "__main__":
62110
unittest.main()

torcharrow/velox_rt/string_column_cpu.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -187,23 +187,6 @@ def __gt__(self, other):
187187
def __ge__(self, other):
188188
return self._checked_binary_op_call(other, "gte")
189189

190-
# printing ----------------------------------------------------------------
191-
192-
def __str__(self):
193-
def quote(x):
194-
return f"'{x}'"
195-
196-
return f"Column([{', '.join('None' if i is None else quote(i) for i in self)}])"
197-
198-
def __repr__(self):
199-
tab = tabulate(
200-
[["None" if i is None else f"'{i}'"] for i in self],
201-
tablefmt="plain",
202-
showindex=True,
203-
)
204-
typ = f"dtype: {self.dtype}, length: {self.length}, null_count: {self.null_count}, device: cpu"
205-
return tab + dt.NL + typ
206-
207190
# interop
208191
def _to_tensor_default(self):
209192
# there are no string tensors, so we're using regular python list conversion

0 commit comments

Comments
 (0)