From 97e2d890772b806ce7c64aae8e690fe324bce349 Mon Sep 17 00:00:00 2001 From: hjm-aws Date: Tue, 4 Jan 2022 20:52:14 -0800 Subject: [PATCH 1/2] Implement all_gather with native XLA primitive. --- test/test_mp_all_gather.py | 8 +-- test/test_mp_distributed_mm.py | 7 ++- torch_xla/core/xla_model.py | 21 +++----- torch_xla/csrc/cross_replica_reduces.cpp | 13 +++++ torch_xla/csrc/cross_replica_reduces.h | 10 ++++ torch_xla/csrc/init_python_bindings.cpp | 30 +++++++++++ torch_xla/csrc/ops/all_gather.cpp | 68 ++++++++++++++++++++++++ torch_xla/csrc/ops/all_gather.h | 38 +++++++++++++ torch_xla/csrc/ops/xla_ops.cpp | 1 + torch_xla/csrc/ops/xla_ops.h | 1 + torch_xla/csrc/tensor.h | 4 ++ torch_xla/csrc/tensor_methods.cpp | 9 ++++ 12 files changed, 187 insertions(+), 23 deletions(-) create mode 100644 torch_xla/csrc/ops/all_gather.cpp create mode 100644 torch_xla/csrc/ops/all_gather.h diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index 334ad4c3d349..1ca4c9189a92 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -1,3 +1,4 @@ +import os import sys import torch import torch_xla @@ -7,8 +8,8 @@ def _mp_fn(index): device = xm.xla_device() - world_size = xm.xrt_world_size() - if world_size > 1: + if xm.xla_device_hw(device) == 'TPU': + world_size = xm.xrt_world_size() ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device) result = xm.all_gather(ordinal_tensor) @@ -20,8 +21,7 @@ def _mp_fn(index): sys.exit(1) else: print( - 'Default device {} does not support replication'.format(device), - file=sys.stderr) + 'Default device {} is not a TPU device'.format(device), file=sys.stderr) if __name__ == '__main__': diff --git a/test/test_mp_distributed_mm.py b/test/test_mp_distributed_mm.py index 9dc45dd504e8..6664b7bb7807 100644 --- a/test/test_mp_distributed_mm.py +++ b/test/test_mp_distributed_mm.py @@ -9,8 +9,8 @@ def _mp_fn(index): device = xm.xla_device() - world_size = xm.xrt_world_size() - if world_size > 1: + if xm.xla_device_hw(device) == 'TPU': + world_size = xm.xrt_world_size() torch_xla._XLAC._xla_set_use_full_mat_mul_precision( use_full_mat_mul_precision=True) torch.manual_seed(11) @@ -35,8 +35,7 @@ def _mp_fn(index): sys.exit(1) else: print( - 'Default device {} does not support replication'.format(device), - file=sys.stderr) + 'Default device {} is not a TPU device'.format(device), file=sys.stderr) if __name__ == '__main__': diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 881343d4ad06..4d04dd4c8516 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -598,21 +598,12 @@ def all_gather(value, dim=0, groups=None): """ if dim < 0: dim = value.dim() + dim - size = value.size(dim) - padding = [0] * (2 * value.dim()) - ordinal = get_ordinal() - if groups is None: - left, right = ordinal, xrt_world_size() - 1 - ordinal - else: - ordinals = dict() - for g in groups: - for i, x in enumerate(g): - ordinals[x] = (i, len(g) - 1 - i) - left, right = ordinals[ordinal] - idx = value.dim() - 1 - dim - padding[2 * idx] = left * size - padding[2 * idx + 1] = right * size - return all_reduce(REDUCE_SUM, F.pad(value, padding), groups=groups) + token, devctx = _get_all_reduce_token() + shard_count = None if groups else xrt_world_size() + result = torch_xla._XLAC._xla_all_gather(value, token, dim, shard_count, + groups or []) + devctx.all_reduce_token = result[1] + return result[0] def all_to_all(value, diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index b05ad405cb3f..8c28fcd8f63f 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -140,6 +140,19 @@ AllToAllResult BuildAllToAll( return {reduce_result, token_handler.GetNewToken(reduce_result)}; } +AllGatherResult BuildAllGather( + xla::XlaOp input, xla::XlaOp token, xla::int64_t dim, + xla::int64_t shard_count, + const std::vector>& groups) { + std::vector reduce_groups = CreateReduceGroups(groups); + const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); + TokenHandler token_handler(token); + xla::XlaOp all_gather_result = + xla::AllGather(token_handler.GetInput(input, &input_shape), dim, + shard_count, reduce_groups); + return {all_gather_result, token_handler.GetNewToken(all_gather_result)}; +} + CollectivePermuteResult BuildCollectivePermute( xla::XlaOp input, xla::XlaOp token, const std::vector>& diff --git a/torch_xla/csrc/cross_replica_reduces.h b/torch_xla/csrc/cross_replica_reduces.h index 0334e931ffe3..2670e3a8c26b 100644 --- a/torch_xla/csrc/cross_replica_reduces.h +++ b/torch_xla/csrc/cross_replica_reduces.h @@ -21,6 +21,11 @@ struct AllToAllResult { xla::XlaOp token; }; +struct AllGatherResult { + xla::XlaOp result; + xla::XlaOp token; +}; + struct CollectivePermuteResult { xla::XlaOp result; xla::XlaOp token; @@ -41,6 +46,11 @@ AllToAllResult BuildAllToAll( xla::int64_t concat_dimension, xla::int64_t split_count, const std::vector>& groups); +AllGatherResult BuildAllGather( + xla::XlaOp input, xla::XlaOp token, xla::int64_t dim, + xla::int64_t shard_count, + const std::vector>& groups); + CollectivePermuteResult BuildCollectivePermute( xla::XlaOp input, xla::XlaOp token, const std::vector>& diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 6d3b40ffa53e..a14fde73be6d 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -213,6 +213,18 @@ std::pair> ReduceScatter( std::make_shared(new_token)); } +std::pair> AllGather( + const at::Tensor& input, const std::shared_ptr& token, + xla::int64_t dim, xla::int64_t shard_count, + const std::vector>& replica_groups) { + XLATensor result; + ir::Value new_token; + std::tie(result, new_token) = XLATensor::all_gather( + bridge::GetXlaTensor(input), *token, dim, shard_count, replica_groups); + return {bridge::AtenFromXlaTensor(std::move(result)), + std::make_shared(new_token)}; +} + std::pair> AllToAll( const at::Tensor& input, const std::shared_ptr& token, xla::int64_t split_dimension, xla::int64_t concat_dimension, @@ -900,6 +912,24 @@ void InitXlaModuleBindings(py::module m) { result_tuple[1] = new_token; return result_tuple; }); + m.def("_xla_all_gather", + [](const at::Tensor& input, const std::shared_ptr& token, + xla::int64_t dim, xla::int64_t shard_count, const py::list& groups) { + std::vector> replica_groups = + CreateReduceGroups(groups); + at::Tensor result; + std::shared_ptr new_token; + { + NoGilSection nogil; + std::tie(result, new_token) = + AllGather(input, token, dim, shard_count, replica_groups); + } + auto result_tuple = py::tuple(2); + result_tuple[0] = torch::autograd::make_variable( + result, /*requires_grad=*/input.requires_grad()); + result_tuple[1] = new_token; + return result_tuple; + }); m.def("_xla_collective_permute", [](const at::Tensor& input, const std::shared_ptr& token, const py::list& pairs) { diff --git a/torch_xla/csrc/ops/all_gather.cpp b/torch_xla/csrc/ops/all_gather.cpp new file mode 100644 index 000000000000..4926593119c1 --- /dev/null +++ b/torch_xla/csrc/ops/all_gather.cpp @@ -0,0 +1,68 @@ +#include "torch_xla/csrc/ops/all_gather.h" + +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" +#include "torch_xla/csrc/ops/xla_ops.h" + +namespace torch_xla { +namespace ir { +namespace ops { +namespace { + +xla::Shape NodeOutputShape( + const Value& input, const Value& token, xla::int64_t dim, + xla::int64_t shard_count, + const std::vector>& groups) { + auto shape_fn = [&](absl::Span operands) -> xla::XlaOp { + AllGatherResult result = + BuildAllGather(operands[0], operands[1], dim, shard_count, groups); + return xla::Tuple(operands[0].builder(), {result.result, result.token}); + }; + return InferOutputShape({input.shape(), token.shape()}, shape_fn); +} + +} // namespace + +AllGather::AllGather(const Value& input, const Value& token, xla::int64_t dim, + xla::int64_t shard_count, + std::vector> groups) + : Node(xla_all_gather, {input, token}, + [&]() { + return NodeOutputShape(input, token, dim, shard_count, groups); + }, + /*num_outputs=*/2, torch::lazy::MHash(dim, shard_count, groups)), + dim_(dim), + shard_count_(shard_count), + groups_(std::move(groups)) {} + +NodePtr AllGather::Clone(OpList operands) const { + return MakeNode(operands.at(0), operands.at(1), dim_, shard_count_, + groups_); +} + +XlaOpVector AllGather::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp token = loctx->GetOutputOp(operand(1)); + AllGatherResult result = + BuildAllGather(input, token, dim_, shard_count_, groups_); + return ReturnOps({result.result, result.token}, loctx); +} + +std::string AllGather::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", dim=" << dim_ << ", shard_count=" << shard_count_ + << ", groups=("; + for (size_t i = 0; i < groups_.size(); ++i) { + ss << (i == 0 ? "(" : ",("); + ss << absl::StrJoin(groups_[i], ", ") << ")"; + } + ss << ")"; + return ss.str(); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/all_gather.h b/torch_xla/csrc/ops/all_gather.h new file mode 100644 index 000000000000..62065e74367c --- /dev/null +++ b/torch_xla/csrc/ops/all_gather.h @@ -0,0 +1,38 @@ +#pragma once + +#include "torch_xla/csrc/cross_replica_reduces.h" +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class AllGather : public Node { + public: + AllGather(const Value& input, const Value& token, xla::int64_t dim, + xla::int64_t shard_count, + std::vector> groups); + + std::string ToString() const override; + + NodePtr Clone(OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + xla::int64_t dim() const { return dim_; } + + xla::int64_t shard_count() const { return shard_count_; } + + const std::vector>& groups() const { + return groups_; + } + + private: + xla::int64_t dim_; + xla::int64_t shard_count_; + std::vector> groups_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/xla_ops.cpp b/torch_xla/csrc/ops/xla_ops.cpp index 4f3b3d2d7802..25308bb3f68b 100644 --- a/torch_xla/csrc/ops/xla_ops.cpp +++ b/torch_xla/csrc/ops/xla_ops.cpp @@ -4,6 +4,7 @@ namespace torch_xla { namespace ir { namespace ops { +const OpKindWrapper xla_all_gather("xla::all_gather"); const OpKindWrapper xla_all_to_all("xla::all_to_all"); const OpKindWrapper xla_as_strided_view_update("xla::as_strided_view_update"); const OpKindWrapper xla_cast("xla::cast"); diff --git a/torch_xla/csrc/ops/xla_ops.h b/torch_xla/csrc/ops/xla_ops.h index 940a5bb14115..17f3c685de9c 100644 --- a/torch_xla/csrc/ops/xla_ops.h +++ b/torch_xla/csrc/ops/xla_ops.h @@ -28,6 +28,7 @@ class OpKindWrapper { mutable std::once_flag once_; }; +extern const OpKindWrapper xla_all_gather; extern const OpKindWrapper xla_all_to_all; extern const OpKindWrapper xla_as_strided_view_update; extern const OpKindWrapper xla_cast; diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 806a117b91a3..7263aa3cc1e8 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -209,6 +209,10 @@ class XLATensor { xla::int64_t split_dimension, xla::int64_t concat_dimension, xla::int64_t split_count, std::vector> groups); + static std::pair all_gather( + const XLATensor& input, const ir::Value& token, xla::int64_t dim, + xla::int64_t shard_count, std::vector> groups); + static std::pair collective_permute( const XLATensor& input, const ir::Value& token, std::vector> source_target_pairs); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 8c41a828e9cb..4e56e7fe7ddb 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -21,6 +21,7 @@ #include "torch_xla/csrc/ops/adaptive_avg_pool3d.h" #include "torch_xla/csrc/ops/adaptive_max_pool2d.h" #include "torch_xla/csrc/ops/all.h" +#include "torch_xla/csrc/ops/all_gather.h" #include "torch_xla/csrc/ops/all_reduce.h" #include "torch_xla/csrc/ops/all_to_all.h" #include "torch_xla/csrc/ops/amax.h" @@ -389,6 +390,14 @@ std::pair XLATensor::all_to_all( return {input.CreateFrom(ir::Value(node, 0)), ir::Value(node, 1)}; } +std::pair XLATensor::all_gather( + const XLATensor& input, const ir::Value& token, xla::int64_t dim, + xla::int64_t shard_count, std::vector> groups) { + ir::NodePtr node = ir::MakeNode( + input.GetIrValue(), token, dim, shard_count, std::move(groups)); + return {input.CreateFrom(ir::Value(node, 0)), ir::Value(node, 1)}; +} + std::pair XLATensor::collective_permute( const XLATensor& input, const ir::Value& token, std::vector> source_target_pairs) { From 671b2adc603e7db35ada1a41478c907545305cec Mon Sep 17 00:00:00 2001 From: hjm-aws Date: Mon, 17 Jan 2022 00:45:56 -0800 Subject: [PATCH 2/2] Enable GPU test. --- test/test_mp_all_gather.py | 5 +++-- test/test_mp_distributed_mm.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index 1ca4c9189a92..597ff079cecd 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -8,7 +8,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) == 'TPU': + if xm.xla_device_hw(device) in ('TPU', 'GPU'): world_size = xm.xrt_world_size() ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device) result = xm.all_gather(ordinal_tensor) @@ -21,7 +21,8 @@ def _mp_fn(index): sys.exit(1) else: print( - 'Default device {} is not a TPU device'.format(device), file=sys.stderr) + 'Default device {} is not a TPU or GPU device'.format(device), + file=sys.stderr) if __name__ == '__main__': diff --git a/test/test_mp_distributed_mm.py b/test/test_mp_distributed_mm.py index 6664b7bb7807..3dc45732ac36 100644 --- a/test/test_mp_distributed_mm.py +++ b/test/test_mp_distributed_mm.py @@ -9,7 +9,7 @@ def _mp_fn(index): device = xm.xla_device() - if xm.xla_device_hw(device) == 'TPU': + if xm.xla_device_hw(device) in ('TPU', 'GPU'): world_size = xm.xrt_world_size() torch_xla._XLAC._xla_set_use_full_mat_mul_precision( use_full_mat_mul_precision=True) @@ -35,7 +35,8 @@ def _mp_fn(index): sys.exit(1) else: print( - 'Default device {} is not a TPU device'.format(device), file=sys.stderr) + 'Default device {} is not a TPU or GPU device'.format(device), + file=sys.stderr) if __name__ == '__main__':