From bc012c8b7c222c8662172c7f957c28a1bf0de9b3 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 11 Oct 2024 11:58:19 -0700 Subject: [PATCH] Add op: masked_scatter (#6167) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/6167 Differential Revision: D64243532 --- kernels/aten/functions.yaml | 2 + kernels/portable/cpu/op_masked_scatter.cpp | 78 ++++++++++ kernels/portable/functions.yaml | 5 + kernels/test/op_masked_scatter_test.cpp | 145 ++++++++++++++++++ kernels/test/targets.bzl | 1 + .../kernels/portable/op_registration_util.bzl | 6 + 6 files changed, 237 insertions(+) create mode 100644 kernels/portable/cpu/op_masked_scatter.cpp create mode 100644 kernels/test/op_masked_scatter_test.cpp diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index cba03b8a743..8b21a2f1454 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -241,6 +241,8 @@ - op: masked_fill.Scalar_out +- op: masked_scatter.out + - op: max_pool2d_with_indices.out - op: max.dim_max diff --git a/kernels/portable/cpu/op_masked_scatter.cpp b/kernels/portable/cpu/op_masked_scatter.cpp new file mode 100644 index 00000000000..16cef033670 --- /dev/null +++ b/kernels/portable/cpu/op_masked_scatter.cpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +Tensor& masked_scatter_out( + KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& mask, + const Tensor& src, + Tensor& out) { + ScalarType in_type = in.scalar_type(); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensor_is_realhbbf16_type(in), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, mask.scalar_type() == ScalarType::Bool, InvalidArgument, out); + ET_KERNEL_CHECK(ctx, src.scalar_type() == in_type, InvalidArgument, out); + ET_KERNEL_CHECK(ctx, out.scalar_type() == in_type, InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, mask, out), InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(in, mask, out) == Error::Ok, + InvalidArgument, + out); + + constexpr auto op_name = "masked_scatter.out"; + + int64_t idx = 0; + int64_t src_numel = src.numel(); + bool src_numel_check = true; + + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, op_name, CTYPE, [&]() { + const CTYPE* const src_data = src.const_data_ptr(); + apply_binary_elementwise_fn( + [src_data, &idx, &src_numel, &src_numel_check]( + const CTYPE val_in, const bool val_mask) { + if (val_mask && idx >= src_numel) { + src_numel_check = false; + return val_in; + } + return val_mask ? src_data[idx++] : val_in; + }, + in, + mask, + out); + }); + + ET_KERNEL_CHECK_MSG( + ctx, + src_numel_check, + InvalidArgument, + out, + "masked_scatter: src doesn't have enough elements"); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index 5136ea0a12f..b0ee9769a56 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -542,6 +542,11 @@ - arg_meta: null kernel_name: torch::executor::masked_fill_scalar_out +- op: masked_scatter.out + kernels: + - arg_meta: null + kernel_name: torch::executor::masked_scatter_out + - op: max.dim_max kernels: - arg_meta: null diff --git a/kernels/test/op_masked_scatter_test.cpp b/kernels/test/op_masked_scatter_test.cpp new file mode 100644 index 00000000000..9116fe71a6c --- /dev/null +++ b/kernels/test/op_masked_scatter_test.cpp @@ -0,0 +1,145 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include + +#include + +using namespace ::testing; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::testing::SupportedFeatures; +using torch::executor::testing::TensorFactory; + +class OpMaskedScatterOutTest : public OperatorTest { + protected: + Tensor& op_masked_scatter_out( + const Tensor& in, + const Tensor& mask, + const Tensor& src, + Tensor& out) { + return torch::executor::aten::masked_scatter_outf( + context_, in, mask, src, out); + } +}; + +TEST_F(OpMaskedScatterOutTest, SmokeTest) { + TensorFactory tf; + TensorFactory tfBool; + + Tensor in = tf.make({2, 3}, {1, 2, 3, 4, 5, 6}); + Tensor mask = tfBool.make({2, 3}, {true, false, false, true, false, true}); + Tensor src = tf.make({3}, {10, 20, 30}); + + Tensor out = tf.zeros({2, 3}); + + op_masked_scatter_out(in, mask, src, out); + EXPECT_TENSOR_EQ(out, tf.make({2, 3}, {10, 2, 3, 20, 5, 30})); +} + +TEST_F(OpMaskedScatterOutTest, BroadcastInput) { + TensorFactory tf; + TensorFactory tfBool; + + Tensor in = tf.make({3}, {1, 2, 3}); + Tensor mask = tfBool.make({2, 3}, {true, false, false, true, false, true}); + Tensor src = tf.make({3}, {10, 20, 30}); + + Tensor out = tf.zeros({2, 3}); + + op_masked_scatter_out(in, mask, src, out); + EXPECT_TENSOR_EQ(out, tf.make({2, 3}, {10, 2, 3, 20, 2, 30})); +} + +TEST_F(OpMaskedScatterOutTest, BroadcastMask) { + TensorFactory tf; + TensorFactory tfBool; + + Tensor in = tf.make({2, 3}, {1, 2, 3, 4, 5, 6}); + Tensor mask = tfBool.make({3}, {false, true, false}); + Tensor src = tf.make({2}, {10, 20}); + + Tensor out = tf.zeros({2, 3}); + + op_masked_scatter_out(in, mask, src, out); + EXPECT_TENSOR_EQ(out, tf.make({2, 3}, {1, 10, 3, 4, 20, 6})); +} + +TEST_F(OpMaskedScatterOutTest, SrcWithMoreElements) { + TensorFactory tf; + TensorFactory tfBool; + + Tensor in = tf.make({2, 3}, {1, 2, 3, 4, 5, 6}); + Tensor mask = tfBool.make({2, 3}, {true, false, false, true, false, true}); + Tensor src = tf.make({4}, {10, 20, 30, 40}); + + Tensor out = tf.zeros({2, 3}); + + op_masked_scatter_out(in, mask, src, out); + EXPECT_TENSOR_EQ(out, tf.make({2, 3}, {10, 2, 3, 20, 5, 30})); +} + +TEST_F(OpMaskedScatterOutTest, SrcWithLessElementsFails) { + TensorFactory tf; + TensorFactory tfBool; + + Tensor in = tf.make({2, 3}, {1, 2, 3, 4, 5, 6}); + Tensor mask = tfBool.make({2, 3}, {true, false, false, true, false, true}); + Tensor src = tf.make({2}, {10, 20}); + + Tensor out = tf.zeros({2, 3}); + + ET_EXPECT_KERNEL_FAILURE(context_, op_masked_scatter_out(in, mask, src, out)); +} + +TEST_F(OpMaskedScatterOutTest, EmptyMask) { + TensorFactory tf; + TensorFactory tfBool; + + Tensor in = tf.make({2, 1}, {100, 200}); + Tensor mask = tfBool.make({2, 0}, {}); + Tensor src = tf.make({4}, {10, 20, 30, 40}); + + Tensor out = tf.zeros({2, 0}); + + op_masked_scatter_out(in, mask, src, out); + EXPECT_TENSOR_EQ(out, tf.make({2, 0}, {})); +} + +TEST_F(OpMaskedScatterOutTest, EmptySrc) { + TensorFactory tf; + TensorFactory tfBool; + + Tensor in = tf.make({2, 1}, {100, 200}); + Tensor mask = tfBool.make({2, 1}, {false, false}); + Tensor src = tf.make({0}, {}); + + Tensor out = tf.zeros({2, 1}); + + op_masked_scatter_out(in, mask, src, out); + EXPECT_TENSOR_EQ(out, tf.make({2, 1}, {100, 200})); +} + +TEST_F(OpMaskedScatterOutTest, EmptyMaskAndSrc) { + TensorFactory tf; + TensorFactory tfBool; + + Tensor in = tf.make({2, 1}, {100, 200}); + Tensor mask = tfBool.make({0}, {}); + Tensor src = tf.make({0}, {}); + + Tensor out = tf.zeros({2, 0}); + + op_masked_scatter_out(in, mask, src, out); + EXPECT_TENSOR_EQ(out, tf.make({2, 0}, {})); +} diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 7bc2e7555c6..91b3ba89fde 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -254,6 +254,7 @@ def define_common_targets(): _common_op_test("op_logit_test", ["aten", "portable"]) _common_op_test("op_lt_test", ["aten", "portable"]) _common_op_test("op_masked_fill_test", ["aten", "portable"]) + _common_op_test("op_masked_scatter_test", ["aten", "portable"]) _common_op_test("op_max_test", ["aten", "portable"]) _common_op_test("op_max_pool2d_with_indices_test", ["aten", "portable"]) _common_op_test("op_maximum_test", ["aten", "portable"]) diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index 32b9cec7e24..6328860f99a 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -776,6 +776,12 @@ ATEN_OPS = ( ":scalar_utils", ], ), + op_target( + name = "op_masked_scatter", + deps = [ + "//executorch/kernels/portable/cpu/util:broadcast_util", + ], + ), op_target( name = "op_max", deps = [