Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,46 @@ def test_manual_sharding_api_e2e(self):
self.assertEqual(xxx.shape, (8, 8))
self.assertTrue(torch.allclose(x.cpu() + 1, xxx.cpu()))

@unittest.skipIf(xr.device_type() != 'TPU', "Skip non-TPU device")
def test_spmd_reduce_scatter(self):
xs.set_global_mesh(self._get_mesh((1, self.n_devices)))
x = torch.ones(8, 8).to(xm.xla_device())

# Reduce scatter
x = xs.enable_manual_sharding(x, (None, None)).global_tensor
x = torch_xla._XLAC._xla_spmd_reduce_scatter(xm.REDUCE_SUM, x, 1.0, 0,
self.n_devices,
[self.device_ids])
x = xs.disable_manual_sharding(x, (None, None), x.shape).global_tensor

hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
self.assertIn(
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.3",
hlo)

expected_x = torch.ones(2, 8) * 4
self.assertTrue(torch.allclose(x.cpu(), expected_x))

@unittest.skipIf(xr.device_type() != 'TPU', "Skip non-TPU device")
def test_spmd_reduce_scatter_canonical_index(self):
xs.set_global_mesh(self._get_mesh((1, self.n_devices)))
x = torch.ones(8, 8).to(xm.xla_device())

# Reduce scatter
x = xs.enable_manual_sharding(x, (None, None)).global_tensor
x = torch_xla._XLAC._xla_spmd_reduce_scatter(xm.REDUCE_SUM, x, 1.0, -1,
self.n_devices,
[self.device_ids])
x = xs.disable_manual_sharding(x, (None, None), x.shape).global_tensor

hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
self.assertIn(
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.3",
hlo)

expected_x = torch.ones(8, 2) * 4
self.assertTrue(torch.allclose(x.cpu(), expected_x))


if __name__ == '__main__':
test = unittest.main()
Expand Down
24 changes: 24 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,30 @@ ReduceScatterResult BuildReduceScatter(
return {reduce_result, token_handler.GetNewToken(reduce_result)};
}

xla::XlaOp BuildReduceScatter(AllReduceType reduce_type, xla::XlaOp input,
double scale, int64_t scatter_dim,
int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups) {
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
// Just a dummy channel handle, and it's required to set the
// use_global_device_ids which is requried for SPMD.
xla::ChannelHandle channel_handle;
channel_handle.set_handle(1);
channel_handle.set_type(xla::ChannelHandle::DEVICE_TO_DEVICE);
xla::XlaOp reduce_result;
reduce_result = xla::ReduceScatter(
input, GetReduceComutation(reduce_type, input_shape.element_type()),
scatter_dim, shard_count, std::move(reduce_groups),
std::move(channel_handle), std::nullopt, true);
if (scale != 1.0) {
xla::XlaOp scaling_value = XlaHelpers::ScalarValue<float>(
scale, input_shape.element_type(), input.builder());
reduce_result = reduce_result * scaling_value;
}
return reduce_result;
}

