|
6 | 6 |
|
7 | 7 | import typing as ty |
8 | 8 | import unittest |
| 9 | +from unittest import mock |
9 | 10 |
|
10 | 11 | import torcharrow as ta |
11 | 12 | import torcharrow.dtypes as dt |
@@ -345,6 +346,222 @@ def base_test_regular_expressions(self): |
345 | 346 | ], |
346 | 347 | ) |
347 | 348 |
|
| 349 | + def base_test_is_unique(self): |
| 350 | + unique_column = ta.column( |
| 351 | + [f"test{x}" for x in range(3)], |
| 352 | + device=self.device, |
| 353 | + ) |
| 354 | + |
| 355 | + self.assertTrue(unique_column.is_unique) |
| 356 | + |
| 357 | + non_unique_column = ta.column( |
| 358 | + [ |
| 359 | + "test", |
| 360 | + "test", |
| 361 | + ], |
| 362 | + device=self.device, |
| 363 | + ) |
| 364 | + |
| 365 | + self.assertFalse(non_unique_column.is_unique) |
| 366 | + |
| 367 | + def base_test_is_monotonic_increasing(self): |
| 368 | + c = ta.column([f"test{x}" for x in range(5)], device=self.device) |
| 369 | + self.assertTrue(c.is_monotonic_increasing) |
| 370 | + self.assertFalse(c.is_monotonic_decreasing) |
| 371 | + |
| 372 | + def base_test_is_monotonic_decreasing(self): |
| 373 | + c = ta.column([f"test{x}" for x in range(5, 0, -1)], device=self.device) |
| 374 | + self.assertFalse(c.is_monotonic_increasing) |
| 375 | + self.assertTrue(c.is_monotonic_decreasing) |
| 376 | + |
| 377 | + def base_test_if_else(self): |
| 378 | + left_repr = ["a1", "a2", "a3", "a4"] |
| 379 | + right_repr = ["b1", "b2", "b3", "b4"] |
| 380 | + cond_repr = [True, False, True, False] |
| 381 | + cond = ta.column(cond_repr, device=self.device) |
| 382 | + left = ta.column(left_repr, device=self.device) |
| 383 | + right = ta.column(right_repr, device=self.device) |
| 384 | + |
| 385 | + # Ensure py-iterables work as intended |
| 386 | + expected = [left_repr[0], right_repr[1], left_repr[2], right_repr[3]] |
| 387 | + result = ta.if_else(cond, left_repr, right_repr) |
| 388 | + self.assertEqual(expected, list(result)) |
| 389 | + |
| 390 | + # Non common d type |
| 391 | + with self.assertRaisesRegex( |
| 392 | + expected_exception=TypeError, |
| 393 | + expected_regex="then and else branches must have compatible types, got.*and.*, respectively", |
| 394 | + ), mock.patch("torcharrow.icolumn.dt.common_dtype") as mock_common_dtype: |
| 395 | + mock_common_dtype.return_value = None |
| 396 | + |
| 397 | + ta.if_else(cond, left, right) |
| 398 | + |
| 399 | + mock_common_dtype.assert_called_once_with( |
| 400 | + left.dtype, |
| 401 | + right.dtype, |
| 402 | + ) |
| 403 | + |
| 404 | + with self.assertRaisesRegex( |
| 405 | + expected_exception=TypeError, |
| 406 | + expected_regex="then and else branches must have compatible types, got.*and.*, respectively", |
| 407 | + ), mock.patch( |
| 408 | + "torcharrow.icolumn.dt.common_dtype" |
| 409 | + ) as mock_common_dtype, mock.patch( |
| 410 | + "torcharrow.icolumn.dt.is_void" |
| 411 | + ) as mock_is_void: |
| 412 | + mock_is_void.return_value = True |
| 413 | + mock_common_dtype.return_value = dt.int64 |
| 414 | + |
| 415 | + ta.if_else(cond, left, right) |
| 416 | + mock_common_dtype.assert_called_once_with( |
| 417 | + left.dtype, |
| 418 | + right.dtype, |
| 419 | + ) |
| 420 | + mock_is_void.assert_called_once_with(mock_common_dtype.return_value) |
| 421 | + |
| 422 | + # Invalid condition input |
| 423 | + with self.assertRaisesRegex( |
| 424 | + expected_exception=TypeError, |
| 425 | + expected_regex="condition must be a boolean vector", |
| 426 | + ): |
| 427 | + ta.if_else( |
| 428 | + cond=left, |
| 429 | + left=left, |
| 430 | + right=right, |
| 431 | + ) |
| 432 | + |
| 433 | + def base_test_str(self): |
| 434 | + c = ta.column([f"test{x}" for x in range(5)], device=self.device) |
| 435 | + c.id = 321 |
| 436 | + |
| 437 | + expected = "Column(['test0', 'test1', 'test2', 'test3', 'test4'], id = 321)" |
| 438 | + self.assertEqual(expected, str(c)) |
| 439 | + |
| 440 | + def base_test_repr(self): |
| 441 | + c = ta.column([f"test{x}" for x in range(5)], device=self.device) |
| 442 | + |
| 443 | + expected = ( |
| 444 | + "0 'test0'\n" |
| 445 | + "1 'test1'\n" |
| 446 | + "2 'test2'\n" |
| 447 | + "3 'test3'\n" |
| 448 | + "4 'test4'\n" |
| 449 | + f"dtype: string, length: 5, null_count: 0, device: {self.device}" |
| 450 | + ) |
| 451 | + self.assertEqual(expected, repr(c)) |
| 452 | + |
| 453 | + def base_test_is_valid_at(self): |
| 454 | + c = ta.column([f"test{x}" for x in range(5)], device=self.device) |
| 455 | + |
| 456 | + # Normal access |
| 457 | + self.assertTrue(all(c.is_valid_at(x) for x in range(5))) |
| 458 | + |
| 459 | + # Negative access |
| 460 | + self.assertTrue(c.is_valid_at(-1)) |
| 461 | + |
| 462 | + def base_test_cast(self): |
| 463 | + c_repr = ["0", "1", "2", "3", "4", None] |
| 464 | + c_repr_after_cast = [0, 1, 2, 3, 4, None] |
| 465 | + c = ta.column(c_repr, device=self.device) |
| 466 | + |
| 467 | + result = c.cast(dt.int64) |
| 468 | + self.assertEqual(c_repr_after_cast, list(result)) |
| 469 | + |
| 470 | + def base_test_drop_null(self): |
| 471 | + c_repr = ["0", "1", "2", "3", "4", None] |
| 472 | + c = ta.column(c_repr, device=self.device) |
| 473 | + |
| 474 | + result = c.drop_null() |
| 475 | + |
| 476 | + self.assertEqual(c_repr[:-1], list(result)) |
| 477 | + |
| 478 | + with self.assertRaisesRegex( |
| 479 | + expected_exception=TypeError, |
| 480 | + expected_regex="how parameter for flat columns not supported", |
| 481 | + ): |
| 482 | + c.drop_null(how="any") |
| 483 | + |
| 484 | + def base_test_drop_duplicates(self): |
| 485 | + c_repr = ["test", "test2", "test3", "test"] |
| 486 | + c = ta.column(c_repr, device=self.device) |
| 487 | + |
| 488 | + result = c.drop_duplicates() |
| 489 | + |
| 490 | + self.assertEqual(c_repr[:-1], list(result)) |
| 491 | + |
| 492 | + # TODO: Add functionality for last |
| 493 | + with self.assertRaises(expected_exception=AssertionError): |
| 494 | + c.drop_duplicates(keep="last") |
| 495 | + |
| 496 | + with self.assertRaisesRegex( |
| 497 | + expected_exception=TypeError, |
| 498 | + expected_regex="subset parameter for flat columns not supported", |
| 499 | + ): |
| 500 | + c.drop_duplicates(subset=c_repr[:2]) |
| 501 | + |
| 502 | + def base_test_fill_null(self): |
| 503 | + c_repr = ["0", "1", None, "3", "4", None] |
| 504 | + expected_fill = "TEST" |
| 505 | + expected_repr = ["0", "1", expected_fill, "3", "4", expected_fill] |
| 506 | + c = ta.column(c_repr, device=self.device) |
| 507 | + |
| 508 | + result = c.fill_null(expected_fill) |
| 509 | + |
| 510 | + self.assertEqual(expected_repr, list(result)) |
| 511 | + |
| 512 | + with self.assertRaisesRegex( |
| 513 | + expected_exception=TypeError, |
| 514 | + expected_regex="fill_null with bytes is not supported", |
| 515 | + ): |
| 516 | + c.fill_null(expected_fill.encode()) |
| 517 | + |
| 518 | + def base_test_isin(self): |
| 519 | + c_repr = [f"test{x}" for x in range(5)] |
| 520 | + c = ta.column(c_repr, device=self.device) |
| 521 | + self.assertTrue(all(c.isin(values=c_repr + ["test_123"]))) |
| 522 | + self.assertFalse(any(c.isin(values=["test5", "test6", "test7"]))) |
| 523 | + |
| 524 | + def base_test_bool(self): |
| 525 | + c = ta.column([f"test{x}" for x in range(5)], device=self.device) |
| 526 | + with self.assertRaisesRegex( |
| 527 | + expected_exception=ValueError, |
| 528 | + expected_regex=r"The truth value of a.*is ambiguous. Use a.any\(\) or a.all\(\).", |
| 529 | + ): |
| 530 | + bool(c) |
| 531 | + |
| 532 | + def base_test_flatmap(self): |
| 533 | + c = ta.column(["test1", "test2", None, None, "test3"], device=self.device) |
| 534 | + expected_result = [ |
| 535 | + "test1", |
| 536 | + "test1", |
| 537 | + "test2", |
| 538 | + "test2", |
| 539 | + None, |
| 540 | + None, |
| 541 | + None, |
| 542 | + None, |
| 543 | + "test3", |
| 544 | + "test3", |
| 545 | + ] |
| 546 | + result = c.flatmap(lambda xs: [xs, xs]) |
| 547 | + self.assertEqual(expected_result, list(result)) |
| 548 | + |
| 549 | + def base_test_any(self): |
| 550 | + c_some = ta.column(["test1", "test2", None, None, "test3"], device=self.device) |
| 551 | + c_none = ta.column([], dtype=dt.string, device=self.device) |
| 552 | + c_none = c_none.append([None]) |
| 553 | + self.assertTrue(any(c_some)) |
| 554 | + self.assertFalse(any(c_none)) |
| 555 | + |
| 556 | + def base_test_all(self): |
| 557 | + c_all = ta.column(["test", "test2", "test3"], device=self.device) |
| 558 | + c_partial = ta.column(["test", "test2", None, None], device=self.device) |
| 559 | + c_none = ta.column([], dtype=dt.string, device=self.device) |
| 560 | + c_none = c_none.append([None]) |
| 561 | + self.assertTrue(all(c_all)) |
| 562 | + self.assertFalse(all(c_partial)) |
| 563 | + self.assertFalse(all(c_none)) |
| 564 | + |
348 | 565 |
|
349 | 566 | if __name__ == "__main__": |
350 | 567 | unittest.main() |
0 commit comments