From 2aa18ed1eb343fee3bbcf0fc5fc77cc5c4b965bd Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 3 Sep 2024 09:42:54 -0700 Subject: [PATCH] Add op: topk (#4307) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4307 Differential Revision: D59936967 --- kernels/portable/cpu/op_topk.cpp | 251 ++++++++++++++++++ kernels/portable/functions.yaml | 5 + kernels/test/op_topk_test.cpp | 138 ++++++++++ kernels/test/targets.bzl | 1 + .../kernels/portable/op_registration_util.bzl | 3 + 5 files changed, 398 insertions(+) create mode 100644 kernels/portable/cpu/op_topk.cpp create mode 100644 kernels/test/op_topk_test.cpp diff --git a/kernels/portable/cpu/op_topk.cpp b/kernels/portable/cpu/op_topk.cpp new file mode 100644 index 00000000000..3cc0ccb9de4 --- /dev/null +++ b/kernels/portable/cpu/op_topk.cpp @@ -0,0 +1,251 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +namespace torch { +namespace executor { +namespace native { +namespace { + +bool check_topk_args( + const Tensor& in, + int64_t k, + int64_t dim, + Tensor& values, + Tensor& indices) { + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, values)); + ET_LOG_AND_RETURN_IF_FALSE(indices.scalar_type() == ScalarType::Long); + ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim)); + if (dim < 0) { + dim += nonzero_dim(in); + } + ET_LOG_MSG_AND_RETURN_IF_FALSE( + k >= 0 && k <= nonempty_size(in, dim), "selected index k out of range"); + return true; +} + +bool get_topk_target_size( + const Tensor& in, + int64_t k, + int64_t dim, + Tensor::SizesType* target_size, + size_t* target_dim) { + *target_dim = in.dim(); + for (size_t i = 0; i < *target_dim; ++i) { + if (i == dim) { + target_size[i] = k; + } else { + target_size[i] = in.size(i); + } + } + return true; +} + +template > +void perform_topk( + const Tensor& in, + int64_t k, + int64_t dim, + bool largest, + bool sorted, + Tensor& values, + Tensor& indices, + elem_t* queue) { + const CTYPE* const in_data = in.const_data_ptr(); + CTYPE* values_data = values.mutable_data_ptr(); + long* indices_data = indices.mutable_data_ptr(); + + if (in.dim() == 0) { + values_data[0] = in_data[0]; + indices_data[0] = 0; + return; + } + + if (k == 0) { + return; + } + + const size_t outer_size = getLeadingDims(in, dim); + + const size_t dim_size = in.size(dim); + const size_t dim_stride = in.strides()[dim]; + + const size_t outer_stride_in = dim_size * dim_stride; + const size_t outer_stride_out = k * dim_stride; + + bool use_partial_sort = k * 64 <= dim_size; + + // Loop through all outer dimensions + for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { + size_t outer_in = outer_idx * outer_stride_in; + size_t outer_out = outer_idx * outer_stride_out; + // Loop through all inner dimensions + for (size_t inner_idx = 0; inner_idx < dim_stride; ++inner_idx) { + size_t base_in = outer_in + inner_idx; + size_t base_out = outer_out + inner_idx; + + // Populate the queue with the values from the input tensor + for (size_t i = 0; i < dim_size; ++i) { + size_t in_ix = base_in + i * dim_stride; + queue[i].first = in_data[in_ix]; + queue[i].second = i; + } + + // Perform topk on the queue + if (use_partial_sort) { + if (largest) { + std::partial_sort( + queue, + queue + k, + queue + dim_size, + [](const elem_t& x, const elem_t& y) -> bool { + return ( + (std::isnan(x.first) && !std::isnan(y.first)) || + (x.first > y.first)); + }); + } else { + std::partial_sort( + queue, + queue + k, + queue + dim_size, + [](const elem_t& x, const elem_t& y) -> bool { + return ( + (!std::isnan(x.first) && std::isnan(y.first)) || + (x.first < y.first)); + }); + } + } else { + if (largest) { + std::nth_element( + queue, + queue + k - 1, + queue + dim_size, + [](const elem_t& x, const elem_t& y) -> bool { + return ( + (std::isnan(x.first) && !std::isnan(y.first)) || + (x.first > y.first)); + }); + if (sorted) { + std::sort( + queue, + queue + k - 1, + [](const elem_t& x, const elem_t& y) -> bool { + return ( + (std::isnan(x.first) && !std::isnan(y.first)) || + (x.first > y.first)); + }); + } + } else { + std::nth_element( + queue, + queue + k - 1, + queue + dim_size, + [](const elem_t& x, const elem_t& y) -> bool { + return ( + (!std::isnan(x.first) && std::isnan(y.first)) || + (x.first < y.first)); + }); + if (sorted) { + std::sort( + queue, + queue + k - 1, + [](const elem_t& x, const elem_t& y) -> bool { + return ( + (!std::isnan(x.first) && std::isnan(y.first)) || + (x.first < y.first)); + }); + } + } + } + + // Write the topk values and indices to the output tensors + for (size_t i = 0; i < k; ++i) { + size_t out_ix = base_out + i * dim_stride; + + values_data[out_ix] = queue[i].first; + indices_data[out_ix] = queue[i].second; + } + } + } +} + +void* allocate_temp_memory(RuntimeContext& ctx, size_t size) { + Result temp_mem_res = ctx.allocate_temp(size); + return temp_mem_res.ok() ? temp_mem_res.get() : nullptr; +} + +} // namespace + +std::tuple topk_values( + RuntimeContext& ctx, + const Tensor& in, + int64_t k, + int64_t dim, + bool largest, + bool sorted, + Tensor& values, + Tensor& indices) { + auto out = std::tuple({values, indices}); + + ET_KERNEL_CHECK( + ctx, check_topk_args(in, k, dim, values, indices), InvalidArgument, out); + + if (dim < 0) { + dim += nonzero_dim(in); + } + + // @lint-ignore CLANGTIDY facebook-hte-CArray + Tensor::SizesType target_size[kTensorDimensionLimit]; + size_t target_dim = 0; + get_topk_target_size(in, k, dim, target_size, &target_dim); + + ET_KERNEL_CHECK( + ctx, + resize_tensor(values, {target_size, target_dim}) == Error::Ok, + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, + resize_tensor(indices, {target_size, target_dim}) == Error::Ok, + InvalidArgument, + out); + + constexpr auto name = "topk.values"; + + if (in.numel() == 0 || (k == 0 && in.dim() > 0)) { + return out; + } + + bool temp_mem_allocated = false; + + ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() { + using elem_t = std::pair; + size_t temp_mem_size = nonempty_size(in, dim) * sizeof(elem_t); + + elem_t* queue = (elem_t*)allocate_temp_memory(ctx, temp_mem_size); + if (queue == nullptr) { + return; + } + temp_mem_allocated = true; + + perform_topk(in, k, dim, largest, sorted, values, indices, queue); + }); + + ET_KERNEL_CHECK(ctx, temp_mem_allocated, MemoryAllocationFailed, out); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index 21258329aa8..6b0b2466888 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -847,6 +847,11 @@ - arg_meta: null kernel_name: torch::executor::tanh_out +- op: topk.values + kernels: + - arg_meta: null + kernel_name: torch::executor::topk_values + - op: transpose_copy.int_out kernels: - arg_meta: null diff --git a/kernels/test/op_topk_test.cpp b/kernels/test/op_topk_test.cpp new file mode 100644 index 00000000000..9f57225ba4f --- /dev/null +++ b/kernels/test/op_topk_test.cpp @@ -0,0 +1,138 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include + +#include + +using namespace ::testing; +using exec_aten::IntArrayRef; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using executorch::runtime::MemoryAllocator; +using torch::executor::testing::TensorFactory; + +class TempMemoryAllocator final : public MemoryAllocator { + private: + // We allocate a little more than requested and use that memory as a node in + // a linked list, pushing the allocated buffers onto a list that's iterated + // and freed when the KernelRuntimeContext is destroyed. + struct AllocationNode { + void* data; + AllocationNode* next; + }; + + AllocationNode* head_ = nullptr; + + public: + TempMemoryAllocator() : MemoryAllocator(0, nullptr) {} + + void* allocate(size_t size, size_t alignment = kDefaultAlignment) override { + if (!isPowerOf2(alignment)) { + ET_LOG(Error, "Alignment %zu is not a power of 2", alignment); + return nullptr; + } + + // Allocate enough memory for the node, the data and the alignment bump. + size_t alloc_size = sizeof(AllocationNode) + size + alignment; + void* node_memory = malloc(alloc_size); + + // If allocation failed, log message and return nullptr. + if (node_memory == nullptr) { + ET_LOG(Error, "Failed to allocate %zu bytes", alloc_size); + return nullptr; + } + + // Compute data pointer. + uint8_t* data_ptr = + reinterpret_cast(node_memory) + sizeof(AllocationNode); + + // Align the data pointer. + void* aligned_data_ptr = alignPointer(data_ptr, alignment); + + // Assert that the alignment didn't overflow the allocated memory. + ET_DCHECK_MSG( + reinterpret_cast(aligned_data_ptr) + size <= + reinterpret_cast(node_memory) + alloc_size, + "aligned_data_ptr %p + size %zu > node_memory %p + alloc_size %zu", + aligned_data_ptr, + size, + node_memory, + alloc_size); + + // Construct the node. + AllocationNode* new_node = reinterpret_cast(node_memory); + new_node->data = aligned_data_ptr; + new_node->next = head_; + head_ = new_node; + + // Return the aligned data pointer. + return head_->data; + } + + void reset() override { + AllocationNode* current = head_; + while (current != nullptr) { + AllocationNode* next = current->next; + free(current); + current = next; + } + head_ = nullptr; + } + + ~TempMemoryAllocator() override { + reset(); + } +}; + +std::tuple op_topk_values( + const Tensor& input, + int64_t k, + int64_t dim, + bool largest, + bool sorted, + Tensor& values, + Tensor& indices) { + TempMemoryAllocator allocator = TempMemoryAllocator(); + exec_aten::RuntimeContext context(nullptr, &allocator); + return torch::executor::aten::topk_outf( + context, input, k, dim, largest, sorted, values, indices); +} + +class OpTopkValuesTest : public ::testing::Test { + protected: + void SetUp() override { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + torch::executor::runtime_init(); + } +}; + +TEST_F(OpTopkValuesTest, SmokeTest) { + TensorFactory tfFloat; + TensorFactory tfLong; + + Tensor input = + tfFloat.make({3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + int64_t k = 2; + int64_t dim = 0; + bool largest = true; + bool sorted = true; + Tensor values = tfFloat.zeros({2, 2, 2}); + Tensor indices = tfLong.zeros({2, 2, 2}); + Tensor values_expected = tfFloat.make({2, 2, 2}, {9, 10, 11, 12, 5, 6, 7, 8}); + Tensor indices_expected = tfLong.make({2, 2, 2}, {2, 2, 2, 2, 1, 1, 1, 1}); + op_topk_values(input, k, dim, largest, sorted, values, indices); + EXPECT_TENSOR_CLOSE(values, values_expected); + EXPECT_TENSOR_EQ(indices, indices_expected); +} diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 69f4e176ff9..07421b25e51 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -295,6 +295,7 @@ def define_common_targets(): _common_op_test("op_tan_test", ["aten", "portable"]) _common_op_test("op_tanh_test", ["aten", "portable"]) _common_op_test("op_to_copy_test", ["aten", "portable"]) + _common_op_test("op_topk_test", ["aten", "portable"]) _common_op_test("op_transpose_copy_test", ["aten", "portable"]) _common_op_test("op_tril_test", ["aten", "portable"]) _common_op_test("op_trunc_test", ["aten", "portable"]) diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index 0cc9ab5fd0e..04e824db57c 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -1125,6 +1125,9 @@ ATEN_OPS = ( "//executorch/kernels/portable/cpu/util:copy_ops_util", ], ), + op_target( + name = "op_topk", + ), op_target( name = "op_transpose_copy", deps = ["//executorch/kernels/portable/cpu/util:transpose_util"],