Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2913,6 +2913,36 @@ TEST_F(AtenXlaTensorTest, TestMultiIndexPut) {
}
}

TEST_F(AtenXlaTensorTest, TestMultiIndexPutHeadNull) {
at::Tensor indices_0 =
at::randint(-3, 3, {2, 4, 3}, at::TensorOptions(at::kLong));
at::Tensor indices_null;
at::Tensor indices_1 =
at::randint(-3, 3, {2, 4, 3}, at::TensorOptions(at::kLong));
for (at::ScalarType scalar_type :
{at::kFloat, at::kByte, at::kChar, at::kShort, at::kInt, at::kLong}) {
at::Tensor params =
isFloatingType(scalar_type)
? at::rand({4, 3, 3, 6, 7}, at::TensorOptions(scalar_type))
: at::randint(100, {4, 3, 3, 6, 7}, at::TensorOptions(scalar_type));
at::Tensor values = at::ones({3, 6, 7}, at::TensorOptions(scalar_type));
for (bool accumulate : {false, true}) {
at::Tensor result = at::index_put(
params, {indices_null, indices_0, indices_1}, values, accumulate);
ForEachDevice([&](const Device& device) {
at::Tensor xla_params = bridge::CreateXlaTensor(params, device);
at::Tensor xla_indices_0 = bridge::CreateXlaTensor(indices_0, device);
at::Tensor xla_indices_1 = bridge::CreateXlaTensor(indices_1, device);
at::Tensor xla_values = bridge::CreateXlaTensor(values, device);
at::Tensor xla_result = at::index_put(
xla_params, {indices_null, xla_indices_0, xla_indices_1},
xla_values, accumulate);
AllClose(result, xla_result);
});
}
}
}

