Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 8 additions & 20 deletions kernels/portable/cpu/op_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,6 @@ bool check_fast_path_conditions(
if (index.dim() != 1) {
return false;
}

// Fast path only supports non-negative indices.
if (ix_type == ScalarType::Int) {
const int32_t* const data = index.const_data_ptr<int32_t>();
if (std::any_of(data, data + index.numel(), [](const auto x) {
return x < 0;
})) {
return false;
}
} else { // ScalarType::Long
const int64_t* const data = index.const_data_ptr<int64_t>();
if (std::any_of(data, data + index.numel(), [](const auto x) {
return x < 0;
})) {
return false;
}
}
}
}

Expand Down Expand Up @@ -96,8 +79,10 @@ bool check_fast_path_args(
Long, Int, index.scalar_type(), ctx, "index.Tensor", CTYPE, [&]() {
const CTYPE* const index_arr = index.const_data_ptr<CTYPE>();
for (const auto i : c10::irange(index.numel())) {
if (index_arr[i] < 0 ||
index_arr[i] >= static_cast<CTYPE>(in.size(dim))) {
CTYPE index_val = index_arr[i];
CTYPE dim_size = static_cast<CTYPE>(in.size(dim));
index_val = index_val < 0 ? index_val + dim_size : index_val;
if (index_val < 0 || index_val >= dim_size) {
ET_LOG(
Error,
"Index %" PRId64
Expand Down Expand Up @@ -189,11 +174,14 @@ Tensor& fast_path(

ET_SWITCH_TWO_TYPES(Long, Int, index_type, ctx, op_name, CTYPE, [&]() {
const CTYPE* const index_arr = index.const_data_ptr<CTYPE>();
CTYPE dim_size = static_cast<CTYPE>(in.size(dim));
for (const auto i : c10::irange(leading_dims)) {
const char* src = in_data + i * in_dim_length * length_per_step;
char* dest = out_data + i * out_dim_length * length_per_step;
for (const auto j : c10::irange(out_dim_length)) {
const char* copy_src = src + index_arr[j] * length_per_step;
auto index_val =
index_arr[j] < 0 ? index_arr[j] + dim_size : index_arr[j];
const char* copy_src = src + index_val * length_per_step;
char* copy_dest = dest + j * length_per_step;
memcpy(copy_dest, copy_src, length_per_step);
}
Expand Down
53 changes: 53 additions & 0 deletions kernels/test/op_index_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -947,3 +947,56 @@ TEST_F(OpIndexTensorOutTest, FastPathEmptyInput) {

EXPECT_TENSOR_EQ(out, expected);
}

TEST_F(OpIndexTensorOutTest, FastPathNegativeIndex) {
TensorFactory<ScalarType::Float> tf;
TensorFactory<ScalarType::Long> tfl;

// clang-format off
Tensor x = tf.make(
{2, 3, 4},
{
// [0, :, :]
1., 2., 3., 4., // [0, 0, :]
5., 6., 7., 8., // [0, 1, :]
9., 10., 11., 12., // [0, 2, :]

// [1, :, :]
-1., -2., -3., -4., // [1, 0, :]
-5., -6., -7., -8., // [1, 1, :]
-9., -10., -11., -12., // [1, 2, :]
});
// clang-format on

// Use negative indices in the first dimension: -1, 0, -2
std::array<optional<Tensor>, 3> indices = {
optional<Tensor>(tfl.make({3}, {-1, 0, -2})),
optional<Tensor>(),
optional<Tensor>()};

Tensor out = tf.zeros({3, 3, 4});
// clang-format off
Tensor expected = tf.make(
{3, 3, 4},
{
// [1, :, :]
-1., -2., -3., -4., // [1, 0, :]
-5., -6., -7., -8., // [1, 1, :]
-9., -10., -11., -12., // [1, 2, :]

// [0, :, :]
1., 2., 3., 4., // [0, 0, :]
5., 6., 7., 8., // [0, 1, :]
9., 10., 11., 12., // [0, 2, :]

// [0, :, :] again (since -2 wraps to 0)
1., 2., 3., 4., // [0, 0, :]
5., 6., 7., 8., // [0, 1, :]
9., 10., 11., 12., // [0, 2, :]
});
// clang-format on

op_index_tensor_out(x, indices, out);

EXPECT_TENSOR_EQ(out, expected);
}
Loading