Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/cpp/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ FILTER=
BUILD_ONLY=0
RMBUILD=1
LOGFILE=/tmp/pytorch_cpp_test.log
XLA_EXPERIMENTAL="nonzero:masked_select"
XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter"

if [ "$DEBUG" == "1" ]; then
BUILDTYPE="Debug"
Expand Down
25 changes: 25 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3804,6 +3804,31 @@ TEST_F(AtenXlaTensorTest, TestMaskedSelect) {
});
}

TEST_F(AtenXlaTensorTest, TestMaskedScatter) {
torch::Tensor a = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor b =
torch::randint(0, 2, {3, 5}, torch::TensorOptions(torch::kBool));
torch::Tensor c = torch::rand({15}, torch::TensorOptions(torch::kFloat));
torch::Tensor d = torch::masked_scatter(a, b, c);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
torch::Tensor xla_b = CopyToDevice(b, device);
torch::Tensor xla_c = CopyToDevice(c, device);
torch::Tensor xla_d = torch::masked_scatter(xla_a, xla_b, xla_c);
AllClose(d, xla_d);

if (DebugUtil::ExperimentEnabled("masked_scatter") &&
bridge::AtenDeviceToXlaDevice(device).hw_type == DeviceType::TPU) {
// If the masked_select support is enabled, we must not see any aten::
// calls.
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
}
ExpectCounterChanged("xla::masked_scatter_",
cpp_test::GetIgnoredCounters());
ResetCounters();
});
}

