diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index 28f1a215562..48a8d3bc8ee 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -201,6 +201,8 @@ - op: index_put.out +- op: index_put_ + - op: index_select.out - op: index.Tensor_out diff --git a/kernels/portable/cpu/op_index_put.cpp b/kernels/portable/cpu/op_index_put.cpp index 942892c31ec..76bd7a48922 100644 --- a/kernels/portable/cpu/op_index_put.cpp +++ b/kernels/portable/cpu/op_index_put.cpp @@ -11,6 +11,7 @@ #include #include +#include #include namespace torch { @@ -18,11 +19,13 @@ namespace executor { namespace native { using Tensor = executorch::aten::Tensor; +using TensorOptList = + executorch::aten::ArrayRef>; Tensor& index_put_out( KernelRuntimeContext& ctx, const Tensor& in, - executorch::aten::ArrayRef> indices, + TensorOptList indices, const Tensor& values, const bool accumulate, Tensor& out) { @@ -154,6 +157,177 @@ Tensor& index_put_out( return out; } +namespace { + +bool check_special_case_in_place_args( + Tensor& in, + TensorOptList indices, + const Tensor& values, + const bool accumulate, + size_t* dim) { + ET_CHECK_OR_RETURN_FALSE( + !accumulate, + "Special case in-place index_put does not support accumulate"); + + ET_CHECK_OR_RETURN_FALSE( + static_cast(indices.size()) <= in.dim(), + "Indexing too many dimensions"); + + bool found_index = false; + for (const auto i : c10::irange(indices.size())) { + if (indices[i].has_value()) { + *dim = i; + ET_CHECK_OR_RETURN_FALSE( + !found_index, + "Special case in-place index_put only supports a single non-null index tensor"); + found_index = true; + const Tensor& index = indices[i].value(); + ScalarType ix_type = index.scalar_type(); + ET_CHECK_OR_RETURN_FALSE( + ix_type == ScalarType::Long || ix_type == ScalarType::Int, + "Special case in-place index_put only supports Long or Int index tensors; got %d", + static_cast(ix_type)); + ET_CHECK_OR_RETURN_FALSE( + index.dim() == 1, + "Special case in-place index_put only supports 1-dimensional index tensors; got %d", + static_cast(ix_type)); + } + } + + ET_CHECK_OR_RETURN_FALSE( + found_index, + "Special case in-place index_put needs at least one non-null index tensor"); + + const Tensor& index = indices[*dim].value(); + + bool is_valid_index = true; + ET_SWITCH_TWO_TYPES( + Long, Int, index.scalar_type(), ctx, "index_put_", CTYPE, [&]() { + const CTYPE* const index_arr = index.const_data_ptr(); + for (const auto i : c10::irange(index.numel())) { + if (index_arr[i] < 0 || + index_arr[i] >= static_cast(in.size(*dim))) { + ET_LOG( + Error, + "Index %" PRId64 + " out of range for tensor with size %zd" + " at dimension %zu", + static_cast(index_arr[i]), + in.size(*dim), + *dim); + is_valid_index = false; + break; + } + } + }); + + ET_CHECK_OR_RETURN_FALSE( + is_valid_index, + "Some index values are not within bounds of input tensor at indexed dim"); + + ET_CHECK_OR_RETURN_FALSE( + values.size(*dim) == index.size(0), + "Special case in-place index_put requires values to match index length at the indexed dim; values.size(%zu) = %" ET_PRI_TENSOR_SIZE + ", index_length = %zd", + *dim, + values.size(*dim), + index.size(0)); + + Tensor::SizesType expected_values_size[kTensorDimensionLimit] = {}; + size_t in_ndim = static_cast(in.dim()); + for (const auto i : c10::irange(in_ndim)) { + if (i != *dim) { + expected_values_size[i] = static_cast(in.size(i)); + } + } + expected_values_size[*dim] = static_cast(index.size(0)); + +#if ET_LOG_ENABLED + auto in_shape_str = executorch::runtime::tensor_shape_to_c_string( + executorch::runtime::Span( + in.sizes().data(), in.sizes().size())); + auto values_shape_str = executorch::runtime::tensor_shape_to_c_string( + executorch::runtime::Span( + values.sizes().data(), values.sizes().size())); + + ET_CHECK_OR_RETURN_FALSE( + tensor_has_expected_size(values, {expected_values_size, in_ndim}), + "Special case in-place index_put requires values to match input shape except for indexed dim; got input shape %s and values shape %s", + in_shape_str.data(), + values_shape_str.data()); +#else + ET_CHECK_OR_RETURN_FALSE( + tensor_has_expected_size(values, {expected_values_size, in_ndim}), + "Special case in-place index_put requires values to match input shape except for indexed dim"); +#endif // ET_LOG_ENABLED + + return true; +} + +} // namespace + +Tensor& index_put_( + KernelRuntimeContext& ctx, + Tensor& in, + TensorOptList indices, + const Tensor& values, + const bool accumulate) { + (void)ctx; + + ET_KERNEL_CHECK( + ctx, tensors_have_same_dtype(in, values), InvalidArgument, in); + + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, values), InvalidArgument, in); + + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, in); + + size_t dim = 0; + ET_KERNEL_CHECK( + ctx, + check_special_case_in_place_args(in, indices, values, accumulate, &dim), + InvalidArgument, + in); + + const Tensor& index = indices[dim].value(); + ScalarType index_type = index.scalar_type(); + + if (in.dim() == 0) { + memcpy(in.mutable_data_ptr(), values.const_data_ptr(), in.nbytes()); + return in; + } + + size_t leading_dims = getLeadingDims(in, dim); + size_t trailing_dims = getTrailingDims(in, dim); + + if (leading_dims == 0 || trailing_dims == 0) { + return in; + } + + size_t values_dim_length = values.size(dim); + size_t in_dim_length = in.size(dim); + + size_t length_per_step = trailing_dims * in.element_size(); + + const char* values_data = values.const_data_ptr(); + char* in_data = in.mutable_data_ptr(); + + ET_SWITCH_TWO_TYPES(Long, Int, index_type, ctx, "index_put_", CTYPE, [&]() { + const CTYPE* const index_arr = index.const_data_ptr(); + for (const auto i : c10::irange(leading_dims)) { + const char* src = values_data + i * values_dim_length * length_per_step; + char* dest = in_data + i * in_dim_length * length_per_step; + for (const auto j : c10::irange(values_dim_length)) { + const char* copy_src = src + j * length_per_step; + char* copy_dest = dest + index_arr[j] * length_per_step; + memcpy(copy_dest, copy_src, length_per_step); + } + } + }); + + return in; +} + } // namespace native } // namespace executor } // namespace torch diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index 466e015e31d..ecd6a771646 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -452,6 +452,11 @@ - arg_meta: null kernel_name: torch::executor::index_put_out +- op: index_put_ + kernels: + - arg_meta: null + kernel_name: torch::executor::index_put_ + - op: index_select.out kernels: - arg_meta: null diff --git a/kernels/test/op_index_put_test.cpp b/kernels/test/op_index_put_test.cpp index 967760576da..b25cdb01e92 100644 --- a/kernels/test/op_index_put_test.cpp +++ b/kernels/test/op_index_put_test.cpp @@ -1011,3 +1011,99 @@ TEST_F(OpIndexPutOutTest, DynamicShapeUnbound) { test_dynamic_shape( {1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND); } + +class OpIndexPutInplaceTest : public OperatorTest { + protected: + Tensor& op_index_put_( + Tensor& input, + OptTensorArrayRef indices, + const Tensor& values, + const bool accumulate) { +#ifdef USE_ATEN_LIB + c10::List> indices_list(indices); + return torch::executor::aten::index_put_( + context_, input, indices_list, values, accumulate); +#else + return torch::executor::aten::index_put_( + context_, input, indices, values, accumulate); +#endif + } + + template < + executorch::aten::ScalarType INPUT_DTYPE, + executorch::aten::ScalarType INDICES_DTYPE> + void test_dtype() { + TensorFactory tf; + TensorFactory tfl; + + // clang-format off + Tensor x = tf.make( + {2, 3, 4}, + { + // [0, :, :] + 1, 1, 1, 1, // [0, 0, :] + 0, 0, 0, 0, // [0, 1, :] + 2, 2, 2, 2, // [0, 2, :] + + // [1, :, :] + 3, 3, 3, 3, // [0, 0, :] + 0, 0, 0, 0, // [0, 1, :] + 5, 5, 5, 5, // [0, 2, :] + }); + // clang-format on + + optional indices[] = { + optional(), + optional(tfl.make({2}, {0, 2})), + }; + + // clang-format off + Tensor values = tf.make( + {2, 2, 4}, + { + // [0, :, :] + 1, 2, 3, 4, // [0, 0, :] + 5, 6, 7, 8, // [0, 1, :] + + // [1, :, :] + 9, 10, 11, 12, // [0, 0, :] + 13, 14, 15, 16, // [0, 1, :] + }); + // clang-format on + + // clang-format off + Tensor expected = tf.make( + {2, 3, 4}, + { + // [0, :, :] + 1, 2, 3, 4, // [0, 0, :] + 0, 0, 0, 0, // [0, 1, :] + 5, 6, 7, 8, // [0, 2, :] + + // [1, :, :] + 9, 10, 11, 12, // [0, 0, :] + 0, 0, 0, 0, // [0, 1, :] + 13, 14, 15, 16, // [0, 2, :] + }); + // clang-format on + + Tensor ret = op_index_put_(x, indices, values, /*accumulate=*/false); + + EXPECT_TENSOR_EQ(ret, x); + EXPECT_TENSOR_EQ(ret, expected); + } +}; + +TEST_F(OpIndexPutInplaceTest, AllDtypesSupportedForInput) { +#define TEST_ENTRY(ctype, dtype) \ + test_dtype(); + + ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY); + +#undef TEST_ENTRY +} + +TEST_F(OpIndexPutInplaceTest, AllDtypesSupportedForIndicesList) { + test_dtype(); + test_dtype(); +}