TEST_F(AtenXlaTensorTest, TestMultiIndexPutMiddleNull) {
at::Tensor indices_0 =
at::randint(-3, 3, {2, 4, 3}, at::TensorOptions(at::kLong));
Expand Down
20 changes: 10 additions & 10 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1499,11 +1499,11 @@ at::Tensor AtenXlaType::index_put(const at::Tensor& self,
bool accumulate) const {
CanonicalIndexInfo canonical_index_info =
GetCanonicalIndexInfo(self, indices);
return bridge::AtenFromXlaTensor(
XLATensor::index_put(bridge::GetXlaTensor(canonical_index_info.base),
bridge::GetXlaTensors(canonical_index_info.indices),
bridge::GetXlaTensor(values), accumulate,
canonical_index_info.result_permutation));
return bridge::AtenFromXlaTensor(XLATensor::index_put(
bridge::GetXlaTensor(canonical_index_info.base),
bridge::GetXlaTensors(canonical_index_info.indices),
canonical_index_info.start_dim, bridge::GetXlaTensor(values), accumulate,
canonical_index_info.result_permutation));
}

at::Tensor& AtenXlaType::index_put_(at::Tensor& self, at::TensorList indices,
Expand All @@ -1512,11 +1512,11 @@ at::Tensor& AtenXlaType::index_put_(at::Tensor& self, at::TensorList indices,
CanonicalIndexInfo canonical_index_info =
GetCanonicalIndexInfo(self, indices);
XLATensor self_tensor = bridge::GetXlaTensor(self);
XLATensor::index_put_(self_tensor,
bridge::GetXlaTensor(canonical_index_info.base),
bridge::GetXlaTensors(canonical_index_info.indices),
bridge::GetXlaTensor(values), accumulate,
canonical_index_info.result_permutation);
XLATensor::index_put_(
self_tensor, bridge::GetXlaTensor(canonical_index_info.base),
bridge::GetXlaTensors(canonical_index_info.indices),
canonical_index_info.start_dim, bridge::GetXlaTensor(values), accumulate,
canonical_index_info.result_permutation);
return self;
}

Expand Down
42 changes: 6 additions & 36 deletions torch_xla/csrc/ops/index_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/arithmetic_ir_ops.h"
#include "torch_xla/csrc/ops/index_get.h"
#include "torch_xla/csrc/ops/index_put.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/ops/ops.h"
#include "torch_xla/csrc/ops/permute.h"
Expand Down Expand Up @@ -141,38 +142,6 @@ std::vector<XLATensor> WrapIndicesOnce(
return canonical_indices;
}

ir::NodePtr IndexPutOp(const ir::Value& buffer, const ir::Value& indices,
const ir::Value& values, bool accumulate) {
static std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp, xla::XlaBuilder*)>
add_scatter_combiner =
[](const xla::XlaOp& x, const xla::XlaOp& y,
xla::XlaBuilder* builder) -> xla::XlaOp { return x + y; };
auto lower_fn = [accumulate](const ir::Node& node,
ir::LoweringContext* loctx) -> ir::XlaOpVector {
xla::XlaOp xla_base = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_indices = loctx->GetOutputOp(node.operand(1));
xla::XlaOp xla_values = loctx->GetOutputOp(node.operand(2));
return node.ReturnOp(
CreateIndexUpdate(xla_base, xla_indices, xla_values,
accumulate ? add_scatter_combiner : nullptr),
loctx);
};
auto lower_for_shape_fn =
[&](tensorflow::gtl::ArraySlice<const xla::XlaOp> operands)
-> xla::XlaOp {
// The combiner doesn't matter for shape.
return CreateIndexUpdate(operands[0], operands[1], operands[2], nullptr);
};
return ir::ops::GenericOp(
ir::OpKind(at::aten::index_put), {buffer, indices, values},
[&]() {
return ir::ops::InferOutputShape(
{buffer.shape(), indices.shape(), values.shape()},
lower_for_shape_fn);
},
std::move(lower_fn));
}

ir::NodePtr IndexFillOp(const ir::Value& buffer, xla::int64 dim,
const ir::Value& index, const ir::Value& value) {
auto lower_fn = [dim](const ir::Node& node,
Expand Down Expand Up @@ -287,19 +256,20 @@ XLATensor IndexByTensors(const XLATensor& base,

ir::Value IndexPutByTensors(
const XLATensor& base, tensorflow::gtl::ArraySlice<const XLATensor> indices,
const XLATensor& values, bool accumulate,
xla::int64 start_dim, const XLATensor& values, bool accumulate,
tensorflow::gtl::ArraySlice<const xla::int64> result_permutation) {
if (indices.empty()) {
return base.GetIrValue();
}
auto canonical_indices = WrapIndicesOnce(base, indices, 0);
auto canonical_indices = WrapIndicesOnce(base, indices, start_dim);
xla::int64 indices_rank = canonical_indices.front().shape().get().rank();
// Stack the indices to allow the whole multi-indexing to be dispatched with a
// single scatter.
XLATensor indices_nd = XLATensor::stack(canonical_indices, indices_rank);
return ir::MakeNode<ir::ops::Permute>(
IndexPutOp(base.GetIrValue(), indices_nd.GetIrValue(),
values.GetIrValue(), accumulate),
ir::MakeNode<ir::ops::IndexPut>(base.GetIrValue(),
indices_nd.GetIrValue(), start_dim,
values.GetIrValue(), accumulate),
xla::util::ToVector<xla::int64>(result_permutation));
}

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/index_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ XLATensor IndexByTensors(const XLATensor& base,

ir::Value IndexPutByTensors(
const XLATensor& base, tensorflow::gtl::ArraySlice<const XLATensor> indices,
const XLATensor& updates, bool accumulate,
xla::int64 start_dim, const XLATensor& updates, bool accumulate,
tensorflow::gtl::ArraySlice<const xla::int64> result_permutation);

ir::NodePtr IndexFill(const XLATensor& base, xla::int64 dim,
Expand Down
43 changes: 43 additions & 0 deletions torch_xla/csrc/ops/index_put.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include "torch_xla/csrc/ops/index_put.h"

#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/xla_lower_util.h"

namespace torch_xla {
namespace ir {
namespace ops {

IndexPut::IndexPut(const ir::Value& base, const ir::Value& indices,
xla::int64 start_dim, const ir::Value& values,
bool accumulate)
: Node(OpKind(at::aten::index_put), {base, indices, values}, base.shape(),
/*num_outputs=*/1, xla::util::MHash(start_dim, accumulate)),
start_dim_(start_dim),
accumulate_(accumulate) {}

std::string IndexPut::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", start_dim=" << start_dim_
<< ", accumulate=" << accumulate_;
return ss.str();
}

XlaOpVector IndexPut::Lower(LoweringContext* loctx) const {
std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp, xla::XlaBuilder*)>
add_scatter_combiner =
[](const xla::XlaOp& x, const xla::XlaOp& y,
xla::XlaBuilder* builder) -> xla::XlaOp { return x + y; };

xla::XlaOp base = loctx->GetOutputOp(operand(0));
xla::XlaOp indices = loctx->GetOutputOp(operand(1));
xla::XlaOp values = loctx->GetOutputOp(operand(2));
xla::XlaOp output =
CreateIndexUpdate(base, indices, start_dim_, values,
accumulate_ ? add_scatter_combiner : nullptr);
return ReturnOp(output, loctx);
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
31 changes: 31 additions & 0 deletions torch_xla/csrc/ops/index_put.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#pragma once

#include "torch_xla/csrc/ir.h"

namespace torch_xla {
namespace ir {
namespace ops {

class IndexPut : public Node {
public:
IndexPut(const ir::Value& base, const ir::Value& indices,
xla::int64 start_dim, const ir::Value& values, bool accumulate);

std::string ToString() const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

xla::int64 start_dim() const { return start_dim_; }

bool accumulate() const { return accumulate_; }

private:
// The dimension number at which indexing starts.
xla::int64 start_dim_;
// Whether to accumulate instead of set.
bool accumulate_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
4 changes: 2 additions & 2 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -511,13 +511,13 @@ class XLATensor {
static XLATensor index_put(
const XLATensor& input,
tensorflow::gtl::ArraySlice<const XLATensor> indices,
const XLATensor& values, bool accumulate,
xla::int64 start_dim, const XLATensor& values, bool accumulate,
tensorflow::gtl::ArraySlice<const xla::int64> result_permutation);

static void index_put_(
XLATensor& input, const XLATensor& canonical_base,
tensorflow::gtl::ArraySlice<const XLATensor> indices,
const XLATensor& values, bool accumulate,
xla::int64 start_dim, const XLATensor& values, bool accumulate,
tensorflow::gtl::ArraySlice<const xla::int64> result_permutation);

static XLATensor index_select(const XLATensor& input, xla::int64 dim,
Expand Down
10 changes: 5 additions & 5 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1121,19 +1121,19 @@ void XLATensor::index_fill_(XLATensor& input, xla::int64 dim,

XLATensor XLATensor::index_put(
const XLATensor& input,
tensorflow::gtl::ArraySlice<const XLATensor> indices,
tensorflow::gtl::ArraySlice<const XLATensor> indices, xla::int64 start_dim,
const XLATensor& values, bool accumulate,
tensorflow::gtl::ArraySlice<const xla::int64> result_permutation) {
return input.CreateFrom(IndexPutByTensors(input, indices, values, accumulate,
result_permutation));
return input.CreateFrom(IndexPutByTensors(input, indices, start_dim, values,
accumulate, result_permutation));
}

void XLATensor::index_put_(
XLATensor& input, const XLATensor& canonical_base,
tensorflow::gtl::ArraySlice<const XLATensor> indices,
tensorflow::gtl::ArraySlice<const XLATensor> indices, xla::int64 start_dim,
const XLATensor& values, bool accumulate,
tensorflow::gtl::ArraySlice<const xla::int64> result_permutation) {
input.SetIrValue(IndexPutByTensors(canonical_base, indices, values,
input.SetIrValue(IndexPutByTensors(canonical_base, indices, start_dim, values,
accumulate, result_permutation));
}

Expand Down
13 changes: 8 additions & 5 deletions torch_xla/csrc/tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ XLATensor EmbeddingDenseBackward(const XLATensor& grad_output,
XLATensor::full({num_weights}, 0, indices.GetDevice(), indices.dtype());
XLATensor ones =
XLATensor::full({numel}, 1, indices.GetDevice(), indices.dtype());
XLATensor::index_put_(counts, counts, {indices_rank1}, ones,
XLATensor::index_put_(counts, counts, {indices_rank1}, /*start_dim=*/0,
/*values=*/ones,
/*accumulate=*/true, /*result_permutation=*/{0});
XLATensor grad_weights_scale = XLATensor::index(counts, {indices_rank1}, 0);
// Scale the value of the gradient by the histogram.
Expand All @@ -212,10 +213,12 @@ XLATensor EmbeddingDenseBackward(const XLATensor& grad_output,
XLATensor::expand(skip_padding, grad.shape().get().dimensions());
XLATensor zero_grad =
XLATensor::full_like(grad, 0, grad.GetDevice(), grad.dtype());
return XLATensor::index_put(grad_weight, {indices_rank1},
XLATensor::where(skip_padding, grad, zero_grad),
/*accumulate=*/true,
/*result_permutation=*/{0, 1});
return XLATensor::index_put(
grad_weight, {indices_rank1},
/*start_dim=*/0,
/*values=*/XLATensor::where(skip_padding, grad, zero_grad),
/*accumulate=*/true,
/*result_permutation=*/{0, 1});
}

} // namespace tensor_ops
Expand Down
21 changes: 14 additions & 7 deletions torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ xla::XlaOp CreateIndex(const xla::XlaOp& input, const xla::XlaOp& indices,
}

xla::XlaOp CreateIndexUpdate(
const xla::XlaOp& buffer, const xla::XlaOp& indices,
const xla::XlaOp& buffer, const xla::XlaOp& indices, xla::int64 start_dim,
const xla::XlaOp& values,
const std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp, xla::XlaBuilder*)>&
combiner) {
Expand All @@ -360,9 +360,13 @@ xla::XlaOp CreateIndexUpdate(
xla::int64 num_window_dims_in_values = buffer_rank - num_index_dims;

// Make the values match the rank expected by scatter.
std::vector<xla::int64> expected_values_dims(indices_dims.begin(),
indices_dims.end());
for (xla::int64 dim = num_index_dims; dim < buffer_rank; ++dim) {
std::vector<xla::int64> expected_values_dims;
for (xla::int64 dim = 0; dim < start_dim; ++dim) {
expected_values_dims.push_back(buffer_shape.dimensions(dim));
}
expected_values_dims.insert(expected_values_dims.end(), indices_dims.begin(),
indices_dims.end());
for (xla::int64 dim = num_index_dims + start_dim; dim < buffer_rank; ++dim) {
expected_values_dims.push_back(buffer_shape.dimensions(dim));
}
xla::XlaOp new_values = values;
Expand All @@ -374,13 +378,16 @@ xla::XlaOp CreateIndexUpdate(
values_shape = XlaHelpers::ShapeOfXlaOp(new_values);
values_rank = values_shape.rank();

for (xla::int64 i = (values_rank - num_window_dims_in_values);
for (xla::int64 dim = 0; dim < start_dim; ++dim) {
dim_numbers.add_update_window_dims(dim);
}
for (xla::int64 i = values_rank - num_window_dims_in_values + start_dim;
i < values_rank; ++i) {
dim_numbers.add_update_window_dims(i);
}
for (xla::int64 i = 0; i < num_index_dims; ++i) {
dim_numbers.add_inserted_window_dims(i);
dim_numbers.add_scatter_dims_to_operand_dims(i);
dim_numbers.add_inserted_window_dims(i + start_dim);
dim_numbers.add_scatter_dims_to_operand_dims(i + start_dim);
}
xla::XlaComputation combiner_computation =
MakeScatterComputation(combiner, buffer_shape.element_type());
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/xla_lower_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ xla::XlaOp CreateIndex(const xla::XlaOp& input, const xla::XlaOp& indices,

// Similar to tf.scatter_nd, used to implement advanced indexing updates.
xla::XlaOp CreateIndexUpdate(
const xla::XlaOp& buffer, const xla::XlaOp& indices,
const xla::XlaOp& buffer, const xla::XlaOp& indices, xla::int64 start_dim,
const xla::XlaOp& updates,
const std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp, xla::XlaBuilder*)>&
combiner);
Expand Down