Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions kernels/aten/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@

- op: index_put.out

- op: index_put_

- op: index_select.out

- op: index.Tensor_out
Expand Down
176 changes: 175 additions & 1 deletion kernels/portable/cpu/op_index_put.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,21 @@

#include <executorch/kernels/portable/cpu/util/advanced_index_util.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/runtime/core/exec_aten/util/tensor_shape_to_c_string.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {
namespace native {

using Tensor = executorch::aten::Tensor;
using TensorOptList =
executorch::aten::ArrayRef<executorch::aten::optional<Tensor>>;

Tensor& index_put_out(
KernelRuntimeContext& ctx,
const Tensor& in,
executorch::aten::ArrayRef<executorch::aten::optional<Tensor>> indices,
TensorOptList indices,
const Tensor& values,
const bool accumulate,
Tensor& out) {
Expand Down Expand Up @@ -154,6 +157,177 @@ Tensor& index_put_out(
return out;
}

namespace {

bool check_special_case_in_place_args(
Tensor& in,
TensorOptList indices,
const Tensor& values,
const bool accumulate,
size_t* dim) {
ET_CHECK_OR_RETURN_FALSE(
!accumulate,
"Special case in-place index_put does not support accumulate");

ET_CHECK_OR_RETURN_FALSE(
static_cast<ssize_t>(indices.size()) <= in.dim(),
"Indexing too many dimensions");

bool found_index = false;
for (const auto i : c10::irange(indices.size())) {
if (indices[i].has_value()) {
*dim = i;
ET_CHECK_OR_RETURN_FALSE(
!found_index,
"Special case in-place index_put only supports a single non-null index tensor");
found_index = true;
const Tensor& index = indices[i].value();
ScalarType ix_type = index.scalar_type();
ET_CHECK_OR_RETURN_FALSE(
ix_type == ScalarType::Long || ix_type == ScalarType::Int,
"Special case in-place index_put only supports Long or Int index tensors; got %d",
static_cast<int>(ix_type));
ET_CHECK_OR_RETURN_FALSE(
index.dim() == 1,
"Special case in-place index_put only supports 1-dimensional index tensors; got %d",
static_cast<int>(ix_type));
}
}

ET_CHECK_OR_RETURN_FALSE(
found_index,
"Special case in-place index_put needs at least one non-null index tensor");

const Tensor& index = indices[*dim].value();

bool is_valid_index = true;
ET_SWITCH_TWO_TYPES(
Long, Int, index.scalar_type(), ctx, "index_put_", CTYPE, [&]() {
const CTYPE* const index_arr = index.const_data_ptr<CTYPE>();
for (const auto i : c10::irange(index.numel())) {
if (index_arr[i] < 0 ||
index_arr[i] >= static_cast<CTYPE>(in.size(*dim))) {
ET_LOG(
Error,
"Index %" PRId64
" out of range for tensor with size %zd"
" at dimension %zu",
static_cast<int64_t>(index_arr[i]),
in.size(*dim),
*dim);
is_valid_index = false;
break;
}
}
});

ET_CHECK_OR_RETURN_FALSE(
is_valid_index,
"Some index values are not within bounds of input tensor at indexed dim");

ET_CHECK_OR_RETURN_FALSE(
values.size(*dim) == index.size(0),
"Special case in-place index_put requires values to match index length at the indexed dim; values.size(%zu) = %" ET_PRI_TENSOR_SIZE
", index_length = %zd",
*dim,
values.size(*dim),
index.size(0));

Tensor::SizesType expected_values_size[kTensorDimensionLimit] = {};
size_t in_ndim = static_cast<size_t>(in.dim());
for (const auto i : c10::irange(in_ndim)) {
if (i != *dim) {
expected_values_size[i] = static_cast<Tensor::SizesType>(in.size(i));
}
}
expected_values_size[*dim] = static_cast<Tensor::SizesType>(index.size(0));

#if ET_LOG_ENABLED
auto in_shape_str = executorch::runtime::tensor_shape_to_c_string(
executorch::runtime::Span<const Tensor::SizesType>(
in.sizes().data(), in.sizes().size()));
auto values_shape_str = executorch::runtime::tensor_shape_to_c_string(
executorch::runtime::Span<const Tensor::SizesType>(
values.sizes().data(), values.sizes().size()));

ET_CHECK_OR_RETURN_FALSE(
tensor_has_expected_size(values, {expected_values_size, in_ndim}),
"Special case in-place index_put requires values to match input shape except for indexed dim; got input shape %s and values shape %s",
in_shape_str.data(),
values_shape_str.data());
#else
ET_CHECK_OR_RETURN_FALSE(
tensor_has_expected_size(values, {expected_values_size, in_ndim}),
"Special case in-place index_put requires values to match input shape except for indexed dim");
#endif // ET_LOG_ENABLED

return true;
}

} // namespace

Tensor& index_put_(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it worth making the out variant call the inplace one on out?

KernelRuntimeContext& ctx,
Tensor& in,
TensorOptList indices,
const Tensor& values,
const bool accumulate) {
(void)ctx;

ET_KERNEL_CHECK(
ctx, tensors_have_same_dtype(in, values), InvalidArgument, in);

ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, values), InvalidArgument, in);

ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, in);

