Skip to content
33 changes: 33 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,39 @@ def test_mark_manual_sharding(self):
# The following exception cannot be caught somehow.
# xt.global_tensor.cpu()

def test_spmd_full_to_shard_shape(self):
x = torch.zeros(8, 4).to(xm.xla_device())
with self.assertRaises(RuntimeError):
x = torch_xla._XLAC._spmd_full_to_shard_shape(x)

# Sharded shape
xt = xs.mark_sharding(x, self._get_mesh((1, self.n_devices)), (0, 1))
xx = torch_xla._XLAC._spmd_full_to_shard_shape(xt.global_tensor)

hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx])
self.assertEqual(xx.shape, (8, 4 // self.n_devices))
self.assertIn(f'%custom-call.2 = f32[8,{4//self.n_devices}]{{1,0}}', hlo)
self.assertIn(
f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}', hlo)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{manual}")

# It looks like XLA does't like only having manual sharding in the HLO.
# It needs to be paired with SPMDFullToShardShape/SPMDShardToFullShape.
# The following exception cannot be caught somehow.
# xx.cpu()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you intend to keep this xx.cpu?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, it's more like a note that this won't work... I was trying to use with self.assertRaises but that doesn't capture the exception... I have noticed this before too. When libtpu crashed, it's hard to catch it in the py level. Not sure why. Maybe you have some better ideas?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I think I run into similar issue before.. The way I handle it was ugly through

# crash will hapeen in a async execution thread, need to grab the lock again to
# surface that exception
dynamo_res = dynamo_linear(xla_x)
try:
print(dynamo_res)
except:
print('catch')
# it is hard to catch the C++ runtime error in python, instead we can check if
# after printing that dynamo_res is still a placeholder then it means C++ crashed.
self.assertTrue(torch_xla._XLAC._is_placecholder(dynamo_res))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C++ crash on pt level can be caught with self.assertRaise but not libtpu level.... I'm not sure why... yea, not even with this hack...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @will-cromar Do you know how to catch libtpu exception on py? Appreciate your insights.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you can. To make a proper runtime error, you have to raise an exception, and Google internal binaries don't generally do that. I wrote about a similar case in #6700 (comment)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Will. That makes a lot of sense now.


# Replicated shape
x = torch.zeros(8, 4).to(xm.xla_device())
xt = xs.mark_sharding(x, self._get_mesh((self.n_devices, 1)), (None, None))
xx = torch_xla._XLAC._spmd_full_to_shard_shape(xt.global_tensor)

hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx])
self.assertEqual(xx.shape, (8, 4))
self.assertIn(f'%custom-call.2 = f32[8,4]{{1,0}}', hlo)
self.assertIn(
f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}', hlo)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{manual}")


if __name__ == '__main__':
test = unittest.main()
Expand Down
34 changes: 25 additions & 9 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ std::vector<XLATensorPtr> GetXlaTensors(const std::vector<at::Tensor>& tensors,
return xtensors;
}

bool IsNonDeviceDataIR(const at::Tensor& tensor) {
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
return xtensor->CurrentIrValue() &&
!DeviceData::Cast(xtensor->CurrentIrValue().node.get());
}