ReduceScatterResultCoalesced BuildReduceScatterCoalesced(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> inputs,
xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count,
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ ReduceScatterResult BuildReduceScatter(
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout);

xla::XlaOp BuildReduceScatter(AllReduceType reduce_type, xla::XlaOp input,
double scale, int64_t scatter_dim,
int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups);

ReduceScatterResultCoalesced BuildReduceScatterCoalesced(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> inputs,
xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count,
Expand Down
11 changes: 11 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1557,6 +1557,17 @@ void InitXlaModuleBindings(py::module m) {
result_tuple[1] = new_token;
return result_tuple;
});
m.def(
"_xla_spmd_reduce_scatter",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the only difference is this one does not have token?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's one difference. Others are mentioned in the description.

[](const std::string& reduce_type, const at::Tensor& input, double scale,
int64_t scatter_dim, int64_t shard_count, const py::list& groups) {
std::vector<std::vector<int64_t>> replica_groups =
CreateReduceGroups(groups);
auto result = tensor_methods::reduce_scatter(
bridge::GetXlaTensor(input), GetReduceType(reduce_type), scale,
scatter_dim, shard_count, replica_groups);
return bridge::AtenFromXlaTensor(std::move(result));
});
m.def("_xla_reduce_scatter",
[](const std::string& reduce_type, const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
Expand Down
38 changes: 38 additions & 0 deletions torch_xla/csrc/ops/reduce_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ xla::Shape NodeOutputShape(AllReduceType reduce_type,
return InferOutputShape({GetXlaShape(input), GetXlaShape(token)}, shape_fn);
}

xla::Shape NodeOutputShape(AllReduceType reduce_type,
const torch::lazy::Value input, double scale,
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups) {
auto shape_fn = [&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
xla::XlaOp inputOp = operands[0];
return BuildReduceScatter(reduce_type, inputOp, scale, scatter_dim,
shard_count, groups);
};
return InferOutputShape({GetXlaShape(input)}, shape_fn);
}

xla::Shape NodeOutputShapeCoalesced(
AllReduceType reduce_type, c10::ArrayRef<torch::lazy::Value> inputs,
const torch::lazy::Value& token, double scale, int64_t scatter_dim,
Expand Down Expand Up @@ -73,6 +85,27 @@ ReduceScatter::ReduceScatter(AllReduceType reduce_type,
groups_(std::move(groups)),
pin_layout_(pin_layout) {}

ReduceScatter::ReduceScatter(AllReduceType reduce_type,
const torch::lazy::Value& input, double scale,
int64_t scatter_dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups)
: XlaNode(
xla_reduce_scatter, {input},
[&]() {
return NodeOutputShape(reduce_type, input, scale, scatter_dim,
shard_count, groups);
},
/*num_outputs=*/1,
torch::lazy::MHash(torch::lazy::GetEnumValue(reduce_type), scale,
scatter_dim, shard_count, groups)),
reduce_type_(reduce_type),
scale_(scale),
scatter_dim_(scatter_dim),
shard_count_(shard_count),
groups_(std::move(groups)),
pin_layout_(false),
has_token_(false) {}

ReduceScatterCoalesced::ReduceScatterCoalesced(
AllReduceType reduce_type, c10::ArrayRef<torch::lazy::Value> inputs,
const torch::lazy::Value& token, double scale, int64_t scatter_dim,
Expand Down Expand Up @@ -111,6 +144,11 @@ torch::lazy::NodePtr ReduceScatterCoalesced::Clone(

XlaOpVector ReduceScatter::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
if (!has_token_) {
auto result = BuildReduceScatter(reduce_type_, input, scale_, scatter_dim_,
shard_count_, groups_);
return ReturnOp(result, loctx);
}
xla::XlaOp token = loctx->GetOutputOp(operand(1));
ReduceScatterResult result =
BuildReduceScatter(reduce_type_, input, token, scale_, scatter_dim_,
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/ops/reduce_scatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ class ReduceScatter : public XlaNode {
const torch::lazy::Value& token, double scale,
int64_t scatter_dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups, bool pin_layout);
ReduceScatter(AllReduceType reduce_type, const torch::lazy::Value& input,
double scale, int64_t scatter_dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups);

std::string ToString() const override;

Expand All @@ -34,6 +37,7 @@ class ReduceScatter : public XlaNode {
int64_t shard_count_;
std::vector<std::vector<int64_t>> groups_;
bool pin_layout_;
bool has_token_{true};
};

class ReduceScatterCoalesced : public XlaNode {
Expand Down
11 changes: 11 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,17 @@ std::pair<XLATensorPtr, torch::lazy::Value> reduce_scatter(
torch::lazy::Value(node, 1)};
}

XLATensorPtr reduce_scatter(const XLATensorPtr& input,
AllReduceType reduce_type, double scale,
int64_t scatter_dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups) {
auto canonical_scatter_dim = torch::lazy::GetCanonicalDimensionIndex(
scatter_dim, input->shape().get().rank());
return input->CreateFrom(torch::lazy::MakeNode<ReduceScatter>(
reduce_type, input->GetIrValue(), scale, canonical_scatter_dim,
shard_count, std::move(groups)));
}

torch::lazy::Value reduce_scatter_out(XLATensorPtr& output,
const XLATensorPtr& input,
const torch::lazy::Value& token,
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ std::pair<XLATensorPtr, torch::lazy::Value> reduce_scatter(
int64_t shard_count, std::vector<std::vector<int64_t>> groups,
bool pin_layout);

XLATensorPtr reduce_scatter(const XLATensorPtr& input,
AllReduceType reduce_type, double scale,
int64_t scatter_dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups);

torch::lazy::Value reduce_scatter_out(XLATensorPtr& output,
const XLATensorPtr& input,
const torch::lazy::Value& token,
Expand Down