Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions test/test_mp_all_gather.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import sys
import torch
import torch_xla
Expand All @@ -7,8 +8,8 @@

def _mp_fn(index):
device = xm.xla_device()
world_size = xm.xrt_world_size()
if world_size > 1:
if xm.xla_device_hw(device) in ('TPU', 'GPU'):
world_size = xm.xrt_world_size()
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
result = xm.all_gather(ordinal_tensor)

Expand All @@ -20,7 +21,7 @@ def _mp_fn(index):
sys.exit(1)
else:
print(
'Default device {} does not support replication'.format(device),
'Default device {} is not a TPU or GPU device'.format(device),
file=sys.stderr)


Expand Down
6 changes: 3 additions & 3 deletions test/test_mp_distributed_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
def _mp_fn(index):
device = xm.xla_device()

world_size = xm.xrt_world_size()
if world_size > 1:
if xm.xla_device_hw(device) in ('TPU', 'GPU'):
world_size = xm.xrt_world_size()
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=True)
torch.manual_seed(11)
Expand All @@ -35,7 +35,7 @@ def _mp_fn(index):
sys.exit(1)
else:
print(
'Default device {} does not support replication'.format(device),
'Default device {} is not a TPU or GPU device'.format(device),
file=sys.stderr)


Expand Down
21 changes: 6 additions & 15 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,21 +598,12 @@ def all_gather(value, dim=0, groups=None):
"""
if dim < 0:
dim = value.dim() + dim
size = value.size(dim)
padding = [0] * (2 * value.dim())
ordinal = get_ordinal()
if groups is None:
left, right = ordinal, xrt_world_size() - 1 - ordinal
else:
ordinals = dict()
for g in groups:
for i, x in enumerate(g):
ordinals[x] = (i, len(g) - 1 - i)
left, right = ordinals[ordinal]
idx = value.dim() - 1 - dim
padding[2 * idx] = left * size
padding[2 * idx + 1] = right * size
return all_reduce(REDUCE_SUM, F.pad(value, padding), groups=groups)
token, devctx = _get_all_reduce_token()
shard_count = None if groups else xrt_world_size()
result = torch_xla._XLAC._xla_all_gather(value, token, dim, shard_count,
groups or [])
devctx.all_reduce_token = result[1]
return result[0]


def all_to_all(value,
Expand Down
13 changes: 13 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,19 @@ AllToAllResult BuildAllToAll(
return {reduce_result, token_handler.GetNewToken(reduce_result)};
}

AllGatherResult BuildAllGather(
xla::XlaOp input, xla::XlaOp token, xla::int64_t dim,
xla::int64_t shard_count,
const std::vector<std::vector<xla::int64_t>>& groups) {
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
TokenHandler token_handler(token);
xla::XlaOp all_gather_result =
xla::AllGather(token_handler.GetInput(input, &input_shape), dim,
shard_count, reduce_groups);
return {all_gather_result, token_handler.GetNewToken(all_gather_result)};
}

CollectivePermuteResult BuildCollectivePermute(
xla::XlaOp input, xla::XlaOp token,
const std::vector<std::pair<xla::int64_t, xla::int64_t>>&
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ struct AllToAllResult {
xla::XlaOp token;
};

struct AllGatherResult {
xla::XlaOp result;
xla::XlaOp token;
};

struct CollectivePermuteResult {
xla::XlaOp result;
xla::XlaOp token;
Expand All @@ -41,6 +46,11 @@ AllToAllResult BuildAllToAll(
xla::int64_t concat_dimension, xla::int64_t split_count,
const std::vector<std::vector<xla::int64_t>>& groups);

AllGatherResult BuildAllGather(
xla::XlaOp input, xla::XlaOp token, xla::int64_t dim,
xla::int64_t shard_count,
const std::vector<std::vector<xla::int64_t>>& groups);

CollectivePermuteResult BuildCollectivePermute(
xla::XlaOp input, xla::XlaOp token,
const std::vector<std::pair<xla::int64_t, xla::int64_t>>&
Expand Down
30 changes: 30 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,18 @@ std::pair<at::Tensor, std::shared_ptr<ir::Value>> ReduceScatter(
std::make_shared<ir::Value>(new_token));
}

std::pair<at::Tensor, std::shared_ptr<ir::Value>> AllGather(
const at::Tensor& input, const std::shared_ptr<ir::Value>& token,
xla::int64_t dim, xla::int64_t shard_count,
const std::vector<std::vector<xla::int64_t>>& replica_groups) {
XLATensor result;
ir::Value new_token;
std::tie(result, new_token) = XLATensor::all_gather(
bridge::GetXlaTensor(input), *token, dim, shard_count, replica_groups);
return {bridge::AtenFromXlaTensor(std::move(result)),
std::make_shared<ir::Value>(new_token)};
}

std::pair<at::Tensor, std::shared_ptr<ir::Value>> AllToAll(
const at::Tensor& input, const std::shared_ptr<ir::Value>& token,
xla::int64_t split_dimension, xla::int64_t concat_dimension,
Expand Down Expand Up @@ -900,6 +912,24 @@ void InitXlaModuleBindings(py::module m) {
result_tuple[1] = new_token;
return result_tuple;
});
m.def("_xla_all_gather",
[](const at::Tensor& input, const std::shared_ptr<ir::Value>& token,
xla::int64_t dim, xla::int64_t shard_count, const py::list& groups) {
std::vector<std::vector<xla::int64_t>> replica_groups =
CreateReduceGroups(groups);
at::Tensor result;
std::shared_ptr<ir::Value> new_token;
{
NoGilSection nogil;
std::tie(result, new_token) =
AllGather(input, token, dim, shard_count, replica_groups);
}
auto result_tuple = py::tuple(2);
result_tuple[0] = torch::autograd::make_variable(
result, /*requires_grad=*/input.requires_grad());
result_tuple[1] = new_token;
return result_tuple;
});
m.def("_xla_collective_permute", [](const at::Tensor& input,
const std::shared_ptr<ir::Value>& token,
const py::list& pairs) {
Expand Down
68 changes: 68 additions & 0 deletions torch_xla/csrc/ops/all_gather.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include "torch_xla/csrc/ops/all_gather.h"

#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/ops/xla_ops.h"

namespace torch_xla {
namespace ir {
namespace ops {
namespace {

xla::Shape NodeOutputShape(
const Value& input, const Value& token, xla::int64_t dim,
xla::int64_t shard_count,
const std::vector<std::vector<xla::int64_t>>& groups) {
auto shape_fn = [&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
AllGatherResult result =
BuildAllGather(operands[0], operands[1], dim, shard_count, groups);
return xla::Tuple(operands[0].builder(), {result.result, result.token});
};
return InferOutputShape({input.shape(), token.shape()}, shape_fn);
}

} // namespace

AllGather::AllGather(const Value& input, const Value& token, xla::int64_t dim,
xla::int64_t shard_count,
std::vector<std::vector<xla::int64_t>> groups)
: Node(xla_all_gather, {input, token},
[&]() {
return NodeOutputShape(input, token, dim, shard_count, groups);
},
/*num_outputs=*/2, torch::lazy::MHash(dim, shard_count, groups)),
dim_(dim),
shard_count_(shard_count),
groups_(std::move(groups)) {}

NodePtr AllGather::Clone(OpList operands) const {
return MakeNode<AllGather>(operands.at(0), operands.at(1), dim_, shard_count_,
groups_);
}

XlaOpVector AllGather::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp token = loctx->GetOutputOp(operand(1));
AllGatherResult result =
BuildAllGather(input, token, dim_, shard_count_, groups_);
return ReturnOps({result.result, result.token}, loctx);
}

std::string AllGather::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", dim=" << dim_ << ", shard_count=" << shard_count_
<< ", groups=(";
for (size_t i = 0; i < groups_.size(); ++i) {
ss << (i == 0 ? "(" : ",(");
ss << absl::StrJoin(groups_[i], ", ") << ")";
}
ss << ")";
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
38 changes: 38 additions & 0 deletions torch_xla/csrc/ops/all_gather.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#pragma once

#include "torch_xla/csrc/cross_replica_reduces.h"
#include "torch_xla/csrc/ir.h"

namespace torch_xla {
namespace ir {
namespace ops {

class AllGather : public Node {
public:
AllGather(const Value& input, const Value& token, xla::int64_t dim,
xla::int64_t shard_count,
std::vector<std::vector<xla::int64_t>> groups);

std::string ToString() const override;

NodePtr Clone(OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

xla::int64_t dim() const { return dim_; }

xla::int64_t shard_count() const { return shard_count_; }

const std::vector<std::vector<xla::int64_t>>& groups() const {
return groups_;
}

private:
xla::int64_t dim_;
xla::int64_t shard_count_;
std::vector<std::vector<xla::int64_t>> groups_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
1 change: 1 addition & 0 deletions torch_xla/csrc/ops/xla_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ namespace torch_xla {
namespace ir {
namespace ops {

const OpKindWrapper xla_all_gather("xla::all_gather");
const OpKindWrapper xla_all_to_all("xla::all_to_all");
const OpKindWrapper xla_as_strided_view_update("xla::as_strided_view_update");
const OpKindWrapper xla_cast("xla::cast");
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/ops/xla_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class OpKindWrapper {
mutable std::once_flag once_;
};

extern const OpKindWrapper xla_all_gather;
extern const OpKindWrapper xla_all_to_all;
extern const OpKindWrapper xla_as_strided_view_update;
extern const OpKindWrapper xla_cast;
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ class XLATensor {
xla::int64_t split_dimension, xla::int64_t concat_dimension,
xla::int64_t split_count, std::vector<std::vector<xla::int64_t>> groups);

static std::pair<XLATensor, ir::Value> all_gather(
const XLATensor& input, const ir::Value& token, xla::int64_t dim,
xla::int64_t shard_count, std::vector<std::vector<xla::int64_t>> groups);

static std::pair<XLATensor, ir::Value> collective_permute(
const XLATensor& input, const ir::Value& token,
std::vector<std::pair<xla::int64_t, xla::int64_t>> source_target_pairs);
Expand Down
9 changes: 9 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "torch_xla/csrc/ops/adaptive_avg_pool3d.h"
#include "torch_xla/csrc/ops/adaptive_max_pool2d.h"
#include "torch_xla/csrc/ops/all.h"
#include "torch_xla/csrc/ops/all_gather.h"
#include "torch_xla/csrc/ops/all_reduce.h"
#include "torch_xla/csrc/ops/all_to_all.h"
#include "torch_xla/csrc/ops/amax.h"
Expand Down Expand Up @@ -389,6 +390,14 @@ std::pair<XLATensor, ir::Value> XLATensor::all_to_all(
return {input.CreateFrom(ir::Value(node, 0)), ir::Value(node, 1)};
}

std::pair<XLATensor, ir::Value> XLATensor::all_gather(
const XLATensor& input, const ir::Value& token, xla::int64_t dim,
xla::int64_t shard_count, std::vector<std::vector<xla::int64_t>> groups) {
ir::NodePtr node = ir::MakeNode<ir::ops::AllGather>(
input.GetIrValue(), token, dim, shard_count, std::move(groups));
return {input.CreateFrom(ir::Value(node, 0)), ir::Value(node, 1)};
}

std::pair<XLATensor, ir::Value> XLATensor::collective_permute(
const XLATensor& input, const ir::Value& token,
std::vector<std::pair<xla::int64_t, xla::int64_t>> source_target_pairs) {
Expand Down