Skip to content

Commit 6e90569

Browse files
authored
mark_sharding over a replicated tensor is allowed. (#5513)
1 parent 33f1cd2 commit 6e90569

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

test/spmd/test_xla_sharding.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,16 @@ def test_clear_sharding(self):
442442
xs.clear_sharding(xt)
443443
self.assertFalse(torch_xla._XLAC._get_xla_sharding_spec(xt))
444444

445+
def test_replication_with_no_clear_sharding(self):
446+
xt = torch.randn(2, 4).to(xm.xla_device())
447+
# replication
448+
xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (None, None))
449+
# sharding annotation over an existing replication sharding is permitted.
450+
xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (0, 1))
451+
if self.n_devices > 1:
452+
self.assertFalse(
453+
"replicated" in torch_xla._XLAC._get_xla_sharding_spec(xt))
454+
445455
def test_deep_copy(self):
446456
xt = torch.randn(2, 4, 8, 16).to(xm.xla_device())
447457
xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)),

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,12 +1432,13 @@ void InitXlaModuleBindings(py::module m) {
14321432
cpu_tensor = xtensor->CurrentTensorData().value();
14331433
} else {
14341434
// A new input tensor is not expected to be sharded. But sometimes,
1435-
// the same input is used sharding annotation, in which case we can
1436-
// skip if it's the same sharding; however, if it's the same input
1437-
// with a different sharding then we block & ask the user to clear
1438-
// the existing sharding first.
1435+
// the same input is called for sharding annotation over multiple steps,
1436+
// in which case we can skip if it's the same sharding; however, if it's
1437+
// the same input with a different sharding then we block & ask the user
1438+
// to clear the existing sharding first.
14391439
auto current_sharding_spec = xtensor->sharding_spec();
1440-
if (current_sharding_spec) {
1440+
if (current_sharding_spec && (current_sharding_spec->sharding.type() !=
1441+
xla::OpSharding::REPLICATED)) {
14411442
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
14421443
*current_sharding_spec))
14431444
<< "Existing annotation must be cleared first.";

torch_xla/csrc/tensor.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,8 @@ void XLATensor::SetShardingSpec(const ShardingSpec& sharding) {
237237
// Existing annotation must be cleared explicitly. We do not clear and
238238
// overwrite the existing sharding on the user's behalf. This is a no-op if
239239
// the same sharding already applied.
240-
if (!sharding_spec()) {
240+
if (!sharding_spec() ||
241+
(sharding_spec()->sharding.type() == xla::OpSharding::REPLICATED)) {
241242
TORCH_LAZY_COUNTER("SetShardingSpec", 1);
242243
data()->sharding = std::make_shared<ShardingSpec>(sharding);
243244
} else {

0 commit comments

Comments
 (0)