Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,21 @@ TEST_F(XLAShardingTest, ShardTensor) {
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({10, 1, 4, 4, 2}));
}

TEST_F(XLAShardingTest, EqualShardingSpecs) {
XLATensor::ShardingSpec tiled_2d(xla::HloSharding::Tile({
{0, 1, 2, 3},
{4, 5, 6, 7},
})
.ToProto());
XLATensor::ShardingSpec tiled_3d(
xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto());
XLATensor::ShardingSpec replicated(xla::HloSharding::Replicate().ToProto());
EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_2d));
EXPECT_TRUE(!ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_3d));
EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(replicated, replicated));
EXPECT_TRUE(!ShardingUtil::EqualShardingSpecs(tiled_2d, replicated));
}

TEST_F(XLAShardingTest, CreateTensorsData) {
if (xla::sys_util::GetEnvString(xla::env::kEnvPjRtDevice, "") == "") {
GTEST_SKIP() << "`PJRT_DEVICE` is not set.";
Expand Down
27 changes: 17 additions & 10 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1456,14 +1456,10 @@ void InitXlaModuleBindings(py::module m) {
bool replicated = false, bool manual = false) {
xla::OpSharding sharding =
ShardingUtil::CreateOpSharding(tile_assignment, replicated, manual);
auto new_sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding);
XLATensorPtr xtensor = bridge::GetXlaTensor(input);

// Existing annotation must be cleared explicitly. We do not clear and
// overwrite the existing sharding on user's behalf, since it could lead to
// confusion/error.
XLA_CHECK(xtensor->sharding_spec() == nullptr)
<< "Existing annotation must be cleared first.";

at::Tensor cpu_tensor;
if (xla::sys_util::GetEnvBool("XLA_USE_SPMD", false) &&
xtensor->CurrentTensorData().has_value()) {
Expand All @@ -1472,6 +1468,19 @@ void InitXlaModuleBindings(py::module m) {
// the sharded data transfer.
cpu_tensor = xtensor->CurrentTensorData().value();
} else {
// A new input tensor is not expected to be sharded. But sometimes, the
// same input is used sharding annotation, in which case we can skip if
// it's the same sharding; however, if it's the same input with a
// different sharding then we block & ask the user to clear the existing
// sharding first.
auto current_sharding_spec = xtensor->sharding_spec();
if (current_sharding_spec) {
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
*current_sharding_spec))
<< "Existing annotation must be cleared first.";
return;
}

// If the at::Tensor data is not present, we need to re-download the
// tensor from the physical device to CPU. In that case, the value
// must be present on the backend device.
Expand All @@ -1481,14 +1490,12 @@ void InitXlaModuleBindings(py::module m) {
std::vector<XLATensorPtr> xla_tensors{xtensor};
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
}

auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(sharding);
auto xla_data = CreateTensorsData(
std::vector<at::Tensor>{cpu_tensor},
std::vector<XLATensor::ShardingSpecPtr>{sharding_spec},
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
std::vector<std::string>{GetVirtualDevice().toString()})[0];
xtensor->SetXlaData(xla_data);
xtensor->SetShardingSpec(*sharding_spec);
xtensor->SetShardingSpec(*new_sharding_spec);
});
m.def("_xla_clear_sharding", [](const at::Tensor& input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
Expand Down
25 changes: 15 additions & 10 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,23 @@ XLATensor::ShardingSpecPtr XLATensor::sharding_spec() const {
return nullptr;
}

void XLATensor::SetShardingSpec(const ShardingSpec& sharding_spec) {
XLA_CHECK(GetIrValue().node != nullptr) << "Tyring to access a null cursor";
dynamic_cast<XlaNode*>(GetIrValue().node.get())
->SetSharding(sharding_spec.sharding);
void XLATensor::SetShardingSpec(const ShardingSpec& sharding) {
// Existing annotation must be cleared explicitly. We do not clear and
// overwrite the existing sharding on the user's behalf. This is a no-op if
// the same sharding already applied.
if (sharding_spec() == nullptr ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we throw an error or log something if we call SetShardingSpec on a tensor with different sharding?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

silently no-op seems dangerous.

!ShardingUtil::EqualShardingSpecs(sharding, *sharding_spec())) {
TORCH_LAZY_COUNTER("SetShardingSpec", 1);
XLA_CHECK(GetIrValue().node != nullptr) << "Tyring to access a null cursor";
dynamic_cast<XlaNode*>(GetIrValue().node.get())
->SetSharding(sharding.sharding);
}
}
void XLATensor::ClearShardingSpec() {
torch::lazy::Value ir_value = CurrentIrValue();
if (ir_value) {
if (ir_value.node != nullptr) {
dynamic_cast<XlaNode*>(ir_value.node.get())->ClearSharding();
}
// This should be a no-op if there is no sharding.
dynamic_cast<XlaNode*>(ir_value.node.get())->ClearSharding();
}
}

Expand Down Expand Up @@ -292,10 +298,9 @@ void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value) {
void XLATensor::AssignIrValue(torch::lazy::Value ir_value) const {
ShardingSpecPtr sharding = sharding_spec();
if (sharding != nullptr) {
// Sharded xla_data is accompanied by sharding annotation.
// Use unsynced ir_value or xla_data to hold the annotation.
// TODO(yeounoh): This does not propagate sharding to views.
if (!ir_value) {
// Create a tensor node if applicable, re-use the current IR otherwise.
// TODO(yeounoh) this has some performance implications for convolution.
ir_value = GetIrValue();
}
dynamic_cast<XlaNode*>(ir_value.node.get())
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <unordered_map>

#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
Expand Down Expand Up @@ -60,6 +61,11 @@ bool ShardingUtil::SetHloSharding(LoweringContext* lowering_ctx) {
return is_sharded;
}

bool ShardingUtil::EqualShardingSpecs(const XLATensor::ShardingSpec& a,
const XLATensor::ShardingSpec& b) {
return xla::protobuf_util::ProtobufEquals(a.sharding, b.sharding);
}

xla::OpSharding ShardingUtil::CreateOpSharding(const py::list& tile_assignment,
bool replicated, bool manual) {
XLA_CHECK(!(replicated && manual))
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class ShardingUtil {
// building the computation; otherwise, this is a no-op.
static bool SetHloSharding(LoweringContext* lowering_ctx);

// Returns true if two sharding specs are the same.
static bool EqualShardingSpecs(const XLATensor::ShardingSpec& a,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, ShardingSpecsAreEqual

const XLATensor::ShardingSpec& b);

// Create an xla::OpSharding from `tile_assignment` (ndarray).
static xla::OpSharding CreateOpSharding(const py::list& tile_assignment,
bool replicated = false,
Expand Down