From eee5a6e1536d9055afed71eefd4b74df63909f84 Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Sat, 11 Jan 2020 18:06:58 -0800 Subject: [PATCH] Added masked_scatter support. --- test/cpp/run_tests.sh | 2 +- test/cpp/test_aten_xla_tensor.cpp | 25 ++++++ test/run_tests.sh | 2 +- torch_xla/csrc/aten_xla_type.cpp | 17 ++++ torch_xla/csrc/aten_xla_type.h | 3 + torch_xla/csrc/helpers.cpp | 6 ++ torch_xla/csrc/helpers.h | 3 + torch_xla/csrc/ops/masked_scatter.cpp | 31 +++++++ torch_xla/csrc/ops/masked_scatter.h | 23 +++++ torch_xla/csrc/tensor.h | 3 + torch_xla/csrc/tensor_methods.cpp | 27 ++++-- torch_xla/csrc/xla_lower_util.cpp | 124 ++++++++++++++++---------- torch_xla/csrc/xla_lower_util.h | 9 +- 13 files changed, 218 insertions(+), 57 deletions(-) create mode 100644 torch_xla/csrc/ops/masked_scatter.cpp create mode 100644 torch_xla/csrc/ops/masked_scatter.h diff --git a/test/cpp/run_tests.sh b/test/cpp/run_tests.sh index 6fbfcedf3df..07ee5ce3c3e 100755 --- a/test/cpp/run_tests.sh +++ b/test/cpp/run_tests.sh @@ -8,7 +8,7 @@ FILTER= BUILD_ONLY=0 RMBUILD=1 LOGFILE=/tmp/pytorch_cpp_test.log -XLA_EXPERIMENTAL="nonzero:masked_select" +XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter" if [ "$DEBUG" == "1" ]; then BUILDTYPE="Debug" diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 3aef9378a91..e06f85ef32b 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -3804,6 +3804,31 @@ TEST_F(AtenXlaTensorTest, TestMaskedSelect) { }); } +TEST_F(AtenXlaTensorTest, TestMaskedScatter) { + torch::Tensor a = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat)); + torch::Tensor b = + torch::randint(0, 2, {3, 5}, torch::TensorOptions(torch::kBool)); + torch::Tensor c = torch::rand({15}, torch::TensorOptions(torch::kFloat)); + torch::Tensor d = torch::masked_scatter(a, b, c); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + torch::Tensor xla_b = CopyToDevice(b, device); + torch::Tensor xla_c = CopyToDevice(c, device); + torch::Tensor xla_d = torch::masked_scatter(xla_a, xla_b, xla_c); + AllClose(d, xla_d); + + if (DebugUtil::ExperimentEnabled("masked_scatter") && + bridge::AtenDeviceToXlaDevice(device).hw_type == DeviceType::TPU) { + // If the masked_select support is enabled, we must not see any aten:: + // calls. + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + } + ExpectCounterChanged("xla::masked_scatter_", + cpp_test::GetIgnoredCounters()); + ResetCounters(); + }); +} + TEST_F(AtenXlaTensorTest, TestMultiIndexHeadNull) { for (torch::ScalarType scalar_type : {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, diff --git a/test/run_tests.sh b/test/run_tests.sh index dea0e7d8965..40bbd3fbd18 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -36,7 +36,7 @@ function run_opbyop { } function run_dynamic { - XLA_EXPERIMENTAL="nonzero:masked_select" "$@" + XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter" "$@" } function run_all_tests { diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 2957187babe..ee418a61c94 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1596,6 +1596,23 @@ at::Tensor& AtenXlaType::masked_fill_(at::Tensor& self, const at::Tensor& mask, return masked_fill_(self, mask, value.item()); } +at::Tensor& AtenXlaType::masked_scatter_(at::Tensor& self, + const at::Tensor& mask, + const at::Tensor& source) { + XLA_FN_COUNTER("xla::"); + XLATensor self_tensor = bridge::GetXlaTensor(self); + // Initially make XLA handled masked_scatter_() handling experimental, and + // opt-in. Only the XLA TPU backend for now implements the dynamic dimension + // setting required by the masked_scatter_ implementation. + if (!DebugUtil::ExperimentEnabled("masked_scatter") || + self_tensor.GetDevice().hw_type != DeviceType::TPU) { + return AtenXlaTypeDefault::masked_scatter_(self, mask, source); + } + XLATensor::masked_scatter_(self_tensor, bridge::GetXlaTensor(mask), + bridge::GetXlaTensor(source)); + return self; +} + at::Tensor AtenXlaType::masked_select(const at::Tensor& self, const at::Tensor& mask) { XLA_FN_COUNTER("xla::"); diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index 477baf40d4d..4afdeccff7c 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -509,6 +509,9 @@ class AtenXlaType { static at::Tensor& masked_fill_(at::Tensor& self, const at::Tensor& mask, const at::Tensor& value); + static at::Tensor& masked_scatter_(at::Tensor& self, const at::Tensor& mask, + const at::Tensor& source); + static at::Tensor masked_select(const at::Tensor& self, const at::Tensor& mask); diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 5c00f68df1a..038fafeae52 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -334,6 +334,12 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input, dynamic_dimension); } +bool XlaHelpers::SameStaticDimensions(const xla::Shape& shape1, + const xla::Shape& shape2) { + return shape1.is_static() && shape2.is_static() && + shape1.dimensions() == shape2.dimensions(); +} + xla::XlaOp XlaHelpers::Flatten(xla::XlaOp input, xla::Shape* input_shape) { xla::util::MaybePtr input_shape_tmp(input_shape); *input_shape_tmp = ShapeOfXlaOp(input); diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index ef2af0e02b4..08e650dbecd 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -137,6 +137,9 @@ class XlaHelpers { static xla::XlaOp DynamicReshapeAs(xla::XlaOp input, const xla::Shape& shape); + static bool SameStaticDimensions(const xla::Shape& shape1, + const xla::Shape& shape2); + // Creates a convolution or dot precision configuration. static xla::PrecisionConfig BuildPrecisionConfig( const xla::PrecisionConfig::Precision conv_precision); diff --git a/torch_xla/csrc/ops/masked_scatter.cpp b/torch_xla/csrc/ops/masked_scatter.cpp new file mode 100644 index 00000000000..6d9a3212705 --- /dev/null +++ b/torch_xla/csrc/ops/masked_scatter.cpp @@ -0,0 +1,31 @@ +#include "torch_xla/csrc/ops/masked_scatter.h" + +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/xla_lower_util.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +MaskedScatter::MaskedScatter(const Value& input, const Value& mask, + const Value& source) + : Node(ir::OpKind(at::aten::masked_scatter), {input, mask, source}, + input.shape(), + /*num_outputs=*/1) {} + +NodePtr MaskedScatter::Clone(OpList operands) const { + return MakeNode(operands.at(0), operands.at(1), + operands.at(2)); +} + +XlaOpVector MaskedScatter::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp mask = loctx->GetOutputOp(operand(1)); + xla::XlaOp source = loctx->GetOutputOp(operand(2)); + return ReturnOp(BuildMaskedScatter(input, mask, source), loctx); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/masked_scatter.h b/torch_xla/csrc/ops/masked_scatter.h new file mode 100644 index 00000000000..6a1ac2d3b4e --- /dev/null +++ b/torch_xla/csrc/ops/masked_scatter.h @@ -0,0 +1,23 @@ +#pragma once + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +// This node has no metadata, so it could have been implemented as generic-op in +// ops.cpp, but since this might require special handling from upper IR layers, +// it gets its own IR node class. +class MaskedScatter : public Node { + public: + MaskedScatter(const Value& input, const Value& mask, const Value& source); + + NodePtr Clone(OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 34281bc44b3..230d77ffdb4 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -612,6 +612,9 @@ class XLATensor { static void masked_fill_(XLATensor& input, const XLATensor& mask, at::Scalar value); + static void masked_scatter_(XLATensor& input, const XLATensor& mask, + const XLATensor& source); + static XLATensor masked_select(const XLATensor& input, const XLATensor& mask); static XLATensor matmul(const XLATensor& input, const XLATensor& other); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 81819fe342b..33569be8ad3 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -57,6 +57,7 @@ #include "torch_xla/csrc/ops/linear_interpolation.h" #include "torch_xla/csrc/ops/log_softmax.h" #include "torch_xla/csrc/ops/masked_fill.h" +#include "torch_xla/csrc/ops/masked_scatter.h" #include "torch_xla/csrc/ops/masked_select.h" #include "torch_xla/csrc/ops/max_in_dim.h" #include "torch_xla/csrc/ops/max_pool_nd.h" @@ -234,6 +235,14 @@ absl::optional GetOptionalIrValue(const XLATensor& tensor) { return value; } +ir::Value MaybeExpand(const ir::Value& input, const xla::Shape& target_shape) { + if (input.shape().dimensions() == target_shape.dimensions()) { + return input; + } + return ir::MakeNode( + input, xla::util::ToVector(target_shape.dimensions())); +} + void CheckIsIntegralOrPred(const xla::Shape& shape, const std::string& op_name) { XLA_CHECK(xla::ShapeUtil::ElementIsIntegral(shape) || @@ -1385,12 +1394,18 @@ void XLATensor::lt_(XLATensor& input, const XLATensor& other) { void XLATensor::masked_fill_(XLATensor& input, const XLATensor& mask, at::Scalar value) { - // Expand mask to be the same size as input. - ir::NodePtr expanded_mask = ir::MakeNode( - mask.GetIrValue(), - xla::util::ToVector(input.shape().get().dimensions())); - input.SetIrValue(ir::MakeNode(input.GetIrValue(), - expanded_mask, value)); + ir::ScopePusher ir_scope(at::aten::masked_fill.toQualString()); + input.SetIrValue(ir::MakeNode( + input.GetIrValue(), MaybeExpand(mask.GetIrValue(), input.shape()), + value)); +} + +void XLATensor::masked_scatter_(XLATensor& input, const XLATensor& mask, + const XLATensor& source) { + ir::ScopePusher ir_scope(at::aten::masked_scatter.toQualString()); + input.SetIrValue(ir::MakeNode( + input.GetIrValue(), MaybeExpand(mask.GetIrValue(), input.shape()), + source.GetIrValue())); } XLATensor XLATensor::masked_select(const XLATensor& input, diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 2b20b4d7ea6..d9418fb5d6b 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -23,30 +23,48 @@ namespace { struct ConditionMaskData { xla::Shape iota_shape; xla::int64 flattened_size; - xla::XlaOp reshaped_condition_int; + xla::XlaOp r1_condition_int; + xla::PrimitiveType condition_int_type; xla::XlaOp length; }; ConditionMaskData CreateConditionMaskData(xla::XlaOp condition) { + static const xla::PrimitiveType kConditionType = xla::PrimitiveType::S32; + static const xla::PrimitiveType kIotaType = xla::PrimitiveType::S32; xla::Shape iota_shape = XlaHelpers::ShapeOfXlaOp(condition); - iota_shape.set_element_type(xla::PrimitiveType::S32); + iota_shape.set_element_type(kIotaType); - xla::int64 flattened_size = xla::Product(iota_shape.dimensions()); - xla::XlaOp reshaped_condition = + xla::int64 flattened_size = xla::ShapeUtil::ElementsIn(iota_shape); + xla::XlaOp r1_condition = XlaHelpers::DynamicReshape(condition, {flattened_size}); - xla::XlaOp zeros = xla::ZerosLike(reshaped_condition); - xla::XlaOp zeros_int = - xla::ConvertElementType(zeros, xla::PrimitiveType::S32); - xla::XlaOp reshaped_condition_int = - xla::ConvertElementType(reshaped_condition, xla::PrimitiveType::S32); - xla::XlaOp compared = xla::ConvertElementType( - xla::Gt(reshaped_condition_int, zeros_int), xla::PrimitiveType::S32); + xla::XlaOp r1_condition_int = + xla::ConvertElementType(r1_condition, kConditionType); + xla::XlaOp zeros = xla::ZerosLike(r1_condition_int); + xla::XlaOp compared = + xla::ConvertElementType(xla::Gt(r1_condition_int, zeros), kConditionType); xla::XlaOp length = xla::ReduceAll( - compared, xla::Zero(condition.builder(), xla::PrimitiveType::S32), - xla::CreateScalarAddComputation(xla::PrimitiveType::S32, - condition.builder())); - return {std::move(iota_shape), flattened_size, reshaped_condition_int, - length}; + compared, xla::Zero(condition.builder(), kConditionType), + xla::CreateScalarAddComputation(kConditionType, condition.builder())); + return {std::move(iota_shape), flattened_size, r1_condition_int, + kConditionType, length}; +} + +xla::XlaOp GetPromotedR1Mask(xla::XlaOp mask, const xla::Shape& input_shape) { + const xla::Shape& mask_shape = XlaHelpers::ShapeOfXlaOp(mask); + xla::Shape promoted_mask_shape = + XlaHelpers::GetPromotedShape(mask_shape, input_shape); + xla::XlaOp bcast_mask = + XlaHelpers::ImplicitBroadcast(mask, mask_shape, promoted_mask_shape); + return XlaHelpers::Flatten(bcast_mask); +} + +bool ShouldUseDenseScatter(const xla::Shape& input_shape, + const xla::Shape& index_shape) { + static int dense_scatter_factor = + xla::sys_util::GetEnvInt("XLA_DENSE_SCATTER_FACTOR", 100); + xla::int64 input_elements = xla::ShapeUtil::ElementsIn(input_shape); + xla::int64 index_elements = xla::ShapeUtil::ElementsIn(index_shape); + return index_elements * dense_scatter_factor >= input_elements; } xla::XlaOp DotExpand(xla::XlaOp op, const xla::Shape& op_shape, @@ -177,8 +195,8 @@ xla::XlaOp XlaDenseScatter( // a stable implementation. xla::XlaBuilder* builder = input.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape index_shape, builder->GetShape(index)); - TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); + const xla::Shape& index_shape = XlaHelpers::ShapeOfXlaOp(index); + const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); std::vector index_broacast_dims; std::vector sizes; for (xla::int64 i = 0; i < index_shape.rank(); ++i) { @@ -207,7 +225,7 @@ xla::XlaOp XlaDenseScatter( xla::CreateScalarIdentityWithZeroComputation(input_shape.element_type(), builder), {dim + 1}); - if (index_shape.dimensions() == input_shape.dimensions()) { + if (XlaHelpers::SameStaticDimensions(index_shape, input_shape)) { // If the index shape is the same as the input shape, the input shape will // be fully covered (since scatter indices must be unique), so there is no // need for masking. @@ -218,12 +236,8 @@ xla::XlaOp XlaDenseScatter( xla::CreateScalarOrComputation(xla::PrimitiveType::PRED, builder), {dim + 1}); if (ScatterRequiresPadding(input_shape, index_shape)) { - masked_src = - PadToSize(masked_src, xla::Zero(builder, input_shape.element_type()), - input_shape.dimensions()); - reduced_mask = - PadToSize(reduced_mask, xla::ConstantR0(builder, false), - input_shape.dimensions()); + masked_src = PadToSize(masked_src, input_shape.dimensions()); + reduced_mask = PadToSize(reduced_mask, input_shape.dimensions()); } xla::XlaOp result; if (combiner != nullptr) { @@ -237,13 +251,13 @@ xla::XlaOp XlaDenseScatter( std::vector BuildConditionIndices(xla::XlaOp condition) { ConditionMaskData cmd = CreateConditionMaskData(condition); - std::vector to_sort = {cmd.reshaped_condition_int}; - std::vector types_to_sort = {xla::PrimitiveType::S32}; + std::vector to_sort = {cmd.r1_condition_int}; + std::vector types_to_sort = {cmd.condition_int_type}; for (xla::int64 axis = 0; axis < cmd.iota_shape.rank(); ++axis) { xla::XlaOp iota = xla::Iota(condition.builder(), cmd.iota_shape, axis); xla::XlaOp reshaped = xla::Reshape(iota, {cmd.flattened_size}); to_sort.push_back(reshaped); - types_to_sort.push_back(xla::PrimitiveType::S32); + types_to_sort.push_back(cmd.iota_shape.element_type()); } xla::XlaOp sorted = xla::Sort( @@ -265,11 +279,14 @@ std::vector BuildConditionIndices(xla::XlaOp condition) { } // namespace -xla::XlaOp PadToSize(xla::XlaOp input, xla::XlaOp pad_value, - tensorflow::gtl::ArraySlice size) { +xla::XlaOp PadToSize(xla::XlaOp input, + tensorflow::gtl::ArraySlice size, + absl::optional pad_value) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); XLA_CHECK_EQ(input_shape.rank(), size.size()); - + if (!pad_value) { + pad_value = xla::Zero(input.builder(), input_shape.element_type()); + } xla::PaddingConfig padding_config; for (size_t i = 0; i < size.size(); i++) { auto* dims = padding_config.add_dimensions(); @@ -278,7 +295,7 @@ xla::XlaOp PadToSize(xla::XlaOp input, xla::XlaOp pad_value, XLA_CHECK_GE(size[i], input_shape.dimensions(i)); dims->set_edge_padding_high(size[i] - input_shape.dimensions(i)); } - return xla::Pad(input, pad_value, padding_config); + return xla::Pad(input, *pad_value, padding_config); } std::vector CreateKthValue(xla::XlaOp input, xla::int64 k, @@ -586,8 +603,6 @@ XlaOpCombiner NumericAddCombiner() { xla::XlaOp CreateScatter(xla::XlaOp input, xla::XlaOp index, xla::XlaOp source, xla::int64 dim, const XlaOpCombiner& combiner) { - static int dense_scatter_factor = - xla::sys_util::GetEnvInt("XLA_DENSE_SCATTER_FACTOR", 100); const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); xla::Shape index_shape = XlaHelpers::ShapeOfXlaOp(index); const xla::Shape& source_shape = XlaHelpers::ShapeOfXlaOp(source); @@ -597,10 +612,7 @@ xla::XlaOp CreateScatter(xla::XlaOp input, xla::XlaOp index, xla::XlaOp source, std::vector base_indices(source_shape.rank(), 0); source_op = BuildSlice(source_op, base_indices, index_shape.dimensions()); } - - xla::int64 input_elements = xla::ShapeUtil::ElementsIn(input_shape); - xla::int64 index_elements = xla::ShapeUtil::ElementsIn(index_shape); - if (index_elements >= input_elements / dense_scatter_factor) { + if (ShouldUseDenseScatter(input_shape, index_shape)) { return XlaDenseScatter(input, index, source_op, dim, combiner); } @@ -658,16 +670,10 @@ std::vector BuildNonZero(xla::XlaOp input) { std::vector BuildMaskedSelect(xla::XlaOp input, xla::XlaOp mask) { xla::Shape input_shape; xla::XlaOp r1_input = XlaHelpers::Flatten(input, &input_shape); - const xla::Shape& mask_shape = XlaHelpers::ShapeOfXlaOp(mask); - xla::Shape promoted_mask_shape = - XlaHelpers::GetPromotedShape(mask_shape, input_shape); - xla::XlaOp bcast_mask = - XlaHelpers::ImplicitBroadcast(mask, mask_shape, promoted_mask_shape); - xla::XlaOp r1_bcast_mask = XlaHelpers::Flatten(bcast_mask); - + xla::XlaOp r1_bcast_mask = GetPromotedR1Mask(mask, input_shape); ConditionMaskData cmd = CreateConditionMaskData(r1_bcast_mask); - std::vector to_sort = {cmd.reshaped_condition_int, r1_input}; - std::vector types_to_sort = {xla::PrimitiveType::S32, + std::vector to_sort = {cmd.r1_condition_int, r1_input}; + std::vector types_to_sort = {cmd.condition_int_type, input_shape.element_type()}; xla::XlaOp sorted = xla::Sort( to_sort, xla::CreateScalarGtComputation(types_to_sort, input.builder()), @@ -679,4 +685,28 @@ std::vector BuildMaskedSelect(xla::XlaOp input, xla::XlaOp mask) { return {sorted_input_padded, cmd.length}; } +xla::XlaOp BuildMaskedScatter(xla::XlaOp input, xla::XlaOp mask, + xla::XlaOp source) { + xla::Shape input_shape; + xla::XlaOp r1_input = XlaHelpers::Flatten(input, &input_shape); + xla::XlaOp r1_bcast_mask = GetPromotedR1Mask(mask, input_shape); + xla::Shape source_shape; + xla::XlaOp r1_source = XlaHelpers::Flatten(source, &source_shape); + + auto indices = BuildConditionIndices(r1_bcast_mask); + xla::XlaOp mask_indices = indices[0]; + xla::XlaOp num_indices = indices[1]; + + xla::int64 input_size = xla::ShapeUtil::ElementsIn(input_shape); + if (input_size > xla::ShapeUtil::ElementsIn(source_shape)) { + r1_source = PadToSize(r1_source, {input_size}); + } + r1_source = xla::SetDimensionSize(r1_source, num_indices, 0); + + xla::XlaOp r1_index = XlaHelpers::Flatten(mask_indices); + xla::XlaOp r1_scatter = CreateScatter(r1_input, r1_index, r1_source, + /*dim=*/0, /*combiner=*/nullptr); + return XlaHelpers::DynamicReshapeAs(r1_scatter, input_shape); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index 6c459094ff9..8997cae784f 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -2,13 +2,15 @@ #include +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace torch_xla { -xla::XlaOp PadToSize(xla::XlaOp input, xla::XlaOp pad_value, - tensorflow::gtl::ArraySlice size); +xla::XlaOp PadToSize(xla::XlaOp input, + tensorflow::gtl::ArraySlice size, + absl::optional pad_value = absl::nullopt); std::vector CreateKthValue(xla::XlaOp input, xla::int64 k, xla::int64 dim, bool keepdim); @@ -62,4 +64,7 @@ std::vector BuildNonZero(xla::XlaOp input); std::vector BuildMaskedSelect(xla::XlaOp input, xla::XlaOp mask); +xla::XlaOp BuildMaskedScatter(xla::XlaOp input, xla::XlaOp mask, + xla::XlaOp source); + } // namespace torch_xla