From ee3e963bf63acb82711d711b41bb65ef49cc83a4 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Mon, 18 Sep 2023 00:19:10 +0000 Subject: [PATCH 1/5] Cache OpSharding in mark_sharding --- torch_xla/csrc/init_python_bindings.cpp | 16 +++++---- torch_xla/csrc/ir.cpp | 10 +++--- torch_xla/experimental/xla_sharding.py | 44 +++++++++++++++---------- 3 files changed, 42 insertions(+), 28 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 2484642aad06..608c707f1285 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1543,18 +1543,20 @@ void InitXlaModuleBindings(py::module m) { weight_decay, eps, amsgrad, maximize, use_adamw); } }); + py::class_(m, "OpSharding") + .def(py::init([](const py::list& tile_assignment, + const py::list& group_assignment, + const py::list& replication_groups, int sharding_type) { + return ShardingUtil::CreateOpSharding( + tile_assignment, group_assignment, replication_groups, + ShardingUtil::ShardingType(sharding_type)); + })); m.def("_xla_mark_sharding", [](const at::Tensor& input, - const py::list& tile_assignment, - const py::list& group_assignment, - const py::list& replication_groups, - int sharding_type) { + xla::OpSharding sharding) { TORCH_LAZY_COUNTER("XlaMarkSharding", 1); XLA_CHECK(UseVirtualDevice()) << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; XLATensorPtr xtensor = bridge::GetXlaTensor(input); - xla::OpSharding sharding = ShardingUtil::CreateOpSharding( - tile_assignment, group_assignment, replication_groups, - ShardingUtil::ShardingType(sharding_type)); auto new_sharding_spec = std::make_shared( sharding, MakeShapeWithDeviceLayout( xtensor->shape(), diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index a52e20b7ed5a..0fc2d77a47fc 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -199,10 +199,12 @@ void XlaNode::UpdateShardingHash() { sharding_hash_ = torch::lazy::HashCombine( sharding_hash_, (uint32_t)tile_assignment_dimension); } - for (const auto& tile_assignment_device : - sharding->tile_assignment_devices()) { - sharding_hash_ = torch::lazy::HashCombine( - sharding_hash_, (uint32_t)tile_assignment_device); + { + const int64_t* data = sharding->tile_assignment_devices().data(); + const size_t size_in_bytes = + sharding->tile_assignment_devices().size() * sizeof(*data); + sharding_hash_ = + torch::lazy::HashBlock(data, size_in_bytes, sharding_hash_); } for (const auto& last_tile_dim : sharding->last_tile_dims()) { sharding_hash_ = diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index f8888728bae2..a8f3969798c1 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -9,7 +9,7 @@ import numpy as np import itertools -from typing import Tuple, Union, List, Sequence, Any, Optional, Set +from typing import Tuple, Union, List, Sequence, Any, Optional, Set, Dict from enum import IntEnum @@ -45,6 +45,7 @@ class Mesh: device_ids: np.ndarray mesh_shape: Tuple[int, ...] axis_names: Tuple[str, ...] + _op_sharding_cache: Dict[Tuple, torch_xla._XLAC.OpSharding] def __init__(self, device_ids: Union[np.ndarray, List], @@ -59,6 +60,7 @@ def __init__(self, self.device_ids = device_ids self.mesh_shape = mesh_shape self.axis_names = axis_names + self._op_sharding_cache = {} assert all(d < self.size() for d in device_ids) def size(self): @@ -79,6 +81,27 @@ def get_axis_name_idx(self, name: str) -> int: return None return self.axis_names.index(name) + def get_op_sharding(self, + partition_spec: Tuple) -> torch_xla._XLAC.OpSharding: + """ + Return the OpSharding for the given partition spec. This is an expensive + operation as the mesh grows, so the value is cached for reuse. + """ + if partition_spec not in self._op_sharding_cache: + tile_assignment = _get_tile_assignment(self, partition_spec) + if len(tile_assignment.shape) > len(partition_spec): + # Use partial replication for sharding a tensor over a higher-rank mesh + sharding_type = ShardingType.PARTIAL + else: + sharding_type = _get_sharding_type(partition_spec, self.size()) + replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} + group_assignment, replication_groups = _get_group_assignment( + sharding_type, tile_assignment, len(partition_spec), replicate_dims) + self._op_sharding_cache[partition_spec] = torch_xla._XLAC.OpSharding( + tile_assignment.tolist(), group_assignment, replication_groups, + int(sharding_type)) + return self._op_sharding_cache[partition_spec] + # HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4 @@ -475,25 +498,12 @@ def mark_sharding( assert len(specs) == len(np.unique(specs)), \ f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." - tile_assignment = _get_tile_assignment(mesh, partition_spec) - if len(tile_assignment.shape) > len(partition_spec): - # Use partial replication for sharding a tensor over a higher-rank mesh - sharding_type = ShardingType.PARTIAL - else: - sharding_type = _get_sharding_type(partition_spec, num_devices) - replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} - group_assignment, replication_groups = _get_group_assignment( - sharding_type, tile_assignment, len(partition_spec), replicate_dims) + op_sharding = mesh.get_op_sharding(partition_spec) if isinstance(t, XLAShardedTensor): - torch_xla._XLAC._xla_mark_sharding(t.global_tensor, - tile_assignment.tolist(), - group_assignment, replication_groups, - int(sharding_type)) + torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding) return t - torch_xla._XLAC._xla_mark_sharding(t, tile_assignment.tolist(), - group_assignment, replication_groups, - int(sharding_type)) + torch_xla._XLAC._xla_mark_sharding(t, op_sharding) return XLAShardedTensor(t) From a7b7cbcf2d2c04479d31c3d79fea0965854934f8 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Mon, 18 Sep 2023 13:59:36 +0000 Subject: [PATCH 2/5] Cache global_runtime_device_count --- torch_xla/runtime.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index a50e92b55a5d..3087f3c80f64 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -205,6 +205,7 @@ def global_runtime_device_attributes() -> List[Dict[str, object]]: @requires_pjrt +@functools.lru_cache() def global_runtime_device_count() -> int: """Returns the total number of runtime devices across all processes/hosts.""" return len(torch_xla._XLAC._xla_get_all_runtime_devices()) From 9828c7893c30d69cecb08f8f0f2b581fe9473903 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Mon, 18 Sep 2023 14:07:52 +0000 Subject: [PATCH 3/5] Add metric and unit test --- test/spmd/test_xla_sharding.py | 19 +++++++++++++++++++ torch_xla/csrc/xla_sharding_util.cpp | 1 + 2 files changed, 20 insertions(+) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 5f278ddec99a..a79807794b25 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -847,6 +847,25 @@ def test_shard_device_data_ir_after_mark_step(self): self.assertNotEqual(torch_xla._XLAC._get_xla_sharding_spec(xla_x), '') self.assertTrue(torch.allclose(xla_x.cpu(), x)) + def test_op_sharding_cache(self): + met.clear_all() + mesh = self._get_mesh((1, self.n_devices)) + + t = torch.randn(1, self.n_devices).to(xm.xla_device()) + xs.mark_sharding(t, mesh, (0, 1)) + self.assertIn("CreateOpSharding", met.counter_names()) + self.assertEqual(met.counter_value("CreateOpSharding"), 1) + + # Sharding with the same partition spec should not result in another call + u = torch.randn(1, self.n_devices).to(xm.xla_device()) + xs.mark_sharding(u, mesh, (0, 1)) + self.assertEqual(met.counter_value("CreateOpSharding"), 1) + + # Changing the partition spec will result in another CreateOpSharding + v = torch.randn(1, self.n_devices).to(xm.xla_device()) + xs.mark_sharding(v, mesh, (0, None)) + self.assertEqual(met.counter_value("CreateOpSharding"), 2) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 16bc70fbc8a5..aadece85c072 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -194,6 +194,7 @@ bool ShardingUtil::EqualOpShardings(const xla::OpSharding& a, xla::OpSharding ShardingUtil::CreateOpSharding( const py::list& tile_assignment, const py::list& group_assignment, const py::list& replication_groups, ShardingType sharding_type) { + TORCH_LAZY_COUNTER("CreateOpSharding", 1); xla::OpSharding sharding; switch (sharding_type) { case ShardingType::MANUAL: { From 588eed182d96adda0cda03ea058bedd13e0a58b0 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Mon, 18 Sep 2023 22:02:58 +0000 Subject: [PATCH 4/5] Use functools.lru_cache in get_op_sharding --- torch_xla/experimental/xla_sharding.py | 31 ++++++++++++-------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index a8f3969798c1..9b53cbf9b65c 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -8,8 +8,9 @@ import torch_xla.runtime as xr import numpy as np +import functools import itertools -from typing import Tuple, Union, List, Sequence, Any, Optional, Set, Dict +from typing import Tuple, Union, List, Sequence, Any, Optional, Set from enum import IntEnum @@ -45,7 +46,6 @@ class Mesh: device_ids: np.ndarray mesh_shape: Tuple[int, ...] axis_names: Tuple[str, ...] - _op_sharding_cache: Dict[Tuple, torch_xla._XLAC.OpSharding] def __init__(self, device_ids: Union[np.ndarray, List], @@ -60,7 +60,6 @@ def __init__(self, self.device_ids = device_ids self.mesh_shape = mesh_shape self.axis_names = axis_names - self._op_sharding_cache = {} assert all(d < self.size() for d in device_ids) def size(self): @@ -81,26 +80,24 @@ def get_axis_name_idx(self, name: str) -> int: return None return self.axis_names.index(name) + @functools.lru_cache(maxsize=None) def get_op_sharding(self, partition_spec: Tuple) -> torch_xla._XLAC.OpSharding: """ Return the OpSharding for the given partition spec. This is an expensive operation as the mesh grows, so the value is cached for reuse. """ - if partition_spec not in self._op_sharding_cache: - tile_assignment = _get_tile_assignment(self, partition_spec) - if len(tile_assignment.shape) > len(partition_spec): - # Use partial replication for sharding a tensor over a higher-rank mesh - sharding_type = ShardingType.PARTIAL - else: - sharding_type = _get_sharding_type(partition_spec, self.size()) - replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} - group_assignment, replication_groups = _get_group_assignment( - sharding_type, tile_assignment, len(partition_spec), replicate_dims) - self._op_sharding_cache[partition_spec] = torch_xla._XLAC.OpSharding( - tile_assignment.tolist(), group_assignment, replication_groups, - int(sharding_type)) - return self._op_sharding_cache[partition_spec] + tile_assignment = _get_tile_assignment(self, partition_spec) + if len(tile_assignment.shape) > len(partition_spec): + # Use partial replication for sharding a tensor over a higher-rank mesh + sharding_type = ShardingType.PARTIAL + else: + sharding_type = _get_sharding_type(partition_spec, self.size()) + replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} + group_assignment, replication_groups = _get_group_assignment( + sharding_type, tile_assignment, len(partition_spec), replicate_dims) + return torch_xla._XLAC.OpSharding(tile_assignment.tolist(), group_assignment, + replication_groups, int(sharding_type)) # HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4 From 3c2d0b4c9a66aad28ea7d28ceb46a69d63c47235 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Mon, 18 Sep 2023 22:05:42 +0000 Subject: [PATCH 5/5] yapf --- torch_xla/experimental/xla_sharding.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 9b53cbf9b65c..95f4a88128bb 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -96,8 +96,9 @@ def get_op_sharding(self, replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} group_assignment, replication_groups = _get_group_assignment( sharding_type, tile_assignment, len(partition_spec), replicate_dims) - return torch_xla._XLAC.OpSharding(tile_assignment.tolist(), group_assignment, - replication_groups, int(sharding_type)) + return torch_xla._XLAC.OpSharding(tile_assignment.tolist(), + group_assignment, replication_groups, + int(sharding_type)) # HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4