From 707033a8771bab5aa879e55013df1b2460d518fe Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 8 Sep 2025 09:00:34 -0700 Subject: [PATCH] Allow custom ops to log outputs and profiling events for etdump Summary: Our logging and profiling current doesnt work with custom ops registered via EXECUTORCH_LIBRARY macro. This diff fixes that. It requires building the sources for the custom op libs with event tracing enabled. Reviewed By: Gasoonjia Differential Revision: D81131610 --- .../make_boxed_from_unboxed_functor.h | 88 ++++++++++++++++--- extension/kernel_util/meta_programming.h | 15 ++-- 2 files changed, 86 insertions(+), 17 deletions(-) diff --git a/extension/kernel_util/make_boxed_from_unboxed_functor.h b/extension/kernel_util/make_boxed_from_unboxed_functor.h index 8f3d63db449..1710f876b20 100644 --- a/extension/kernel_util/make_boxed_from_unboxed_functor.h +++ b/extension/kernel_util/make_boxed_from_unboxed_functor.h @@ -41,7 +41,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -61,6 +63,49 @@ namespace extension { // internal namespace to avoid conflicts with other extensions. namespace kernel_util_internal { +// Template trait to check if a type is a non-const tensor +template +struct is_nonconst_tensor : std::false_type {}; + +template <> +struct is_nonconst_tensor : std::true_type {}; + +// Template trait to check if a type is a non-const tensor +// Count non-const tensors in a typelist +template +struct count_nonconst_tensors; + +template <> +struct count_nonconst_tensors> { + static constexpr size_t value = 0; +}; + +template +struct count_nonconst_tensors> { + static constexpr size_t value = 0; +}; + +template <> +struct count_nonconst_tensors> { + static constexpr size_t value = 1; +}; + +template +struct count_nonconst_tensors> { + private: + static constexpr size_t tail_tensor_count = + count_nonconst_tensors>::value; + static constexpr size_t tail_args_count = sizeof...(Tail); + static constexpr bool is_head_a_tensor = is_nonconst_tensor::value; + static constexpr bool all_tail_args_are_tensor = + tail_tensor_count == tail_args_count; + + public: + static constexpr size_t value = (is_head_a_tensor && all_tail_args_are_tensor) + ? tail_tensor_count + 1 + : tail_tensor_count; +}; + template struct decay_if_not_tensor final { using type = std::decay_t; @@ -110,16 +155,29 @@ struct evalue_to_arg>> final { } }; -template +template < + class Functor, + size_t nonconst_tensors_to_log, + size_t... evalue_arg_indices, + typename... ArgTypes> void call_functor_with_args_from_stack( executorch::runtime::KernelRuntimeContext& ctx, executorch::runtime::Span stack, std::index_sequence, typelist*) { + executorch::runtime::internal::EventTracerProfileOpScope + event_tracer_op_scope(ctx.internal_event_tracer(), Functor::func_name_); + EXECUTORCH_SCOPE_PROF(Functor::func_name_); (*Functor::func_ptr())( ctx, evalue_to_arg::type>::call( *stack[evalue_arg_indices])...); + constexpr size_t num_inputs = + std::index_sequence::size(); + for (size_t i = num_inputs - nonconst_tensors_to_log; i < num_inputs; ++i) { + executorch::runtime::internal::event_tracer_log_evalue( + ctx.internal_event_tracer(), *stack[i]); + } } } // namespace kernel_util_internal @@ -154,11 +212,16 @@ struct WrapUnboxedIntoFunctor { executorch::runtime::Span stack) { constexpr size_t num_inputs = kernel_util_internal::size::value; - return kernel_util_internal::call_functor_with_args_from_stack( - ctx, - stack, - std::make_index_sequence(), - static_cast(nullptr)); + constexpr size_t num_nonconst_tensors = + kernel_util_internal::count_nonconst_tensors< + ContextRemovedArgsType>::value; + static_assert(num_nonconst_tensors == 1, "Invalid number of inputs"); + return kernel_util_internal:: + call_functor_with_args_from_stack( + ctx, + stack, + std::make_index_sequence(), + static_cast(nullptr)); } }; @@ -181,11 +244,14 @@ static executorch::runtime::Kernel make_boxed_kernel( #define EXECUTORCH_LIBRARY(ns, op_name, func) \ _EXECUTORCH_LIBRARY_IMPL(ns, op_name, func, ET_UID) -#define _EXECUTORCH_LIBRARY_IMPL(ns, op_name, func, uid) \ - static auto ET_CONCATENATE(res_##ns##_, uid) = \ - ::executorch::runtime::register_kernel( \ - ::executorch::extension::make_boxed_kernel( \ - #ns "::" op_name, EXECUTORCH_FN(func))) +#define _EXECUTORCH_LIBRARY_IMPL(ns, op_name, func, uid) \ + static constexpr const char ET_CONCATENATE(name_of_op_, uid)[] = \ + #ns "::" op_name; \ + static auto ET_CONCATENATE(res_##ns##_, uid) = \ + ::executorch::runtime::register_kernel( \ + ::executorch::extension::make_boxed_kernel( \ + #ns "::" op_name, \ + EXECUTORCH_FN(func, ET_CONCATENATE(name_of_op_, uid)))) namespace torch { namespace executor { diff --git a/extension/kernel_util/meta_programming.h b/extension/kernel_util/meta_programming.h index 027568fe687..e3b3649dd4f 100644 --- a/extension/kernel_util/meta_programming.h +++ b/extension/kernel_util/meta_programming.h @@ -32,12 +32,13 @@ template using is_function_type_t = typename is_function_type::type; // A compile-time wrapper around a function pointer -template +template struct CompileTimeFunctionPointer final { static_assert( is_function_type::value, "EXECUTORCH_FN can only wrap function types."); using FuncType = FuncType_; + static constexpr const char* func_name_ = func_name; static constexpr FuncType* func_ptr() { return func_ptr_; @@ -47,15 +48,17 @@ struct CompileTimeFunctionPointer final { // Check if a given type is a compile-time function pointer template struct is_compile_time_function_pointer : std::false_type {}; -template +template struct is_compile_time_function_pointer< - CompileTimeFunctionPointer> : std::true_type {}; + CompileTimeFunctionPointer> + : std::true_type {}; -#define EXECUTORCH_FN_TYPE(func) \ +#define EXECUTORCH_FN_TYPE(func, name) \ ::executorch::extension::kernel_util_internal::CompileTimeFunctionPointer< \ std::remove_pointer_t>, \ - func> -#define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)() + func, \ + name> +#define EXECUTORCH_FN(func, name) EXECUTORCH_FN_TYPE(func, name)() /** * strip_class: helper to remove the class type from pointers to `operator()`.