From 6dda80d3ccb47346551e7acb817c2bf58d8184d8 Mon Sep 17 00:00:00 2001 From: Andrew Aikens Date: Wed, 6 Jul 2022 08:11:01 -0700 Subject: [PATCH 1/2] Improve icolumn.py cc 2/n test_map_column Differential Revision: D37492649 fbshipit-source-id: d7d51e753f6da86fa01984120fc802974e0d279f --- torcharrow/test/test_map_column.py | 52 ++++++++++++++++++++++++++ torcharrow/test/test_map_column_cpu.py | 9 +++++ 2 files changed, 61 insertions(+) diff --git a/torcharrow/test/test_map_column.py b/torcharrow/test/test_map_column.py index 13531fdaf..be5754625 100644 --- a/torcharrow/test/test_map_column.py +++ b/torcharrow/test/test_map_column.py @@ -63,6 +63,58 @@ def base_test_keys_values_get(self): self.assertEqual(list(c.maps.values()), [[123], [45, 67], None]) self.assertEqual(list(c.maps.get("de", 0)), [0, 45, None]) + def base_test_get_operator(self): + col_rep = [ + {"helsinki": [-1.0, 21.0], "moscow": [-4.0, 24.0]}, + {}, + {"nowhere": [], "algiers": [11.0, 25, 2], "kinshasa": [22.0, 26.0]}, + ] + c = ta.column( + col_rep, + device=self.device, + ) + indicies = [0, 2] + expected = [col_rep[i] for i in indicies] + result = [c[i] for i in indicies] + self.assertEqual(expected, result) + + def base_test_slice_operation(self): + col_rep = [ + {"helsinki": [-1.0, 21.0], "moscow": [-4.0, 24.0]}, + {}, + {"nowhere": [], "algiers": [11.0, 25, 2], "kinshasa": [22.0, 26.0]}, + {"london": [], "new york": [500]}, + ] + c = ta.column( + col_rep, + device=self.device, + ) + expected_slice_every_other = col_rep[0:4:2] + result_every_other = c[0:4:2] + self.assertEqual(expected_slice_every_other, list(result_every_other)) + + expected_slice_most = col_rep[1:] + result_most = c[1:4:1] + self.assertEqual(expected_slice_most, list(result_most)) + + def base_test_equality_operators(self): + col_rep = [ + {"helsinki": [-1.0, 21.0], "moscow": [-4.0, 24.0]}, + {"boston": [-4.0]}, + {"nowhere": [], "algiers": [11.0, 25, 2], "kinshasa": [22.0, 26.0]}, + {"london": [], "new york": [500]}, + ] + c = ta.column( + col_rep, + device=self.device, + ) + c2 = ta.column( + col_rep, + device=self.device, + ) + self.assertTrue(all(c == c2)) + self.assertFalse(any(c != c2)) + if __name__ == "__main__": unittest.main() diff --git a/torcharrow/test/test_map_column_cpu.py b/torcharrow/test/test_map_column_cpu.py index 59b521447..5e5d7f3a0 100644 --- a/torcharrow/test/test_map_column_cpu.py +++ b/torcharrow/test/test_map_column_cpu.py @@ -22,6 +22,15 @@ def test_infer(self): def test_keys_values_get(self): self.base_test_keys_values_get() + def test_get_operator(self): + self.base_test_get_operator() + + def test_slice_operation(self): + self.base_test_slice_operation() + + def test_equality_operators(self): + self.base_test_equality_operators() + if __name__ == "__main__": unittest.main() From ba54d41024fa1c34f16fa0bdbcddaa19542419ce Mon Sep 17 00:00:00 2001 From: andrewaikens87 <16356240+andrewaikens87@users.noreply.github.com> Date: Wed, 6 Jul 2022 08:12:07 -0700 Subject: [PATCH 2/2] Improve icolumn.py cc 3/n test_list_column (#413) Summary: Pull Request resolved: https://github.com/pytorch/torcharrow/pull/413 Pull Request resolved: https://github.com/pytorch/torcharrow/pull/408 Improves test coverage for icolumn.py through test_list_column. Reviewed By: wenleix Differential Revision: D37492647 fbshipit-source-id: 9232d7b7ec20f326de08e8a6f9351da6f5830508 --- torcharrow/icolumn.py | 4 ++-- torcharrow/test/test_list_column.py | 17 +++++++++++++++++ torcharrow/test/test_list_column_cpu.py | 3 +++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/torcharrow/icolumn.py b/torcharrow/icolumn.py index dc6c1ff34..03ef09123 100644 --- a/torcharrow/icolumn.py +++ b/torcharrow/icolumn.py @@ -254,8 +254,8 @@ def cast(self, dtype): res._append_value(fun(i)) return res._finalize() else: - raise TypeError('f"{astype}({dtype}) is not supported")') - raise TypeError('f"{astype} for {type(self).__name__} is not supported")') + raise TypeError(f"{dtype} for {type(self).__name__} is not supported") + raise TypeError(f"{self.dtype} for {type(self).__name__} is not supported") # public simple observers ------------------------------------------------- diff --git a/torcharrow/test/test_list_column.py b/torcharrow/test/test_list_column.py index c81c1efa7..b18db0b52 100644 --- a/torcharrow/test/test_list_column.py +++ b/torcharrow/test/test_list_column.py @@ -208,6 +208,23 @@ def base_test_fixed_size_list(self): f"Unexpected failure reason: {str(ex.exception)}", ) + def base_test_cast(self): + list_dtype = dt.List(item_dtype=dt.int64, fixed_size=2) + c_list = ta.column( + [[1, 2], [3, 4]], + dtype=list_dtype, + device=self.device, + ) + + int_dtype = dt.int64 + # TODO: Nested cast should be supported in the future + for arg in (int_dtype, list_dtype): + with self.assertRaisesRegexp( + expected_exception=TypeError, + expected_regex=r"List\(int64, fixed_size=2\) for.*is not supported", + ): + c_list.cast(arg) + if __name__ == "__main__": unittest.main() diff --git a/torcharrow/test/test_list_column_cpu.py b/torcharrow/test/test_list_column_cpu.py index bd774bc7f..72db1bae7 100644 --- a/torcharrow/test/test_list_column_cpu.py +++ b/torcharrow/test/test_list_column_cpu.py @@ -46,6 +46,9 @@ def test_map_reduce_etc(self): def test_fixed_size_list(self): self.base_test_fixed_size_list() + def test_cast(self): + self.base_test_cast() + if __name__ == "__main__": unittest.main()