Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,25 @@ def test_shard_device_data_ir_after_mark_step(self):
self.assertNotEqual(torch_xla._XLAC._get_xla_sharding_spec(xla_x), '')
self.assertTrue(torch.allclose(xla_x.cpu(), x))

def test_op_sharding_cache(self):
met.clear_all()
mesh = self._get_mesh((1, self.n_devices))

t = torch.randn(1, self.n_devices).to(xm.xla_device())
xs.mark_sharding(t, mesh, (0, 1))
self.assertIn("CreateOpSharding", met.counter_names())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's smart to use a counter here.

self.assertEqual(met.counter_value("CreateOpSharding"), 1)

# Sharding with the same partition spec should not result in another call
u = torch.randn(1, self.n_devices).to(xm.xla_device())
xs.mark_sharding(u, mesh, (0, 1))
self.assertEqual(met.counter_value("CreateOpSharding"), 1)

# Changing the partition spec will result in another CreateOpSharding
v = torch.randn(1, self.n_devices).to(xm.xla_device())
xs.mark_sharding(v, mesh, (0, None))
self.assertEqual(met.counter_value("CreateOpSharding"), 2)


if __name__ == '__main__':
test = unittest.main()
Expand Down
16 changes: 9 additions & 7 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1543,18 +1543,20 @@ void InitXlaModuleBindings(py::module m) {
weight_decay, eps, amsgrad, maximize, use_adamw);
}
});
py::class_<xla::OpSharding>(m, "OpSharding")
.def(py::init([](const py::list& tile_assignment,
const py::list& group_assignment,
const py::list& replication_groups, int sharding_type) {
return ShardingUtil::CreateOpSharding(
tile_assignment, group_assignment, replication_groups,
ShardingUtil::ShardingType(sharding_type));
}));
m.def("_xla_mark_sharding", [](const at::Tensor& input,
const py::list& tile_assignment,
const py::list& group_assignment,
const py::list& replication_groups,
int sharding_type) {
xla::OpSharding sharding) {
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
xla::OpSharding sharding = ShardingUtil::CreateOpSharding(
tile_assignment, group_assignment, replication_groups,
ShardingUtil::ShardingType(sharding_type));
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(),
Expand Down
10 changes: 6 additions & 4 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,12 @@ void XlaNode::UpdateShardingHash() {
sharding_hash_ = torch::lazy::HashCombine(
sharding_hash_, (uint32_t)tile_assignment_dimension);
}
for (const auto& tile_assignment_device :
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this for loop is expansive as well.

sharding->tile_assignment_devices()) {
sharding_hash_ = torch::lazy::HashCombine(
sharding_hash_, (uint32_t)tile_assignment_device);
{
const int64_t* data = sharding->tile_assignment_devices().data();
const size_t size_in_bytes =
sharding->tile_assignment_devices().size() * sizeof(*data);
sharding_hash_ =
torch::lazy::HashBlock(data, size_in_bytes, sharding_hash_);
}
for (const auto& last_tile_dim : sharding->last_tile_dims()) {
sharding_hash_ =
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ bool ShardingUtil::EqualOpShardings(const xla::OpSharding& a,
xla::OpSharding ShardingUtil::CreateOpSharding(
const py::list& tile_assignment, const py::list& group_assignment,
const py::list& replication_groups, ShardingType sharding_type) {
TORCH_LAZY_COUNTER("CreateOpSharding", 1);
xla::OpSharding sharding;
switch (sharding_type) {
case ShardingType::MANUAL: {
Expand Down
40 changes: 24 additions & 16 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch_xla.runtime as xr

import numpy as np
import functools
import itertools
from typing import Tuple, Union, List, Sequence, Any, Optional, Set
from enum import IntEnum
Expand Down Expand Up @@ -79,6 +80,26 @@ def get_axis_name_idx(self, name: str) -> int:
return None
return self.axis_names.index(name)

@functools.lru_cache(maxsize=None)
def get_op_sharding(self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't you use @functools.lru_cache() like global_runtime_device_count

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't think about functools.lru_cache until handling global_runtime_device_count, but I guess I reimplemented it for get_op_sharding lol! Let's use it here too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set maxsize=None so that nothing gets evicted.

partition_spec: Tuple) -> torch_xla._XLAC.OpSharding:
"""
Return the OpSharding for the given partition spec. This is an expensive
operation as the mesh grows, so the value is cached for reuse.
"""
tile_assignment = _get_tile_assignment(self, partition_spec)
if len(tile_assignment.shape) > len(partition_spec):
# Use partial replication for sharding a tensor over a higher-rank mesh
sharding_type = ShardingType.PARTIAL
else:
sharding_type = _get_sharding_type(partition_spec, self.size())
replicate_dims = {i for i, d in enumerate(partition_spec) if d is None}
group_assignment, replication_groups = _get_group_assignment(
sharding_type, tile_assignment, len(partition_spec), replicate_dims)
return torch_xla._XLAC.OpSharding(tile_assignment.tolist(),
group_assignment, replication_groups,
int(sharding_type))


# HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4

Expand Down Expand Up @@ -475,25 +496,12 @@ def mark_sharding(
assert len(specs) == len(np.unique(specs)), \
f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."

tile_assignment = _get_tile_assignment(mesh, partition_spec)
if len(tile_assignment.shape) > len(partition_spec):
# Use partial replication for sharding a tensor over a higher-rank mesh
sharding_type = ShardingType.PARTIAL
else:
sharding_type = _get_sharding_type(partition_spec, num_devices)
replicate_dims = {i for i, d in enumerate(partition_spec) if d is None}
group_assignment, replication_groups = _get_group_assignment(
sharding_type, tile_assignment, len(partition_spec), replicate_dims)
op_sharding = mesh.get_op_sharding(partition_spec)

if isinstance(t, XLAShardedTensor):
torch_xla._XLAC._xla_mark_sharding(t.global_tensor,
tile_assignment.tolist(),
group_assignment, replication_groups,
int(sharding_type))
torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding)
return t
torch_xla._XLAC._xla_mark_sharding(t, tile_assignment.tolist(),
group_assignment, replication_groups,
int(sharding_type))
torch_xla._XLAC._xla_mark_sharding(t, op_sharding)
return XLAShardedTensor(t)


Expand Down
1 change: 1 addition & 0 deletions torch_xla/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def global_runtime_device_attributes() -> List[Dict[str, object]]:


@requires_pjrt
@functools.lru_cache()
def global_runtime_device_count() -> int:
"""Returns the total number of runtime devices across all processes/hosts."""
return len(torch_xla._XLAC._xla_get_all_runtime_devices())
Expand Down