From 0064983e80c214ff7a26f9870194b493b95ab00f Mon Sep 17 00:00:00 2001 From: andrewaikens87 <16356240+andrewaikens87@users.noreply.github.com> Date: Wed, 6 Jul 2022 08:18:07 -0700 Subject: [PATCH] Improve icolumn.py cc 2/4 test_map_column (#416) Summary: Pull Request resolved: https://github.com/pytorch/torcharrow/pull/416 Pull Request resolved: https://github.com/pytorch/torcharrow/pull/406 Improves test coverage for icolumn.py through test_map_column. Reviewed By: wenleix Differential Revision: D37492649 fbshipit-source-id: c31250ac49a8759519d57f1578205741a361630c --- torcharrow/test/test_map_column.py | 52 ++++++++++++++++++++++++++ torcharrow/test/test_map_column_cpu.py | 9 +++++ 2 files changed, 61 insertions(+) diff --git a/torcharrow/test/test_map_column.py b/torcharrow/test/test_map_column.py index 13531fdaf..be5754625 100644 --- a/torcharrow/test/test_map_column.py +++ b/torcharrow/test/test_map_column.py @@ -63,6 +63,58 @@ def base_test_keys_values_get(self): self.assertEqual(list(c.maps.values()), [[123], [45, 67], None]) self.assertEqual(list(c.maps.get("de", 0)), [0, 45, None]) + def base_test_get_operator(self): + col_rep = [ + {"helsinki": [-1.0, 21.0], "moscow": [-4.0, 24.0]}, + {}, + {"nowhere": [], "algiers": [11.0, 25, 2], "kinshasa": [22.0, 26.0]}, + ] + c = ta.column( + col_rep, + device=self.device, + ) + indicies = [0, 2] + expected = [col_rep[i] for i in indicies] + result = [c[i] for i in indicies] + self.assertEqual(expected, result) + + def base_test_slice_operation(self): + col_rep = [ + {"helsinki": [-1.0, 21.0], "moscow": [-4.0, 24.0]}, + {}, + {"nowhere": [], "algiers": [11.0, 25, 2], "kinshasa": [22.0, 26.0]}, + {"london": [], "new york": [500]}, + ] + c = ta.column( + col_rep, + device=self.device, + ) + expected_slice_every_other = col_rep[0:4:2] + result_every_other = c[0:4:2] + self.assertEqual(expected_slice_every_other, list(result_every_other)) + + expected_slice_most = col_rep[1:] + result_most = c[1:4:1] + self.assertEqual(expected_slice_most, list(result_most)) + + def base_test_equality_operators(self): + col_rep = [ + {"helsinki": [-1.0, 21.0], "moscow": [-4.0, 24.0]}, + {"boston": [-4.0]}, + {"nowhere": [], "algiers": [11.0, 25, 2], "kinshasa": [22.0, 26.0]}, + {"london": [], "new york": [500]}, + ] + c = ta.column( + col_rep, + device=self.device, + ) + c2 = ta.column( + col_rep, + device=self.device, + ) + self.assertTrue(all(c == c2)) + self.assertFalse(any(c != c2)) + if __name__ == "__main__": unittest.main() diff --git a/torcharrow/test/test_map_column_cpu.py b/torcharrow/test/test_map_column_cpu.py index 59b521447..5e5d7f3a0 100644 --- a/torcharrow/test/test_map_column_cpu.py +++ b/torcharrow/test/test_map_column_cpu.py @@ -22,6 +22,15 @@ def test_infer(self): def test_keys_values_get(self): self.base_test_keys_values_get() + def test_get_operator(self): + self.base_test_get_operator() + + def test_slice_operation(self): + self.base_test_slice_operation() + + def test_equality_operators(self): + self.base_test_equality_operators() + if __name__ == "__main__": unittest.main()