std::vector<std::vector<int64_t>> CreateReduceGroups(const py::list& groups) {
std::vector<std::vector<int64_t>> replica_groups;
for (auto& group : groups) {
Expand Down Expand Up @@ -1939,16 +1945,26 @@ void InitXlaModuleBindings(py::module m) {
[](const at::Tensor& input, xla::OpSharding sharding) {
ShardingUtil::XlaMarkSharding(input, sharding);
});
m.def("_mark_manual_sharding", [](const at::Tensor& input,
xla::OpSharding sharding) {
m.def("_mark_manual_sharding",
[](const at::Tensor& input, xla::OpSharding sharding) {
XLA_CHECK(IsNonDeviceDataIR(input))
<< "Marking any data tensors as manual is not supported";
ShardingUtil::XlaMarkSharding(input, sharding);
});
m.def("_spmd_full_to_shard_shape", [](const at::Tensor& input) -> at::Tensor {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
bool is_ir = xtensor->CurrentIrValue();
if (is_ir) {
is_ir = !DeviceData::Cast(xtensor->CurrentIrValue().node.get());
}
XLA_CHECK(is_ir) << "Marking any data tensors as manual is not supported";

ShardingUtil::XlaMarkSharding(input, sharding);
auto sharding_spec = xtensor->sharding_spec();
XLA_CHECK(sharding_spec != nullptr) << "Input tensor is not sharded";

auto shard_shape = xla::ShapeUtil::MakeShape(
MakeXlaPrimitiveType(xtensor->dtype(), &(xtensor->GetDevice())),
ShardingUtil::GetShardShape(sharding_spec));
auto output = xtensor->CreateFrom(torch::lazy::MakeNode<CustomSharding>(
xtensor->GetIrValue(), shard_shape,
CustomSharding::Type::kSPMDFullToShardShape));
output->SetShardingSpec(XLATensor::ShardingSpec(
xla::HloSharding::Manual().ToProto(), shard_shape));
return bridge::AtenFromXlaTensor(output);
});
m.def("_xla_mark_sharding_dynamo_custom_op",
[](const at::Tensor& input, const py::list& tile_assignment,
Expand Down
30 changes: 24 additions & 6 deletions torch_xla/csrc/ops/custom_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,42 @@
#include "torch_xla/csrc/xla_lower_util.h"

namespace torch_xla {
namespace {
std::string TypeToString(const CustomSharding::Type& type) {
switch (type) {
case CustomSharding::Type::kSharding:
return "Sharding";
case CustomSharding::Type::kSPMDFullToShardShape:
return "SPMDFullToShardShape";
case CustomSharding::Type::kSPMDShardToFullShape:
return "SPMDShardToFullShape";
}
}
} // namespace

CustomSharding::CustomSharding(const torch::lazy::Value& input)
: XlaNode(xla_custom_sharding, {input}, GetXlaShape(input),
/*num_outputs=*/1, torch::lazy::MHash(std::string("Sharding"))) {}
CustomSharding::CustomSharding(const torch::lazy::Value& input,
const xla::Shape& output_shape,
const CustomSharding::Type& type)
: XlaNode(xla_custom_sharding, {input}, output_shape,
/*num_outputs=*/1, torch::lazy::MHash(static_cast<int>(type))),
type(type),
output_shape(output_shape) {}

torch::lazy::NodePtr CustomSharding::Clone(torch::lazy::OpList operands) const {
return torch::lazy::MakeNode<CustomSharding>(operands.at(0));
return torch::lazy::MakeNode<CustomSharding>(operands.at(0), output_shape,
type);
}

XlaOpVector CustomSharding::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp output = BuildCustomSharding(input);
xla::XlaOp output =
BuildCustomSharding(input, TypeToString(type), output_shape);
return ReturnOp(output, loctx);
}

std::string CustomSharding::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", Sharding";
ss << XlaNode::ToString() << ", " << TypeToString(type);
return ss.str();
}

Expand Down
15 changes: 14 additions & 1 deletion torch_xla/csrc/ops/custom_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,27 @@ namespace torch_xla {

class CustomSharding : public XlaNode {
public:
// The following enum represents the custom_call_target being
// passed to xla builder. The actual sharding will still be
// attached to the XLATensor.
enum class Type {
kSharding,
kSPMDFullToShardShape,
kSPMDShardToFullShape,
};
Comment on lines +13 to +17
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This enum is really confusing, can you add some comment around what they actually does? I was reading the SPMD code again, this op itself only means we want to shard the underlying value and the actual sharding resides in the XlaTensor or Based XLAIR object?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this is just the name of the custom call. The sharding annotation is in XlaTensor as normal. I can add more explanations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can annotate explicilty that this is sharding type for custom call in the enum class name or somethinhg.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the current approach sort of does it already? Can you be more specific? @yeounoh

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, Type is already defined under CustomSharding


// Make a custom call to Sharding.
CustomSharding(const torch::lazy::Value& input);
CustomSharding(const torch::lazy::Value& input,
const xla::Shape& output_shape, const Type& type);

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

Type type;
xla::Shape output_shape;
};

} // namespace torch_xla
Expand Down
8 changes: 4 additions & 4 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
#include "torch_xla/csrc/ops/count_nonzero.h"
#include "torch_xla/csrc/ops/cumprod.h"
#include "torch_xla/csrc/ops/cumsum.h"
#include "torch_xla/csrc/ops/custom_sharding.h"
#include "torch_xla/csrc/ops/dequant_tensor.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/diagonal.h"
Expand Down Expand Up @@ -522,9 +521,10 @@ std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(

void custom_sharding_(
const XLATensorPtr& input,
const std::shared_ptr<XLATensor::ShardingSpec>& sharding_spec) {
input->SetInPlaceIrValue(
torch::lazy::MakeNode<CustomSharding>(input->GetIrValue()));
const std::shared_ptr<XLATensor::ShardingSpec>& sharding_spec,
const CustomSharding::Type& type) {
input->SetInPlaceIrValue(torch::lazy::MakeNode<CustomSharding>(
input->GetIrValue(), input->shape().get(), type));
input->SetShardingSpec(*sharding_spec);
}

Expand Down
7 changes: 5 additions & 2 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define XLA_TORCH_XLA_CSRC_TENSOR_METHODS_H_

#include "torch_xla/csrc/cross_replica_reduces.h"
#include "torch_xla/csrc/ops/custom_sharding.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/tensor.h"

Expand Down Expand Up @@ -79,8 +80,10 @@ std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
const XLATensorPtr& input, const torch::lazy::Value& token,
std::vector<std::pair<int64_t, int64_t>> source_target_pairs);

void custom_sharding_(const XLATensorPtr& input,
const std::shared_ptr<XLATensor::ShardingSpec>& spec);
void custom_sharding_(
const XLATensorPtr& input,
const std::shared_ptr<XLATensor::ShardingSpec>& spec,
const CustomSharding::Type& type = CustomSharding::Type::kSharding);

std::vector<XLATensorPtr> tpu_custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
Expand Down
7 changes: 4 additions & 3 deletions torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1266,9 +1266,10 @@ xla::XlaOp BuildMultinomial(xla::XlaOp input, int64_t num_samples,
return output;
}

xla::XlaOp BuildCustomSharding(const xla::XlaOp& input) {
return xla::CustomCall(input.builder(), /*call_target_name=*/"Sharding",
{input}, ShapeHelper::ShapeOfXlaOp(input));
xla::XlaOp BuildCustomSharding(const xla::XlaOp& input, const std::string& type,
const xla::Shape& output_shape) {
return xla::CustomCall(input.builder(), /*call_target_name=*/type, {input},
output_shape);
}

std::vector<xla::XlaOp> BuildTpuCustomCall(
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/xla_lower_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ xla::XlaOp BuildPixelShuffle(xla::XlaOp input, int64_t upscale_factor);

xla::XlaOp BuildUpperTriangle(xla::XlaOp input);

xla::XlaOp BuildCustomSharding(const xla::XlaOp& input);
xla::XlaOp BuildCustomSharding(const xla::XlaOp& input, const std::string& type,
const xla::Shape& output_shape);

std::vector<xla::XlaOp> BuildTpuCustomCall(
const std::vector<xla::XlaOp>& inputs, const xla::Shape& output_shape,
Expand Down