Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 77 additions & 11 deletions extension/kernel_util/make_boxed_from_unboxed_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
#include <executorch/extension/kernel_util/meta_programming.h>
#include <executorch/extension/kernel_util/type_list.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/event_tracer_hooks.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/kernel/kernel_runtime_context.h>
#include <executorch/runtime/kernel/operator_registry.h>
#include <cstdlib>
#include <memory>
Expand All @@ -61,6 +63,49 @@ namespace extension {
// internal namespace to avoid conflicts with other extensions.
namespace kernel_util_internal {

// Template trait to check if a type is a non-const tensor
template <class T>
struct is_nonconst_tensor : std::false_type {};

template <>
struct is_nonconst_tensor<executorch::aten::Tensor&> : std::true_type {};

// Template trait to check if a type is a non-const tensor
// Count non-const tensors in a typelist
template <class TypeList>
struct count_nonconst_tensors;

template <>
struct count_nonconst_tensors<typelist<>> {
static constexpr size_t value = 0;
};

template <class T>
struct count_nonconst_tensors<typelist<T>> {
static constexpr size_t value = 0;
};

template <>
struct count_nonconst_tensors<typelist<executorch::aten::Tensor&>> {
static constexpr size_t value = 1;
};

template <class Head, class... Tail>
struct count_nonconst_tensors<typelist<Head, Tail...>> {
private:
static constexpr size_t tail_tensor_count =
count_nonconst_tensors<typelist<Tail...>>::value;
static constexpr size_t tail_args_count = sizeof...(Tail);
static constexpr bool is_head_a_tensor = is_nonconst_tensor<Head>::value;
static constexpr bool all_tail_args_are_tensor =
tail_tensor_count == tail_args_count;

public:
static constexpr size_t value = (is_head_a_tensor && all_tail_args_are_tensor)
? tail_tensor_count + 1
: tail_tensor_count;
};

template <class T>
struct decay_if_not_tensor final {
using type = std::decay_t<T>;
Expand Down Expand Up @@ -110,16 +155,29 @@ struct evalue_to_arg<executorch::aten::ArrayRef<std::optional<T>>> final {
}
};

template <class Functor, size_t... evalue_arg_indices, typename... ArgTypes>
template <
class Functor,
size_t nonconst_tensors_to_log,
size_t... evalue_arg_indices,
typename... ArgTypes>
void call_functor_with_args_from_stack(
executorch::runtime::KernelRuntimeContext& ctx,
executorch::runtime::Span<executorch::runtime::EValue*> stack,
std::index_sequence<evalue_arg_indices...>,
typelist<ArgTypes...>*) {
executorch::runtime::internal::EventTracerProfileOpScope
event_tracer_op_scope(ctx.internal_event_tracer(), Functor::func_name_);
EXECUTORCH_SCOPE_PROF(Functor::func_name_);
(*Functor::func_ptr())(
ctx,
evalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type>::call(
*stack[evalue_arg_indices])...);
constexpr size_t num_inputs =
std::index_sequence<evalue_arg_indices...>::size();
for (size_t i = num_inputs - nonconst_tensors_to_log; i < num_inputs; ++i) {
executorch::runtime::internal::event_tracer_log_evalue(
ctx.internal_event_tracer(), *stack[i]);
}
}

} // namespace kernel_util_internal
Expand Down Expand Up @@ -154,11 +212,16 @@ struct WrapUnboxedIntoFunctor {
executorch::runtime::Span<executorch::runtime::EValue*> stack) {
constexpr size_t num_inputs =
kernel_util_internal::size<ContextRemovedArgsType>::value;
return kernel_util_internal::call_functor_with_args_from_stack<FuncType>(
ctx,
stack,
std::make_index_sequence<num_inputs>(),
static_cast<ContextRemovedArgsType*>(nullptr));
constexpr size_t num_nonconst_tensors =
kernel_util_internal::count_nonconst_tensors<
ContextRemovedArgsType>::value;
static_assert(num_nonconst_tensors == 1, "Invalid number of inputs");
return kernel_util_internal::
call_functor_with_args_from_stack<FuncType, num_nonconst_tensors>(
ctx,
stack,
std::make_index_sequence<num_inputs>(),
static_cast<ContextRemovedArgsType*>(nullptr));
}
};

Expand All @@ -181,11 +244,14 @@ static executorch::runtime::Kernel make_boxed_kernel(
#define EXECUTORCH_LIBRARY(ns, op_name, func) \
_EXECUTORCH_LIBRARY_IMPL(ns, op_name, func, ET_UID)

#define _EXECUTORCH_LIBRARY_IMPL(ns, op_name, func, uid) \
static auto ET_CONCATENATE(res_##ns##_, uid) = \
::executorch::runtime::register_kernel( \
::executorch::extension::make_boxed_kernel( \
#ns "::" op_name, EXECUTORCH_FN(func)))
#define _EXECUTORCH_LIBRARY_IMPL(ns, op_name, func, uid) \
static constexpr const char ET_CONCATENATE(name_of_op_, uid)[] = \
#ns "::" op_name; \
static auto ET_CONCATENATE(res_##ns##_, uid) = \
::executorch::runtime::register_kernel( \
::executorch::extension::make_boxed_kernel( \
#ns "::" op_name, \
EXECUTORCH_FN(func, ET_CONCATENATE(name_of_op_, uid))))

namespace torch {
namespace executor {
Expand Down
15 changes: 9 additions & 6 deletions extension/kernel_util/meta_programming.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ template <class T>
using is_function_type_t = typename is_function_type<T>::type;

// A compile-time wrapper around a function pointer
template <class FuncType_, FuncType_* func_ptr_>
template <class FuncType_, FuncType_* func_ptr_, const char* func_name>
struct CompileTimeFunctionPointer final {
static_assert(
is_function_type<FuncType_>::value,
"EXECUTORCH_FN can only wrap function types.");
using FuncType = FuncType_;
static constexpr const char* func_name_ = func_name;

static constexpr FuncType* func_ptr() {
return func_ptr_;
Expand All @@ -47,15 +48,17 @@ struct CompileTimeFunctionPointer final {
// Check if a given type is a compile-time function pointer
template <class T>
struct is_compile_time_function_pointer : std::false_type {};
template <class FuncType, FuncType* func_ptr>
template <class FuncType, FuncType* func_ptr, const char* func_name>
struct is_compile_time_function_pointer<
CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {};
CompileTimeFunctionPointer<FuncType, func_ptr, func_name>>
: std::true_type {};

#define EXECUTORCH_FN_TYPE(func) \
#define EXECUTORCH_FN_TYPE(func, name) \
::executorch::extension::kernel_util_internal::CompileTimeFunctionPointer< \
std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \
func>
#define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)()
func, \
name>
#define EXECUTORCH_FN(func, name) EXECUTORCH_FN_TYPE(func, name)()

/**
* strip_class: helper to remove the class type from pointers to `operator()`.
Expand Down
Loading