size_t dim = 0;
ET_KERNEL_CHECK(
ctx,
check_special_case_in_place_args(in, indices, values, accumulate, &dim),
InvalidArgument,
in);

const Tensor& index = indices[dim].value();
ScalarType index_type = index.scalar_type();

if (in.dim() == 0) {
memcpy(in.mutable_data_ptr(), values.const_data_ptr(), in.nbytes());
return in;
}

size_t leading_dims = getLeadingDims(in, dim);
size_t trailing_dims = getTrailingDims(in, dim);

if (leading_dims == 0 || trailing_dims == 0) {
return in;
}

size_t values_dim_length = values.size(dim);
size_t in_dim_length = in.size(dim);

size_t length_per_step = trailing_dims * in.element_size();

const char* values_data = values.const_data_ptr<char>();
char* in_data = in.mutable_data_ptr<char>();

ET_SWITCH_TWO_TYPES(Long, Int, index_type, ctx, "index_put_", CTYPE, [&]() {
const CTYPE* const index_arr = index.const_data_ptr<CTYPE>();
for (const auto i : c10::irange(leading_dims)) {
const char* src = values_data + i * values_dim_length * length_per_step;
char* dest = in_data + i * in_dim_length * length_per_step;
for (const auto j : c10::irange(values_dim_length)) {
const char* copy_src = src + j * length_per_step;
char* copy_dest = dest + index_arr[j] * length_per_step;
memcpy(copy_dest, copy_src, length_per_step);
}
}
});

return in;
}

} // namespace native
} // namespace executor
} // namespace torch
5 changes: 5 additions & 0 deletions kernels/portable/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,11 @@
- arg_meta: null
kernel_name: torch::executor::index_put_out

- op: index_put_
kernels:
- arg_meta: null
kernel_name: torch::executor::index_put_

- op: index_select.out
kernels:
- arg_meta: null
Expand Down
96 changes: 96 additions & 0 deletions kernels/test/op_index_put_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1011,3 +1011,99 @@ TEST_F(OpIndexPutOutTest, DynamicShapeUnbound) {
test_dynamic_shape(
{1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
}

class OpIndexPutInplaceTest : public OperatorTest {
protected:
Tensor& op_index_put_(
Tensor& input,
OptTensorArrayRef indices,
const Tensor& values,
const bool accumulate) {
#ifdef USE_ATEN_LIB
c10::List<std::optional<at::Tensor>> indices_list(indices);
return torch::executor::aten::index_put_(
context_, input, indices_list, values, accumulate);
#else
return torch::executor::aten::index_put_(
context_, input, indices, values, accumulate);
#endif
}

template <
executorch::aten::ScalarType INPUT_DTYPE,
executorch::aten::ScalarType INDICES_DTYPE>
void test_dtype() {
TensorFactory<INPUT_DTYPE> tf;
TensorFactory<INDICES_DTYPE> tfl;

// clang-format off
Tensor x = tf.make(
{2, 3, 4},
{
// [0, :, :]
1, 1, 1, 1, // [0, 0, :]
0, 0, 0, 0, // [0, 1, :]
2, 2, 2, 2, // [0, 2, :]

// [1, :, :]
3, 3, 3, 3, // [0, 0, :]
0, 0, 0, 0, // [0, 1, :]
5, 5, 5, 5, // [0, 2, :]
});
// clang-format on

optional<Tensor> indices[] = {
optional<Tensor>(),
optional<Tensor>(tfl.make({2}, {0, 2})),
};

// clang-format off
Tensor values = tf.make(
{2, 2, 4},
{
// [0, :, :]
1, 2, 3, 4, // [0, 0, :]
5, 6, 7, 8, // [0, 1, :]

// [1, :, :]
9, 10, 11, 12, // [0, 0, :]
13, 14, 15, 16, // [0, 1, :]
});
// clang-format on

// clang-format off
Tensor expected = tf.make(
{2, 3, 4},
{
// [0, :, :]
1, 2, 3, 4, // [0, 0, :]
0, 0, 0, 0, // [0, 1, :]
5, 6, 7, 8, // [0, 2, :]

// [1, :, :]
9, 10, 11, 12, // [0, 0, :]
0, 0, 0, 0, // [0, 1, :]
13, 14, 15, 16, // [0, 2, :]
});
// clang-format on

Tensor ret = op_index_put_(x, indices, values, /*accumulate=*/false);

EXPECT_TENSOR_EQ(ret, x);
EXPECT_TENSOR_EQ(ret, expected);
}
};

TEST_F(OpIndexPutInplaceTest, AllDtypesSupportedForInput) {
#define TEST_ENTRY(ctype, dtype) \
test_dtype<ScalarType::dtype, ScalarType::Long>();

ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);

#undef TEST_ENTRY
}

TEST_F(OpIndexPutInplaceTest, AllDtypesSupportedForIndicesList) {
test_dtype<ScalarType::Float, ScalarType::Long>();
test_dtype<ScalarType::Float, ScalarType::Int>();
}
Loading