Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kernels/portable/cpu/op_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ Tensor& index_Tensor_out(
if (block_count == 0) {
ET_KERNEL_CHECK(
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
ET_SWITCH_REALHB_TYPES(in_type, ctx, "index.Tensor_out", CTYPE, [&]() {
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "index.Tensor_out", CTYPE, [&]() {
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
memcpy(out_data, in_data, in.nbytes());
Expand Down
48 changes: 38 additions & 10 deletions kernels/test/op_index_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,36 @@ class OpIndexTensorOutTest : public OperatorTest {

ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);

#undef TEST_ENTRY
}

template <executorch::aten::ScalarType INPUT_DTYPE>
void test_indices_with_only_null_tensors_supported() {
TensorFactory<INPUT_DTYPE> tf;

Tensor x = tf.make({2, 3}, {1, 2, 3, 4, 5, 6});
Tensor out = tf.zeros({2, 3});

std::array<optional<Tensor>, 1> indices1 = {optional<Tensor>()};
op_index_tensor_out(x, indices1, out);
EXPECT_TENSOR_EQ(out, x);

out = tf.zeros({2, 3});
std::array<optional<Tensor>, 2> indices2 = {
optional<Tensor>(), std::optional<Tensor>()};
op_index_tensor_out(x, indices2, out);
EXPECT_TENSOR_EQ(out, x);
}

/**
* Test indices with only null tensors for all input data types
*/
void test_indices_with_only_null_tensors_enumerate_in_types() {
#define TEST_ENTRY(ctype, dtype) \
test_indices_with_only_null_tensors_supported<ScalarType::dtype>();

ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);

#undef TEST_ENTRY
}

Expand Down Expand Up @@ -405,21 +435,19 @@ TEST_F(OpIndexTensorOutTest, IndicesWithOnlyNullTensorsSupported) {
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
GTEST_SKIP() << "ATen kernel test fails";
}
TensorFactory<ScalarType::Double> tf;
test_indices_with_only_null_tensors_enumerate_in_types();
}

TEST_F(OpIndexTensorOutTest, TooManyNullIndices) {
TensorFactory<ScalarType::Double> tf;
Tensor x = tf.make({2, 3}, {1., 2., 3., 4., 5., 6.});
std::array<optional<Tensor>, 1> indices0 = {optional<Tensor>()};
run_test_cases(x, indices0, x);

std::array<optional<Tensor>, 2> indices1 = {
optional<Tensor>(), std::optional<Tensor>()};
run_test_cases(x, indices1, x);

std::array<optional<Tensor>, 3> indices2 = {
std::array<optional<Tensor>, 3> indices = {
optional<Tensor>(), std::optional<Tensor>(), std::optional<Tensor>()};
Tensor out = tf.ones({2, 3});
ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
context_, op_index_tensor_out(x, indices2, out), "");
context_,
op_index_tensor_out(x, indices, out),
"Indexing too many dimensions");
}

TEST_F(OpIndexTensorOutTest, EmptyIndicesSupported) {
Expand Down
Loading