diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 6751c244c1e3..0a8bd8ef5d16 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -117,6 +117,21 @@ TEST_F(XLAShardingTest, ShardTensor) { EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); } +TEST_F(XLAShardingTest, EqualShardingSpecs) { + XLATensor::ShardingSpec tiled_2d(xla::HloSharding::Tile({ + {0, 1, 2, 3}, + {4, 5, 6, 7}, + }) + .ToProto()); + XLATensor::ShardingSpec tiled_3d( + xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto()); + XLATensor::ShardingSpec replicated(xla::HloSharding::Replicate().ToProto()); + EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_2d)); + EXPECT_TRUE(!ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_3d)); + EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(replicated, replicated)); + EXPECT_TRUE(!ShardingUtil::EqualShardingSpecs(tiled_2d, replicated)); +} + TEST_F(XLAShardingTest, CreateTensorsData) { if (xla::sys_util::GetEnvString(xla::env::kEnvPjRtDevice, "") == "") { GTEST_SKIP() << "`PJRT_DEVICE` is not set."; diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 3b6a9002baa7..28da8a109a7e 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1456,14 +1456,10 @@ void InitXlaModuleBindings(py::module m) { bool replicated = false, bool manual = false) { xla::OpSharding sharding = ShardingUtil::CreateOpSharding(tile_assignment, replicated, manual); + auto new_sharding_spec = + std::make_shared(sharding); XLATensorPtr xtensor = bridge::GetXlaTensor(input); - // Existing annotation must be cleared explicitly. We do not clear and - // overwrite the existing sharding on user's behalf, since it could lead to - // confusion/error. - XLA_CHECK(xtensor->sharding_spec() == nullptr) - << "Existing annotation must be cleared first."; - at::Tensor cpu_tensor; if (xla::sys_util::GetEnvBool("XLA_USE_SPMD", false) && xtensor->CurrentTensorData().has_value()) { @@ -1472,6 +1468,19 @@ void InitXlaModuleBindings(py::module m) { // the sharded data transfer. cpu_tensor = xtensor->CurrentTensorData().value(); } else { + // A new input tensor is not expected to be sharded. But sometimes, the + // same input is used sharding annotation, in which case we can skip if + // it's the same sharding; however, if it's the same input with a + // different sharding then we block & ask the user to clear the existing + // sharding first. + auto current_sharding_spec = xtensor->sharding_spec(); + if (current_sharding_spec) { + XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, + *current_sharding_spec)) + << "Existing annotation must be cleared first."; + return; + } + // If the at::Tensor data is not present, we need to re-download the // tensor from the physical device to CPU. In that case, the value // must be present on the backend device. @@ -1481,14 +1490,12 @@ void InitXlaModuleBindings(py::module m) { std::vector xla_tensors{xtensor}; cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; } - - auto sharding_spec = std::make_shared(sharding); auto xla_data = CreateTensorsData( std::vector{cpu_tensor}, - std::vector{sharding_spec}, + std::vector{new_sharding_spec}, std::vector{GetVirtualDevice().toString()})[0]; xtensor->SetXlaData(xla_data); - xtensor->SetShardingSpec(*sharding_spec); + xtensor->SetShardingSpec(*new_sharding_spec); }); m.def("_xla_clear_sharding", [](const at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 4ca0d4b92609..ae7c361bee49 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -228,17 +228,23 @@ XLATensor::ShardingSpecPtr XLATensor::sharding_spec() const { return nullptr; } -void XLATensor::SetShardingSpec(const ShardingSpec& sharding_spec) { - XLA_CHECK(GetIrValue().node != nullptr) << "Tyring to access a null cursor"; - dynamic_cast(GetIrValue().node.get()) - ->SetSharding(sharding_spec.sharding); +void XLATensor::SetShardingSpec(const ShardingSpec& sharding) { + // Existing annotation must be cleared explicitly. We do not clear and + // overwrite the existing sharding on the user's behalf. This is a no-op if + // the same sharding already applied. + if (sharding_spec() == nullptr || + !ShardingUtil::EqualShardingSpecs(sharding, *sharding_spec())) { + TORCH_LAZY_COUNTER("SetShardingSpec", 1); + XLA_CHECK(GetIrValue().node != nullptr) << "Tyring to access a null cursor"; + dynamic_cast(GetIrValue().node.get()) + ->SetSharding(sharding.sharding); + } } void XLATensor::ClearShardingSpec() { torch::lazy::Value ir_value = CurrentIrValue(); if (ir_value) { - if (ir_value.node != nullptr) { - dynamic_cast(ir_value.node.get())->ClearSharding(); - } + // This should be a no-op if there is no sharding. + dynamic_cast(ir_value.node.get())->ClearSharding(); } } @@ -292,10 +298,9 @@ void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value) { void XLATensor::AssignIrValue(torch::lazy::Value ir_value) const { ShardingSpecPtr sharding = sharding_spec(); if (sharding != nullptr) { - // Sharded xla_data is accompanied by sharding annotation. - // Use unsynced ir_value or xla_data to hold the annotation. - // TODO(yeounoh): This does not propagate sharding to views. if (!ir_value) { + // Create a tensor node if applicable, re-use the current IR otherwise. + // TODO(yeounoh) this has some performance implications for convolution. ir_value = GetIrValue(); } dynamic_cast(ir_value.node.get()) diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 1007f0e2cbb4..94572e8043e1 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -6,6 +6,7 @@ #include #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" @@ -60,6 +61,11 @@ bool ShardingUtil::SetHloSharding(LoweringContext* lowering_ctx) { return is_sharded; } +bool ShardingUtil::EqualShardingSpecs(const XLATensor::ShardingSpec& a, + const XLATensor::ShardingSpec& b) { + return xla::protobuf_util::ProtobufEquals(a.sharding, b.sharding); +} + xla::OpSharding ShardingUtil::CreateOpSharding(const py::list& tile_assignment, bool replicated, bool manual) { XLA_CHECK(!(replicated && manual)) diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 8d26ab320022..b39ca8f27275 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -18,6 +18,10 @@ class ShardingUtil { // building the computation; otherwise, this is a no-op. static bool SetHloSharding(LoweringContext* lowering_ctx); + // Returns true if two sharding specs are the same. + static bool EqualShardingSpecs(const XLATensor::ShardingSpec& a, + const XLATensor::ShardingSpec& b); + // Create an xla::OpSharding from `tile_assignment` (ndarray). static xla::OpSharding CreateOpSharding(const py::list& tile_assignment, bool replicated = false,