TEST_F(AtenXlaTensorTest, TestMultiIndexHeadNull) {
for (torch::ScalarType scalar_type :
{torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt,
Expand Down
2 changes: 1 addition & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ function run_opbyop {
}

function run_dynamic {
XLA_EXPERIMENTAL="nonzero:masked_select" "$@"
XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter" "$@"
}

function run_all_tests {
Expand Down
17 changes: 17 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1596,6 +1596,23 @@ at::Tensor& AtenXlaType::masked_fill_(at::Tensor& self, const at::Tensor& mask,
return masked_fill_(self, mask, value.item());
}

at::Tensor& AtenXlaType::masked_scatter_(at::Tensor& self,
const at::Tensor& mask,
const at::Tensor& source) {
XLA_FN_COUNTER("xla::");
XLATensor self_tensor = bridge::GetXlaTensor(self);
// Initially make XLA handled masked_scatter_() handling experimental, and
// opt-in. Only the XLA TPU backend for now implements the dynamic dimension
// setting required by the masked_scatter_ implementation.
if (!DebugUtil::ExperimentEnabled("masked_scatter") ||
self_tensor.GetDevice().hw_type != DeviceType::TPU) {
return AtenXlaTypeDefault::masked_scatter_(self, mask, source);
}
XLATensor::masked_scatter_(self_tensor, bridge::GetXlaTensor(mask),
bridge::GetXlaTensor(source));
return self;
}

at::Tensor AtenXlaType::masked_select(const at::Tensor& self,
const at::Tensor& mask) {
XLA_FN_COUNTER("xla::");
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,9 @@ class AtenXlaType {
static at::Tensor& masked_fill_(at::Tensor& self, const at::Tensor& mask,
const at::Tensor& value);

static at::Tensor& masked_scatter_(at::Tensor& self, const at::Tensor& mask,
const at::Tensor& source);

static at::Tensor masked_select(const at::Tensor& self,
const at::Tensor& mask);

Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,12 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input,
dynamic_dimension);
}

bool XlaHelpers::SameStaticDimensions(const xla::Shape& shape1,
const xla::Shape& shape2) {
return shape1.is_static() && shape2.is_static() &&
shape1.dimensions() == shape2.dimensions();
}

xla::XlaOp XlaHelpers::Flatten(xla::XlaOp input, xla::Shape* input_shape) {
xla::util::MaybePtr<xla::Shape> input_shape_tmp(input_shape);
*input_shape_tmp = ShapeOfXlaOp(input);
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ class XlaHelpers {

static xla::XlaOp DynamicReshapeAs(xla::XlaOp input, const xla::Shape& shape);

static bool SameStaticDimensions(const xla::Shape& shape1,
const xla::Shape& shape2);

// Creates a convolution or dot precision configuration.
static xla::PrecisionConfig BuildPrecisionConfig(
const xla::PrecisionConfig::Precision conv_precision);
Expand Down
31 changes: 31 additions & 0 deletions torch_xla/csrc/ops/masked_scatter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "torch_xla/csrc/ops/masked_scatter.h"

#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/xla_lower_util.h"

namespace torch_xla {
namespace ir {
namespace ops {

MaskedScatter::MaskedScatter(const Value& input, const Value& mask,
const Value& source)
: Node(ir::OpKind(at::aten::masked_scatter), {input, mask, source},
input.shape(),
/*num_outputs=*/1) {}

NodePtr MaskedScatter::Clone(OpList operands) const {
return MakeNode<MaskedScatter>(operands.at(0), operands.at(1),
operands.at(2));
}

XlaOpVector MaskedScatter::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp mask = loctx->GetOutputOp(operand(1));
xla::XlaOp source = loctx->GetOutputOp(operand(2));
return ReturnOp(BuildMaskedScatter(input, mask, source), loctx);
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
23 changes: 23 additions & 0 deletions torch_xla/csrc/ops/masked_scatter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#pragma once

#include "torch_xla/csrc/ir.h"

namespace torch_xla {
namespace ir {
namespace ops {

// This node has no metadata, so it could have been implemented as generic-op in
// ops.cpp, but since this might require special handling from upper IR layers,
// it gets its own IR node class.
class MaskedScatter : public Node {
public:
MaskedScatter(const Value& input, const Value& mask, const Value& source);

NodePtr Clone(OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
3 changes: 3 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,9 @@ class XLATensor {
static void masked_fill_(XLATensor& input, const XLATensor& mask,
at::Scalar value);

static void masked_scatter_(XLATensor& input, const XLATensor& mask,
const XLATensor& source);

static XLATensor masked_select(const XLATensor& input, const XLATensor& mask);

static XLATensor matmul(const XLATensor& input, const XLATensor& other);
Expand Down
27 changes: 21 additions & 6 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include "torch_xla/csrc/ops/linear_interpolation.h"
#include "torch_xla/csrc/ops/log_softmax.h"
#include "torch_xla/csrc/ops/masked_fill.h"
#include "torch_xla/csrc/ops/masked_scatter.h"
#include "torch_xla/csrc/ops/masked_select.h"
#include "torch_xla/csrc/ops/max_in_dim.h"
#include "torch_xla/csrc/ops/max_pool_nd.h"
Expand Down Expand Up @@ -234,6 +235,14 @@ absl::optional<ir::Value> GetOptionalIrValue(const XLATensor& tensor) {
return value;
}

ir::Value MaybeExpand(const ir::Value& input, const xla::Shape& target_shape) {
if (input.shape().dimensions() == target_shape.dimensions()) {
return input;
}
return ir::MakeNode<ir::ops::Expand>(
input, xla::util::ToVector<xla::int64>(target_shape.dimensions()));
}

void CheckIsIntegralOrPred(const xla::Shape& shape,
const std::string& op_name) {
XLA_CHECK(xla::ShapeUtil::ElementIsIntegral(shape) ||
Expand Down Expand Up @@ -1385,12 +1394,18 @@ void XLATensor::lt_(XLATensor& input, const XLATensor& other) {

void XLATensor::masked_fill_(XLATensor& input, const XLATensor& mask,
at::Scalar value) {
// Expand mask to be the same size as input.
ir::NodePtr expanded_mask = ir::MakeNode<ir::ops::Expand>(
mask.GetIrValue(),
xla::util::ToVector<xla::int64>(input.shape().get().dimensions()));
input.SetIrValue(ir::MakeNode<ir::ops::MaskedFill>(input.GetIrValue(),
expanded_mask, value));
ir::ScopePusher ir_scope(at::aten::masked_fill.toQualString());
input.SetIrValue(ir::MakeNode<ir::ops::MaskedFill>(
input.GetIrValue(), MaybeExpand(mask.GetIrValue(), input.shape()),
value));
}

void XLATensor::masked_scatter_(XLATensor& input, const XLATensor& mask,
const XLATensor& source) {
ir::ScopePusher ir_scope(at::aten::masked_scatter.toQualString());
input.SetIrValue(ir::MakeNode<ir::ops::MaskedScatter>(
input.GetIrValue(), MaybeExpand(mask.GetIrValue(), input.shape()),
source.GetIrValue()));
}

XLATensor XLATensor::masked_select(const XLATensor& input,
Expand Down
Loading