From 5fc9f37fed9357be80acd4f53daa56bd6fa8dad0 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 2 Oct 2024 16:04:38 -0700 Subject: [PATCH] [ExecuTorch] Just pass SupportedTensorDtypes for each tensor to apply_ternary_elementwise_fn No more function pointers! Also, we check that each Tensor's type is in the allowed set. Differential Revision: [D63794199](https://our.internmc.facebook.com/intern/diff/D63794199/) [ghstack-poisoned] --- kernels/portable/cpu/op_clamp.cpp | 10 +- kernels/portable/cpu/op_where.cpp | 24 ++--- kernels/portable/cpu/util/broadcast_util.h | 116 ++++++++++++++++++--- 3 files changed, 116 insertions(+), 34 deletions(-) diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index c73b2909ac6..37c5d0f6c21 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -215,7 +215,7 @@ Tensor& clamp_tensor_out( static constexpr const char op_name[] = "clamp.Tensor_out"; ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() { - apply_ternary_elementwise_fn( + apply_ternary_elementwise_fn( [has_min, has_max]( const CTYPE_COMMON val_in, const CTYPE_COMMON val_min, @@ -230,13 +230,13 @@ Tensor& clamp_tensor_out( return val_out; }, in, + SupportedTensorDtypes::REALHBBF16, min, + SupportedTensorDtypes::REALHBBF16, max, + SupportedTensorDtypes::REALHBBF16, out, - get_load_to_common_fn_realhbbf16(in), - get_load_to_common_fn_realhbbf16(min), - get_load_to_common_fn_realhbbf16(max), - get_store_common_to_tensor_fn_realhbbf16(out)); + SupportedTensorDtypes::REALHBBF16); }); return out; diff --git a/kernels/portable/cpu/op_where.cpp b/kernels/portable/cpu/op_where.cpp index d93efcf5398..cb2616fa391 100644 --- a/kernels/portable/cpu/op_where.cpp +++ b/kernels/portable/cpu/op_where.cpp @@ -38,29 +38,25 @@ Tensor& where_out( ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(cond, a, b, out), InvalidArgument, out); - constexpr auto name = "where.self_out"; + static constexpr const char op_name[] = "where.self_out"; ET_CHECK_MSG( cond_type == ScalarType::Bool || cond_type == ScalarType::Byte, "Unhandled dtype %s for where.self_out", torch::executor::toString(cond_type)); - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() { - ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() { - using CTYPE_OUT = - typename torch::executor::promote_types::type; - apply_ternary_elementwise_fn( - [](const CTYPE_OUT val_a, - const CTYPE_OUT val_b, - const CTYPE_OUT val_c) { return val_c ? val_a : val_b; }, + ET_SWITCH_REALHBBF16_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() { + apply_ternary_elementwise_fn( + [](const CTYPE_COMMON val_a, + const CTYPE_COMMON val_b, + const CTYPE_COMMON val_c) { return val_c ? val_a : val_b; }, a, + SupportedTensorDtypes::REALHBBF16, b, + SupportedTensorDtypes::REALHBBF16, cond, + SupportedTensorDtypes::BOOL_OR_BYTE, out, - internal::load_and_convert, - internal::load_and_convert, - internal::load_and_convert, - internal::convert_and_store); - }); + SupportedTensorDtypes::SAME_AS_COMMON); }); return out; diff --git a/kernels/portable/cpu/util/broadcast_util.h b/kernels/portable/cpu/util/broadcast_util.h index c75883322d6..e78b4296384 100644 --- a/kernels/portable/cpu/util/broadcast_util.h +++ b/kernels/portable/cpu/util/broadcast_util.h @@ -280,7 +280,6 @@ template void convert_and_store(From f, void* dst) { *reinterpret_cast(dst) = static_cast(f); } -} // namespace internal template using load_to_common_fn = CTYPE_COMMON (*)(const void*); @@ -296,6 +295,15 @@ load_to_common_fn get_load_to_common_fn_realhbbf16( return result; } +template +load_to_common_fn get_load_to_common_fn_bool_or_byte(const Tensor& t) { + CTYPE_COMMON (*result)(const void*) = nullptr; + ET_SWITCH_TWO_TYPES(Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::load_and_convert; + }); + return result; +} + template using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*); @@ -310,6 +318,72 @@ get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) { return result; } +template +store_common_to_tensor_fn +get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) { + void (*result)(CTYPE_COMMON, void*) = nullptr; + ET_SWITCH_TWO_TYPES(Bool, Byte, + t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::convert_and_store; + }); + return result; +} +} // namespace internal + +enum class SupportedTensorDtypes { + REALHBBF16, + BOOL_OR_BYTE, + SAME_AS_COMMON, +}; + +namespace internal { +template +load_to_common_fn get_load_to_common_fn( + const Tensor& t, + SupportedTensorDtypes dtypes) { + switch (dtypes) { + case SupportedTensorDtypes::REALHBBF16: + return get_load_to_common_fn_realhbbf16(t); + case SupportedTensorDtypes::BOOL_OR_BYTE: + return get_load_to_common_fn_bool_or_byte(t); + case SupportedTensorDtypes::SAME_AS_COMMON: { + constexpr auto common_scalar_type = CppTypeToScalarType::value; + ET_CHECK_MSG( + t.scalar_type() == common_scalar_type, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(common_scalar_type), + op_name); + return internal::load_and_convert; + } + } + ET_CHECK(false); + return nullptr; +} + +template +store_common_to_tensor_fn get_store_common_to_tensor_fn( + const Tensor& t, + SupportedTensorDtypes dtypes) { + switch (dtypes) { + case SupportedTensorDtypes::REALHBBF16: + return get_store_common_to_tensor_fn_realhbbf16(t); + case SupportedTensorDtypes::BOOL_OR_BYTE: + return get_store_common_to_tensor_fn_bool_or_byte(t); + case SupportedTensorDtypes::SAME_AS_COMMON: { + constexpr auto common_scalar_type = CppTypeToScalarType::value; + ET_CHECK_MSG( + t.scalar_type() == common_scalar_type, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(common_scalar_type), + op_name); + return internal::convert_and_store; + } + } + ET_CHECK(false); + return nullptr; +} +} // namespace internal + /** * Useful for binary elementwise operators. For each element of the inputs, * perform a computation and write to the corresponding element of the output. @@ -356,33 +430,45 @@ inline void apply_binary_elementwise_fn( * * In order to mitigate build time cost (straightforwardly |CTYPE_A| * * |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun - * are passed as CTYPE_COMMON. We require compute_fun to return - * CTYPE_COMMON, and we require loading conversion functions from each - * input type to CTYPE_COMMON and a storing conversion from - * CTYPE_COMMON to CTYPE_OUT be provided. Each conversion function - * must take a void* pointing to an element of the corresponding - * tensor, load that element, and convert it to CTYPE_COMMON. The - * storing conversion function must have the signature - * void(CTYPE_COMMON, void*), convert the given element to CTYPE_OUT, - * and store it to the given location. + * are passed as CTYPE_COMMON. + * + * Each tensor's supported dtypes set must be provided. The tensor + * will be checked to ensure that its dtype falls into that set. + * + * op_name is used to support dtype selective build, as with the + * ET_SWITCH family of macros. Note: because of C++17 quirks, you + * can't pass a string literal for op_name. Instead, you should do the + * following: + * + * static constexpr const char op_name[] = "my_op"; + * apply_ternary_elementwise_fn. */ -template +template inline void apply_ternary_elementwise_fn( const Op& compute_fun, const Tensor& a, + SupportedTensorDtypes a_dtypes, const Tensor& b, + SupportedTensorDtypes b_dtypes, const Tensor& c, + SupportedTensorDtypes c_dtypes, const Tensor& out, - CTYPE_COMMON (*load_a_to_common)(const void*), - CTYPE_COMMON (*load_b_to_common)(const void*), - CTYPE_COMMON (*load_c_to_common)(const void*), - void (*store_common_to_out)(CTYPE_COMMON, void*)) { + SupportedTensorDtypes out_dtypes) { const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); const bool c_is_broadcasted = !out.sizes().equals(c.sizes()); const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted || c_is_broadcasted); + const auto load_a_to_common = + internal::get_load_to_common_fn(a, a_dtypes); + const auto load_b_to_common = + internal::get_load_to_common_fn(b, b_dtypes); + const auto load_c_to_common = + internal::get_load_to_common_fn(c, c_dtypes); + const auto store_common_to_out = + internal::get_store_common_to_tensor_fn( + out, out_dtypes); const char* const data_a = reinterpret_cast(a.const_data_ptr()); const char* const data_b = reinterpret_cast(b.const_data_ptr()); const char* const data_c = reinterpret_cast(c.const_data_ptr());