From aef60d23e7f8ad75d1cdbc28a3944cbbf7bce995 Mon Sep 17 00:00:00 2001 From: RJ Ascani Date: Mon, 9 Jun 2025 10:15:43 -0700 Subject: [PATCH] Use nullptr to represent fallback kernels (#11484) Summary: `KernelKey`s can be constructed as either a "fallback" or "specialized" key. The "fallback" type uses the default constructor, and the "specialized" takes a specially formatted string. Internally, this was represented as a pointer to the optional key string and a bool for whether it is a fallback kernel. As it would not make sense to construct a "specialized" kernel without a key string, this diff eliminates the bool `is_fallback_` in favor of using `kernel_key_data_ == nullptr` to represent fallback kernels. Each `KernelKey` is nested within the `Kernel` data structure, which makes up the list of registered kernels. As the default size of the registered_kernels array is 2000 kernel entries, this diff can reduce the size of the `registered_kernels` array by 8 KB. This diff also changes the backing storage buffer for `registered_kernels_data` to ensure that there is enough space for each Kernel element in the array to be aligned according to `alignas(Kernel)`. Differential Revision: D76201866 --- runtime/kernel/operator_registry.cpp | 7 +++++-- runtime/kernel/operator_registry.h | 11 +++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/runtime/kernel/operator_registry.cpp b/runtime/kernel/operator_registry.cpp index d7e7b298c10..d5c9a982d6d 100644 --- a/runtime/kernel/operator_registry.cpp +++ b/runtime/kernel/operator_registry.cpp @@ -35,9 +35,12 @@ constexpr uint32_t kMaxRegisteredKernels = kMaxOperators * kMaxKernelsPerOp; // require constructing them at init time. Since we don't care about the values // until we add each entry to the table, allocate static zeroed memory instead // and point the table at it. +struct alignas(Kernel) KernelBuffer { + uint8_t data[sizeof(Kernel)]; +}; + // @lint-ignore CLANGTIDY facebook-hte-CArray -alignas(sizeof(Kernel)) uint8_t - registered_kernels_data[kMaxRegisteredKernels * sizeof(Kernel)]; +KernelBuffer registered_kernels_data[kMaxRegisteredKernels]; /// Global table of registered kernels. Kernel* registered_kernels = reinterpret_cast(registered_kernels_data); diff --git a/runtime/kernel/operator_registry.h b/runtime/kernel/operator_registry.h index f7a62208dd8..9bd6318676c 100644 --- a/runtime/kernel/operator_registry.h +++ b/runtime/kernel/operator_registry.h @@ -123,7 +123,7 @@ struct KernelKey { * for all input tensor dtypes and dim orders if the specialized kernel is not * registered. */ - KernelKey() : is_fallback_(true) {} + KernelKey() = default; /** * Creates a specialized (non-fallback) kernel key that matches a specific @@ -131,7 +131,7 @@ struct KernelKey { * expected format of `kernel_key_data`. */ /* implicit */ KernelKey(const char* kernel_key_data) - : kernel_key_data_(kernel_key_data), is_fallback_(false) {} + : kernel_key_data_(kernel_key_data) {} bool operator==(const KernelKey& other) const { return this->equals(other); @@ -142,17 +142,17 @@ struct KernelKey { } bool equals(const KernelKey& other) const { - if (is_fallback_ != other.is_fallback_) { + if (is_fallback() != other.is_fallback()) { return false; } - if (is_fallback_) { + if (is_fallback()) { return true; } return strcmp(kernel_key_data_, other.kernel_key_data_) == 0; } bool is_fallback() const { - return is_fallback_; + return kernel_key_data_ == nullptr; } const char* data() const { @@ -168,7 +168,6 @@ struct KernelKey { private: const char* kernel_key_data_ = nullptr; - bool is_fallback_; }; /**