Skip to content

Commit ec27f90

Browse files
committed
Fix allgather to be compatible with openxla allgather tuple change without token
1 parent 3e97aa8 commit ec27f90

File tree

4 files changed

+37
-23
lines changed

4 files changed

+37
-23
lines changed

torch_xla/core/xla_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,8 +594,8 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
594594
result = torch_xla._XLAC._xla_all_gather_coalesced(value, token, dim,
595595
shard_count, groups or [],
596596
pin_layout)
597-
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
598-
return result[:-1]
597+
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
598+
return result[0]
599599

600600

601601
def all_to_all(value,

torch_xla/csrc/cross_replica_reduces.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -210,23 +210,18 @@ AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token,
210210
return {reduce_result, token_handler.GetNewToken(reduce_result)};
211211
}
212212

213-
std::vector<xla::XlaOp> BuildAllGather(
213+
AllGatherResult BuildAllGather(
214214
absl::Span<const xla::XlaOp> inputs, xla::XlaOp token, int64_t dim,
215215
int64_t shard_count, const std::vector<std::vector<int64_t>>& groups,
216216
bool pin_layout) {
217217
std::vector<xla::ReplicaGroup> cc_groups = CreateReduceGroups(groups);
218+
TokenHandler token_handler(token);
218219
// TODO: We use pseudo-tokens ATM, which are real values. This need to be
219220
// switched to use the real XLA Token once support has been added to XLA
220221
// AllGather().
221-
xla::XlaOp chained_token = token;
222222
ReduceContext cc_ctx = GetReduceContext(inputs);
223223
std::vector<xla::XlaOp> result(inputs.size());
224224
for (auto& type_ctx : cc_ctx.contexts) {
225-
xla::XlaOp token_op = MaybeConvertTo(chained_token, type_ctx.first);
226-
type_ctx.second.ops.push_back(token_op);
227-
type_ctx.second.operand_shapes.push_back(
228-
ShapeHelper::ShapeOfXlaOp(token_op));
229-
230225
xla::XlaOp all_gather_result;
231226
if (pin_layout) {
232227
all_gather_result = xla::AllGather(
@@ -239,16 +234,17 @@ std::vector<xla::XlaOp> BuildAllGather(
239234
xla::AllGather(xla::Tuple(inputs[0].builder(), type_ctx.second.ops),
240235
dim, shard_count, cc_groups);
241236
}
242-
for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) {
243-
size_t op_idx = type_ctx.second.indices[i];
244-
result[op_idx] = xla::GetTupleElement(all_gather_result, i);
237+
if (type_ctx.second.indices.size() > 1) {
238+
for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) {
239+
size_t op_idx = type_ctx.second.indices[i];
240+
result[op_idx] = xla::GetTupleElement(all_gather_result, i);
241+
}
242+
}
243+
else {
244+
result[0] = all_gather_result;
245245
}
246-
chained_token =
247-
xla::GetTupleElement(all_gather_result, type_ctx.second.indices.size());
248246
}
249-
result.push_back(
250-
MaybeConvertTo(chained_token, XlaHelpers::TypeOfXlaOp(token)));
251-
return result;
247+
return {result, token_handler.GetNewToken(result[0])};
252248
}
253249

254250
CollectivePermuteResult BuildCollectivePermute(

torch_xla/csrc/cross_replica_reduces.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ struct AllToAllResult {
2525
xla::XlaOp token;
2626
};
2727

28+
struct AllGatherResult {
29+
std::vector<xla::XlaOp> result;
30+
xla::XlaOp token;
31+
};
32+
2833
struct CollectivePermuteResult {
2934
xla::XlaOp result;
3035
xla::XlaOp token;
@@ -40,6 +45,11 @@ struct RecvResult {
4045
xla::XlaOp token;
4146
};
4247

48+
struct ReduceScatterResult {
49+
std::vector<xla::XlaOp> result;
50+
xla::XlaOp token;
51+
};
52+
4353
std::vector<xla::XlaOp> BuildAllReduce(
4454
AllReduceType reduce_type, absl::Span<const xla::XlaOp> operands,
4555
xla::XlaOp token, double scale,
@@ -51,7 +61,7 @@ AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token,
5161
const std::vector<std::vector<int64_t>>& groups,
5262
bool pin_layout);
5363

54-
std::vector<xla::XlaOp> BuildAllGather(
64+
AllGatherResult BuildAllGather(
5565
absl::Span<const xla::XlaOp>, xla::XlaOp token, int64_t dim,
5666
int64_t shard_count, const std::vector<std::vector<int64_t>>& groups,
5767
bool pin_layout);
@@ -66,6 +76,7 @@ SendResult BuildSendWithToken(xla::XlaOp input, xla::XlaOp token,
6676
RecvResult BuildRecvWithToken(xla::XlaOp token, const xla::Shape& recv_shape,
6777
int64_t channel_id);
6878

79+
//ReduceScatterResult BuildReduceScatter(
6980
std::vector<xla::XlaOp> BuildReduceScatter(
7081
AllReduceType reduce_type, absl::Span<const xla::XlaOp> inputs,
7182
xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count,

torch_xla/csrc/ops/all_gather.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,23 @@ xla::Shape NodeOutputShape(c10::ArrayRef<torch::lazy::Value> inputs,
1616
const std::vector<std::vector<int64_t>>& groups,
1717
bool pin_layout) {
1818
auto shape_fn = [&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
19-
std::vector<xla::XlaOp> result =
19+
AllGatherResult result =
2020
BuildAllGather(operands.subspan(0, operands.size() - 1),
2121
operands.back(), dim, shard_count, groups, pin_layout);
22-
return xla::Tuple(operands[0].builder(), result);
22+
std::vector<xla::XlaOp> outputs;
23+
for (size_t i = 0; i < result.result.size(); ++i) {
24+
outputs.emplace_back(result.result[i]);
25+
}
26+
outputs.emplace_back(result.token);
27+
return xla::Tuple(operands[0].builder(), outputs);
2328
};
2429
std::vector<xla::Shape> input_shapes;
2530
for (const auto& input : inputs) {
2631
input_shapes.emplace_back(GetXlaShape(input));
2732
}
2833
input_shapes.emplace_back(GetXlaShape(token));
2934
return InferOutputShape(input_shapes, shape_fn);
35+
3036
}
3137

3238
} // namespace
@@ -61,9 +67,10 @@ XlaOpVector AllGather::Lower(LoweringContext* loctx) const {
6167
inputs.push_back(loctx->GetOutputOp(operand_list[i]));
6268
}
6369
xla::XlaOp token = loctx->GetOutputOp(operand_list.back());
64-
return ReturnOps(
65-
BuildAllGather(inputs, token, dim_, shard_count_, groups_, pin_layout_),
66-
loctx);
70+
AllGatherResult result =
71+
BuildAllGather(inputs, token, dim_, shard_count_, groups_, pin_layout_);
72+
result.result.push_back(result.token);
73+
return ReturnOps(result.result, loctx);
6774
}
6875

6976
std::string AllGather::ToString() const {

0 commit comments

Comments
 (0)