diff --git a/torcharrow/test/test_map_column.py b/torcharrow/test/test_map_column.py index 13531fdaf..be5754625 100644 --- a/torcharrow/test/test_map_column.py +++ b/torcharrow/test/test_map_column.py @@ -63,6 +63,58 @@ def base_test_keys_values_get(self): self.assertEqual(list(c.maps.values()), [[123], [45, 67], None]) self.assertEqual(list(c.maps.get("de", 0)), [0, 45, None]) + def base_test_get_operator(self): + col_rep = [ + {"helsinki": [-1.0, 21.0], "moscow": [-4.0, 24.0]}, + {}, + {"nowhere": [], "algiers": [11.0, 25, 2], "kinshasa": [22.0, 26.0]}, + ] + c = ta.column( + col_rep, + device=self.device, + ) + indicies = [0, 2] + expected = [col_rep[i] for i in indicies] + result = [c[i] for i in indicies] + self.assertEqual(expected, result) + + def base_test_slice_operation(self): + col_rep = [ + {"helsinki": [-1.0, 21.0], "moscow": [-4.0, 24.0]}, + {}, + {"nowhere": [], "algiers": [11.0, 25, 2], "kinshasa": [22.0, 26.0]}, + {"london": [], "new york": [500]}, + ] + c = ta.column( + col_rep, + device=self.device, + ) + expected_slice_every_other = col_rep[0:4:2] + result_every_other = c[0:4:2] + self.assertEqual(expected_slice_every_other, list(result_every_other)) + + expected_slice_most = col_rep[1:] + result_most = c[1:4:1] + self.assertEqual(expected_slice_most, list(result_most)) + + def base_test_equality_operators(self): + col_rep = [ + {"helsinki": [-1.0, 21.0], "moscow": [-4.0, 24.0]}, + {"boston": [-4.0]}, + {"nowhere": [], "algiers": [11.0, 25, 2], "kinshasa": [22.0, 26.0]}, + {"london": [], "new york": [500]}, + ] + c = ta.column( + col_rep, + device=self.device, + ) + c2 = ta.column( + col_rep, + device=self.device, + ) + self.assertTrue(all(c == c2)) + self.assertFalse(any(c != c2)) + if __name__ == "__main__": unittest.main() diff --git a/torcharrow/test/test_map_column_cpu.py b/torcharrow/test/test_map_column_cpu.py index 59b521447..5e5d7f3a0 100644 --- a/torcharrow/test/test_map_column_cpu.py +++ b/torcharrow/test/test_map_column_cpu.py @@ -22,6 +22,15 @@ def test_infer(self): def test_keys_values_get(self): self.base_test_keys_values_get() + def test_get_operator(self): + self.base_test_get_operator() + + def test_slice_operation(self): + self.base_test_slice_operation() + + def test_equality_operators(self): + self.base_test_equality_operators() + if __name__ == "__main__": unittest.main()