From 9a2cb2e914d36fe1e14f17c1ee481952eeec186b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20=C5=9Euhan?= Date: Wed, 17 Apr 2019 11:01:30 -0700 Subject: [PATCH] Fix index_put starting with null indices --- test/cpp/test_aten_xla_tensor.cpp | 30 +++++++++++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 20 +++++++------- torch_xla/csrc/ops/index_ops.cpp | 42 +++++------------------------- torch_xla/csrc/ops/index_ops.h | 2 +- torch_xla/csrc/ops/index_put.cpp | 43 +++++++++++++++++++++++++++++++ torch_xla/csrc/ops/index_put.h | 31 ++++++++++++++++++++++ torch_xla/csrc/tensor.h | 4 +-- torch_xla/csrc/tensor_methods.cpp | 10 +++---- torch_xla/csrc/tensor_ops.cpp | 13 ++++++---- torch_xla/csrc/xla_lower_util.cpp | 21 ++++++++++----- torch_xla/csrc/xla_lower_util.h | 2 +- 11 files changed, 151 insertions(+), 67 deletions(-) create mode 100644 torch_xla/csrc/ops/index_put.cpp create mode 100644 torch_xla/csrc/ops/index_put.h diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index eb073e0ac498..ab9ecf35fc6f 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -2913,6 +2913,36 @@ TEST_F(AtenXlaTensorTest, TestMultiIndexPut) { } } +TEST_F(AtenXlaTensorTest, TestMultiIndexPutHeadNull) { + at::Tensor indices_0 = + at::randint(-3, 3, {2, 4, 3}, at::TensorOptions(at::kLong)); + at::Tensor indices_null; + at::Tensor indices_1 = + at::randint(-3, 3, {2, 4, 3}, at::TensorOptions(at::kLong)); + for (at::ScalarType scalar_type : + {at::kFloat, at::kByte, at::kChar, at::kShort, at::kInt, at::kLong}) { + at::Tensor params = + isFloatingType(scalar_type) + ? at::rand({4, 3, 3, 6, 7}, at::TensorOptions(scalar_type)) + : at::randint(100, {4, 3, 3, 6, 7}, at::TensorOptions(scalar_type)); + at::Tensor values = at::ones({3, 6, 7}, at::TensorOptions(scalar_type)); + for (bool accumulate : {false, true}) { + at::Tensor result = at::index_put( + params, {indices_null, indices_0, indices_1}, values, accumulate); + ForEachDevice([&](const Device& device) { + at::Tensor xla_params = bridge::CreateXlaTensor(params, device); + at::Tensor xla_indices_0 = bridge::CreateXlaTensor(indices_0, device); + at::Tensor xla_indices_1 = bridge::CreateXlaTensor(indices_1, device); + at::Tensor xla_values = bridge::CreateXlaTensor(values, device); + at::Tensor xla_result = at::index_put( + xla_params, {indices_null, xla_indices_0, xla_indices_1}, + xla_values, accumulate); + AllClose(result, xla_result); + }); + } + } +} + TEST_F(AtenXlaTensorTest, TestMultiIndexPutMiddleNull) { at::Tensor indices_0 = at::randint(-3, 3, {2, 4, 3}, at::TensorOptions(at::kLong)); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 898c98d68b33..cb3313d36852 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1499,11 +1499,11 @@ at::Tensor AtenXlaType::index_put(const at::Tensor& self, bool accumulate) const { CanonicalIndexInfo canonical_index_info = GetCanonicalIndexInfo(self, indices); - return bridge::AtenFromXlaTensor( - XLATensor::index_put(bridge::GetXlaTensor(canonical_index_info.base), - bridge::GetXlaTensors(canonical_index_info.indices), - bridge::GetXlaTensor(values), accumulate, - canonical_index_info.result_permutation)); + return bridge::AtenFromXlaTensor(XLATensor::index_put( + bridge::GetXlaTensor(canonical_index_info.base), + bridge::GetXlaTensors(canonical_index_info.indices), + canonical_index_info.start_dim, bridge::GetXlaTensor(values), accumulate, + canonical_index_info.result_permutation)); } at::Tensor& AtenXlaType::index_put_(at::Tensor& self, at::TensorList indices, @@ -1512,11 +1512,11 @@ at::Tensor& AtenXlaType::index_put_(at::Tensor& self, at::TensorList indices, CanonicalIndexInfo canonical_index_info = GetCanonicalIndexInfo(self, indices); XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::index_put_(self_tensor, - bridge::GetXlaTensor(canonical_index_info.base), - bridge::GetXlaTensors(canonical_index_info.indices), - bridge::GetXlaTensor(values), accumulate, - canonical_index_info.result_permutation); + XLATensor::index_put_( + self_tensor, bridge::GetXlaTensor(canonical_index_info.base), + bridge::GetXlaTensors(canonical_index_info.indices), + canonical_index_info.start_dim, bridge::GetXlaTensor(values), accumulate, + canonical_index_info.result_permutation); return self; } diff --git a/torch_xla/csrc/ops/index_ops.cpp b/torch_xla/csrc/ops/index_ops.cpp index f52644b47482..388d669952e2 100644 --- a/torch_xla/csrc/ops/index_ops.cpp +++ b/torch_xla/csrc/ops/index_ops.cpp @@ -10,6 +10,7 @@ #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/arithmetic_ir_ops.h" #include "torch_xla/csrc/ops/index_get.h" +#include "torch_xla/csrc/ops/index_put.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/ops/ops.h" #include "torch_xla/csrc/ops/permute.h" @@ -141,38 +142,6 @@ std::vector WrapIndicesOnce( return canonical_indices; } -ir::NodePtr IndexPutOp(const ir::Value& buffer, const ir::Value& indices, - const ir::Value& values, bool accumulate) { - static std::function - add_scatter_combiner = - [](const xla::XlaOp& x, const xla::XlaOp& y, - xla::XlaBuilder* builder) -> xla::XlaOp { return x + y; }; - auto lower_fn = [accumulate](const ir::Node& node, - ir::LoweringContext* loctx) -> ir::XlaOpVector { - xla::XlaOp xla_base = loctx->GetOutputOp(node.operand(0)); - xla::XlaOp xla_indices = loctx->GetOutputOp(node.operand(1)); - xla::XlaOp xla_values = loctx->GetOutputOp(node.operand(2)); - return node.ReturnOp( - CreateIndexUpdate(xla_base, xla_indices, xla_values, - accumulate ? add_scatter_combiner : nullptr), - loctx); - }; - auto lower_for_shape_fn = - [&](tensorflow::gtl::ArraySlice operands) - -> xla::XlaOp { - // The combiner doesn't matter for shape. - return CreateIndexUpdate(operands[0], operands[1], operands[2], nullptr); - }; - return ir::ops::GenericOp( - ir::OpKind(at::aten::index_put), {buffer, indices, values}, - [&]() { - return ir::ops::InferOutputShape( - {buffer.shape(), indices.shape(), values.shape()}, - lower_for_shape_fn); - }, - std::move(lower_fn)); -} - ir::NodePtr IndexFillOp(const ir::Value& buffer, xla::int64 dim, const ir::Value& index, const ir::Value& value) { auto lower_fn = [dim](const ir::Node& node, @@ -287,19 +256,20 @@ XLATensor IndexByTensors(const XLATensor& base, ir::Value IndexPutByTensors( const XLATensor& base, tensorflow::gtl::ArraySlice indices, - const XLATensor& values, bool accumulate, + xla::int64 start_dim, const XLATensor& values, bool accumulate, tensorflow::gtl::ArraySlice result_permutation) { if (indices.empty()) { return base.GetIrValue(); } - auto canonical_indices = WrapIndicesOnce(base, indices, 0); + auto canonical_indices = WrapIndicesOnce(base, indices, start_dim); xla::int64 indices_rank = canonical_indices.front().shape().get().rank(); // Stack the indices to allow the whole multi-indexing to be dispatched with a // single scatter. XLATensor indices_nd = XLATensor::stack(canonical_indices, indices_rank); return ir::MakeNode( - IndexPutOp(base.GetIrValue(), indices_nd.GetIrValue(), - values.GetIrValue(), accumulate), + ir::MakeNode(base.GetIrValue(), + indices_nd.GetIrValue(), start_dim, + values.GetIrValue(), accumulate), xla::util::ToVector(result_permutation)); } diff --git a/torch_xla/csrc/ops/index_ops.h b/torch_xla/csrc/ops/index_ops.h index 62ada9ab239c..7a28d094f85e 100644 --- a/torch_xla/csrc/ops/index_ops.h +++ b/torch_xla/csrc/ops/index_ops.h @@ -56,7 +56,7 @@ XLATensor IndexByTensors(const XLATensor& base, ir::Value IndexPutByTensors( const XLATensor& base, tensorflow::gtl::ArraySlice indices, - const XLATensor& updates, bool accumulate, + xla::int64 start_dim, const XLATensor& updates, bool accumulate, tensorflow::gtl::ArraySlice result_permutation); ir::NodePtr IndexFill(const XLATensor& base, xla::int64 dim, diff --git a/torch_xla/csrc/ops/index_put.cpp b/torch_xla/csrc/ops/index_put.cpp new file mode 100644 index 000000000000..6e0e69d6c3e0 --- /dev/null +++ b/torch_xla/csrc/ops/index_put.cpp @@ -0,0 +1,43 @@ +#include "torch_xla/csrc/ops/index_put.h" + +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/xla_lower_util.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +IndexPut::IndexPut(const ir::Value& base, const ir::Value& indices, + xla::int64 start_dim, const ir::Value& values, + bool accumulate) + : Node(OpKind(at::aten::index_put), {base, indices, values}, base.shape(), + /*num_outputs=*/1, xla::util::MHash(start_dim, accumulate)), + start_dim_(start_dim), + accumulate_(accumulate) {} + +std::string IndexPut::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", start_dim=" << start_dim_ + << ", accumulate=" << accumulate_; + return ss.str(); +} + +XlaOpVector IndexPut::Lower(LoweringContext* loctx) const { + std::function + add_scatter_combiner = + [](const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) -> xla::XlaOp { return x + y; }; + + xla::XlaOp base = loctx->GetOutputOp(operand(0)); + xla::XlaOp indices = loctx->GetOutputOp(operand(1)); + xla::XlaOp values = loctx->GetOutputOp(operand(2)); + xla::XlaOp output = + CreateIndexUpdate(base, indices, start_dim_, values, + accumulate_ ? add_scatter_combiner : nullptr); + return ReturnOp(output, loctx); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/index_put.h b/torch_xla/csrc/ops/index_put.h new file mode 100644 index 000000000000..b356220beec5 --- /dev/null +++ b/torch_xla/csrc/ops/index_put.h @@ -0,0 +1,31 @@ +#pragma once + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class IndexPut : public Node { + public: + IndexPut(const ir::Value& base, const ir::Value& indices, + xla::int64 start_dim, const ir::Value& values, bool accumulate); + + std::string ToString() const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + xla::int64 start_dim() const { return start_dim_; } + + bool accumulate() const { return accumulate_; } + + private: + // The dimension number at which indexing starts. + xla::int64 start_dim_; + // Whether to accumulate instead of set. + bool accumulate_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 66c4768cef1e..6a6b3d8796de 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -511,13 +511,13 @@ class XLATensor { static XLATensor index_put( const XLATensor& input, tensorflow::gtl::ArraySlice indices, - const XLATensor& values, bool accumulate, + xla::int64 start_dim, const XLATensor& values, bool accumulate, tensorflow::gtl::ArraySlice result_permutation); static void index_put_( XLATensor& input, const XLATensor& canonical_base, tensorflow::gtl::ArraySlice indices, - const XLATensor& values, bool accumulate, + xla::int64 start_dim, const XLATensor& values, bool accumulate, tensorflow::gtl::ArraySlice result_permutation); static XLATensor index_select(const XLATensor& input, xla::int64 dim, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 775aeae3ca2a..90471c8ea0ad 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1121,19 +1121,19 @@ void XLATensor::index_fill_(XLATensor& input, xla::int64 dim, XLATensor XLATensor::index_put( const XLATensor& input, - tensorflow::gtl::ArraySlice indices, + tensorflow::gtl::ArraySlice indices, xla::int64 start_dim, const XLATensor& values, bool accumulate, tensorflow::gtl::ArraySlice result_permutation) { - return input.CreateFrom(IndexPutByTensors(input, indices, values, accumulate, - result_permutation)); + return input.CreateFrom(IndexPutByTensors(input, indices, start_dim, values, + accumulate, result_permutation)); } void XLATensor::index_put_( XLATensor& input, const XLATensor& canonical_base, - tensorflow::gtl::ArraySlice indices, + tensorflow::gtl::ArraySlice indices, xla::int64 start_dim, const XLATensor& values, bool accumulate, tensorflow::gtl::ArraySlice result_permutation) { - input.SetIrValue(IndexPutByTensors(canonical_base, indices, values, + input.SetIrValue(IndexPutByTensors(canonical_base, indices, start_dim, values, accumulate, result_permutation)); } diff --git a/torch_xla/csrc/tensor_ops.cpp b/torch_xla/csrc/tensor_ops.cpp index cefdfddcb39e..2adf2ccc745d 100644 --- a/torch_xla/csrc/tensor_ops.cpp +++ b/torch_xla/csrc/tensor_ops.cpp @@ -198,7 +198,8 @@ XLATensor EmbeddingDenseBackward(const XLATensor& grad_output, XLATensor::full({num_weights}, 0, indices.GetDevice(), indices.dtype()); XLATensor ones = XLATensor::full({numel}, 1, indices.GetDevice(), indices.dtype()); - XLATensor::index_put_(counts, counts, {indices_rank1}, ones, + XLATensor::index_put_(counts, counts, {indices_rank1}, /*start_dim=*/0, + /*values=*/ones, /*accumulate=*/true, /*result_permutation=*/{0}); XLATensor grad_weights_scale = XLATensor::index(counts, {indices_rank1}, 0); // Scale the value of the gradient by the histogram. @@ -212,10 +213,12 @@ XLATensor EmbeddingDenseBackward(const XLATensor& grad_output, XLATensor::expand(skip_padding, grad.shape().get().dimensions()); XLATensor zero_grad = XLATensor::full_like(grad, 0, grad.GetDevice(), grad.dtype()); - return XLATensor::index_put(grad_weight, {indices_rank1}, - XLATensor::where(skip_padding, grad, zero_grad), - /*accumulate=*/true, - /*result_permutation=*/{0, 1}); + return XLATensor::index_put( + grad_weight, {indices_rank1}, + /*start_dim=*/0, + /*values=*/XLATensor::where(skip_padding, grad, zero_grad), + /*accumulate=*/true, + /*result_permutation=*/{0, 1}); } } // namespace tensor_ops diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index d06dc7819064..cf0f1f855898 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -338,7 +338,7 @@ xla::XlaOp CreateIndex(const xla::XlaOp& input, const xla::XlaOp& indices, } xla::XlaOp CreateIndexUpdate( - const xla::XlaOp& buffer, const xla::XlaOp& indices, + const xla::XlaOp& buffer, const xla::XlaOp& indices, xla::int64 start_dim, const xla::XlaOp& values, const std::function& combiner) { @@ -360,9 +360,13 @@ xla::XlaOp CreateIndexUpdate( xla::int64 num_window_dims_in_values = buffer_rank - num_index_dims; // Make the values match the rank expected by scatter. - std::vector expected_values_dims(indices_dims.begin(), - indices_dims.end()); - for (xla::int64 dim = num_index_dims; dim < buffer_rank; ++dim) { + std::vector expected_values_dims; + for (xla::int64 dim = 0; dim < start_dim; ++dim) { + expected_values_dims.push_back(buffer_shape.dimensions(dim)); + } + expected_values_dims.insert(expected_values_dims.end(), indices_dims.begin(), + indices_dims.end()); + for (xla::int64 dim = num_index_dims + start_dim; dim < buffer_rank; ++dim) { expected_values_dims.push_back(buffer_shape.dimensions(dim)); } xla::XlaOp new_values = values; @@ -374,13 +378,16 @@ xla::XlaOp CreateIndexUpdate( values_shape = XlaHelpers::ShapeOfXlaOp(new_values); values_rank = values_shape.rank(); - for (xla::int64 i = (values_rank - num_window_dims_in_values); + for (xla::int64 dim = 0; dim < start_dim; ++dim) { + dim_numbers.add_update_window_dims(dim); + } + for (xla::int64 i = values_rank - num_window_dims_in_values + start_dim; i < values_rank; ++i) { dim_numbers.add_update_window_dims(i); } for (xla::int64 i = 0; i < num_index_dims; ++i) { - dim_numbers.add_inserted_window_dims(i); - dim_numbers.add_scatter_dims_to_operand_dims(i); + dim_numbers.add_inserted_window_dims(i + start_dim); + dim_numbers.add_scatter_dims_to_operand_dims(i + start_dim); } xla::XlaComputation combiner_computation = MakeScatterComputation(combiner, buffer_shape.element_type()); diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index 51fa0dec569b..1c643aad9240 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -30,7 +30,7 @@ xla::XlaOp CreateIndex(const xla::XlaOp& input, const xla::XlaOp& indices, // Similar to tf.scatter_nd, used to implement advanced indexing updates. xla::XlaOp CreateIndexUpdate( - const xla::XlaOp& buffer, const xla::XlaOp& indices, + const xla::XlaOp& buffer, const xla::XlaOp& indices, xla::int64 start_dim, const xla::XlaOp& updates, const std::function& combiner);