Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 45 additions & 9 deletions kernels/prim_ops/et_copy_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ constexpr size_t kTensorDimensionLimit = 16;
// torch.ops.executorch.prim.add.int(iteration_index, 1, iteration_index)
// done_bool = torch.ops.executorch.prim.eq.int(iteration_index,
// sym_size, done_bool) # Emitter inserts a instruction here, if
// done_bool == False jump to selcect_copy op # if not continue. return
// done_bool == False jump to select_copy op # if not continue. return
// add_tensor
//
// The output of each iteration (copy_from) is copied into the copy_to tensor at
Expand All @@ -79,12 +79,24 @@ void et_copy_index(KernelRuntimeContext& context, Span<EValue*> stack) {
auto copy_from = (*stack[1]).toTensor();
auto index = (*stack[2]).toInt();

ET_KERNEL_CHECK_MSG(
context,
index >= 0,
InvalidArgument,
/* void */,
"Expected index to be non-negative.");

// Number of bytes we need to copy over from copy_from tensor.
size_t size_copy_from = (copy_from.element_size()) * (copy_from.numel());

ET_CHECK_MSG(
ET_KERNEL_CHECK_MSG(
context,
(copy_to.sizes().size() - copy_from.sizes().size()) == 1,
"Ranks of copy_to and copy_from tensor should only differ by 1.");
InvalidArgument,
/* void */,
"Ranks of copy_to %zu and copy_from tensor %zu should only differ by 1.",
copy_to.sizes().size(),
copy_from.sizes().size());

// Here we calculate the size of the out_tensor after copy_from has
// been copied to it. This will be passed onto the resize call.
Expand All @@ -93,8 +105,11 @@ void et_copy_index(KernelRuntimeContext& context, Span<EValue*> stack) {
// If we're copying past the first index then the shape of
// copy_from and copy_to without the leading dimension should be
// the same. i.e. copy_to.size[1:] == copy_from.size[:].
ET_CHECK_MSG(
ET_KERNEL_CHECK_MSG(
context,
copy_to.sizes()[i + 1] == copy_from.sizes()[i],
InvalidArgument,
/* void */,
"Mismatch in shape between copy_to and copy_from tensors");
expected_output_size[i + 1] = copy_from.sizes()[i];
}
Expand All @@ -105,11 +120,22 @@ void et_copy_index(KernelRuntimeContext& context, Span<EValue*> stack) {
Error err =
resize_tensor(copy_to, {expected_output_size, copy_to.sizes().size()});
ET_CHECK(err == Error::Ok);
ET_CHECK_MSG(
ET_KERNEL_CHECK_MSG(
context,
data_ptr == copy_to.const_data_ptr(),
InvalidState,
/* void */,
"Data ptr of copy_to tensor changed after resize which isn't allowed for static/upper-bounded tensors");
}

// After potential resize, verify that index is within bounds.
ET_KERNEL_CHECK_MSG(
context,
index < copy_to.sizes()[0],
InvalidArgument,
/* void */,
"Index out of bounds");

auto copy_to_ptr = copy_to.const_data_ptr();
auto copy_from_ptr = copy_from.const_data_ptr();

Expand All @@ -118,12 +144,22 @@ void et_copy_index(KernelRuntimeContext& context, Span<EValue*> stack) {
// copy_from into the copy_to tensor.

// Check that the destination has enough space for the copy.
ET_KERNEL_CHECK_MSG(
context,
size_copy_from == 0 ||
static_cast<size_t>(index) <= SIZE_MAX / size_copy_from,
InvalidArgument,
/* void */,
"Offset multiplication .");
size_t offset = index * size_copy_from;
size_t copy_to_size = copy_to.element_size() * copy_to.numel();
ET_CHECK_MSG(
offset + size_copy_from <= copy_to_size,
"Buffer overflow: copy_to tensor is smaller than copy_from tensor.");

ET_KERNEL_CHECK_MSG(
context,
(offset <= SIZE_MAX - size_copy_from) &&
(offset + size_copy_from <= copy_to_size),
InvalidArgument,
/* void */,
"Buffer overflow; offset overflow or copy_to tensor is smaller than copy_from tensor.");
memcpy(
// NOLINTNEXTLINE(performance-no-int-to-ptr)
(void*)((uintptr_t)copy_to_ptr + offset),
Expand Down
5 changes: 3 additions & 2 deletions kernels/prim_ops/test/prim_ops_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,9 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndexMismatchShape) {
// Try to copy and replace at index 1. This will fail because
// copy_to.sizes[1:] and to_copy.sizes[:] don't match each other
// which is a pre-requisite for this operator.
ET_EXPECT_DEATH(
getOpsFn("executorch_prim::et_copy_index.tensor")(context_, stack), "");
ET_EXPECT_KERNEL_FAILURE(
context_,
getOpsFn("executorch_prim::et_copy_index.tensor")(context_, stack));
}

TEST_F(RegisterPrimOpsTest, TestETCopyIndexStaticShape) {
Expand Down
Loading