From 9b8112b4a2a6042a93c4df2bf497f9e259f6aedd Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 8 Oct 2024 14:54:54 -0700 Subject: [PATCH] [ET][Portable][Build Size] Introduce notion of compute type. Refactor add/clamp/where Introduced notion of compute type. We now perform the computation over `CTYPE_COMP` (the compute type) rather than `CTYPE_COMMON` (the common type) All of the occurrences of `CTYPE_COMMON` in elementwise_util.h need to be replaced with `CTYPE_COMP`, to properly reflect that we are dealing with the computation type, and not the common type. But we don't do that in this diff, to facilitate review. The previous `SupportedTensorDtypes::SAME_AS_COMMON` is transformed into `SupportedTensorDtypes::SAME_AS_COMP` and a newer `SupportedTensorDtypes::SAME_AS_COMMON` is written. `SupportedTensorDtypes::SAME_AS_COMMON` should perform the reverse mapping than get_compute_type(). In this case, this means that when `CTYPE_COMP` is anything but float, `SAME_AS_COMMON` is effectively the same as `SAME_AS_COMP`. But when `CTYPE_COMP` is float, `SAME_AS_COMMON` switches over `Float`, `Half` and `BFloat16` Build size reduction: - add: 21K -> 18K - clamp: 28K -> 23K - where: 16K -> 12K Differential Revision: [D63860791](https://our.internmc.facebook.com/intern/diff/D63860791/) [ghstack-poisoned] --- kernels/portable/cpu/op_add.cpp | 87 ++++++------- kernels/portable/cpu/op_clamp.cpp | 126 +++++++++++-------- kernels/portable/cpu/op_where.cpp | 45 ++++--- kernels/portable/cpu/util/elementwise_util.h | 109 +++++++++++++--- 4 files changed, 238 insertions(+), 129 deletions(-) diff --git a/kernels/portable/cpu/op_add.cpp b/kernels/portable/cpu/op_add.cpp index 05b5eab6f03..ff48563c477 100644 --- a/kernels/portable/cpu/op_add.cpp +++ b/kernels/portable/cpu/op_add.cpp @@ -24,36 +24,43 @@ Tensor& add_out( Tensor& out) { ET_KERNEL_CHECK( ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, + (executorch::runtime::tensor_is_realhbbf16_type(a) && + executorch::runtime::tensor_is_realhbbf16_type(b) && + executorch::runtime::tensor_is_realhbbf16_type(out)), InvalidArgument, out); + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype ET_KERNEL_CHECK( ctx, - (executorch::runtime::tensor_is_realhbbf16_type(a) && - executorch::runtime::tensor_is_realhbbf16_type(b) && - executorch::runtime::tensor_is_realhbbf16_type(out)), + (canCast(common_type, out.scalar_type()) && + check_alpha_type(utils::get_scalar_dtype(alpha), common_type)), InvalidArgument, out); + + // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType alpha_type = utils::get_scalar_dtype(alpha); - ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); + // Resize ET_KERNEL_CHECK( - ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out); + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); static constexpr const char op_name[] = "add.out"; - ET_SWITCH_REALB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() { - utils::apply_bitensor_elementwise_fn( - [alpha](const CTYPE_COMMON val_a, const CTYPE_COMMON val_b) { - CTYPE_COMMON val_alpha = utils::scalar_to(alpha); + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + utils::apply_bitensor_elementwise_fn( + [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { return val_a + val_alpha * val_b; }, a, @@ -73,52 +80,48 @@ Tensor& add_scalar_out( const Scalar& b, const Scalar& alpha, Tensor& out) { - (void)ctx; - - // Resize for dynamic shape - ET_KERNEL_CHECK_MSG( + ET_KERNEL_CHECK( ctx, - resize_tensor(out, a.sizes()) == Error::Ok, + (executorch::runtime::tensor_is_realhbbf16_type(a) && + executorch::runtime::tensor_is_realhbbf16_type(out)), InvalidArgument, - out, - "Failed to resize output tensor."); + out); + // Common Dtype + ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); + + // Check Common Dtype ET_KERNEL_CHECK( ctx, - (executorch::runtime::tensor_is_realhbbf16_type(a) && - executorch::runtime::tensor_is_realhbbf16_type(out)), + (common_type == out.scalar_type() && + check_alpha_type(utils::get_scalar_dtype(alpha), common_type)), InvalidArgument, out); + + // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - ScalarType a_type = a.scalar_type(); - ScalarType alpha_type = utils::get_scalar_dtype(alpha); - ScalarType common_type = - utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); + // Resize ET_KERNEL_CHECK( - ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out); + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); - if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) { - common_type = ScalarType::Float; - } + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); static constexpr const char op_name[] = "add.Scalar_out"; - ET_SWITCH_REALB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() { - utils::apply_unitensor_elementwise_fn( - [b, alpha](const CTYPE_COMMON val_a) { - CTYPE_COMMON val_b = utils::scalar_to(b); - CTYPE_COMMON val_alpha = utils::scalar_to(alpha); + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_unitensor_elementwise_fn( + [b, alpha](const CTYPE_COMPUTE val_a) { + CTYPE_COMPUTE val_b = utils::scalar_to(b); + CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); return val_a + val_alpha * val_b; }, a, utils::SupportedTensorDtypes::REALHBBF16, out, - utils::SupportedTensorDtypes::REALHBBF16); + utils::SupportedTensorDtypes::SAME_AS_COMMON); }); return out; diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index c8bb9297a13..508ba711748 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -73,38 +73,52 @@ Tensor& clamp_out( const exec_aten::optional& min_opt, const exec_aten::optional& max_opt, Tensor& out) { - (void)ctx; + ET_KERNEL_CHECK( + ctx, + (executorch::runtime::tensor_is_realhbbf16_type(in) && + executorch::runtime::tensor_is_realhbbf16_type(out)), + InvalidArgument, + out); + + bool has_min = min_opt.has_value(); + bool has_max = max_opt.has_value(); ET_KERNEL_CHECK_MSG( ctx, - resize_tensor(out, in.sizes()) == Error::Ok, + has_min || has_max, InvalidArgument, out, - "Failed to resize output tensor."); - - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + "At least one of 'min' or 'max' must not be None"); + // Input Dtypes ScalarType in_type = in.scalar_type(); - ScalarType min_type = in_type; - ScalarType max_type = in_type; - ScalarType common_type = in_type; + ScalarType min_type = + has_min ? utils::get_scalar_dtype(min_opt.value()) : in_type; + ScalarType max_type = + has_max ? utils::get_scalar_dtype(max_opt.value()) : in_type; ScalarType out_type = out.scalar_type(); - bool has_min = min_opt.has_value(); + // Common Dtype + ScalarType common_type = in_type; if (has_min) { - min_type = utils::get_scalar_dtype(min_opt.value()); common_type = utils::promote_type_with_scalar(common_type, min_opt.value()); + } + if (has_max) { + common_type = utils::promote_type_with_scalar(common_type, max_opt.value()); + } + + // Check Common Dtype + ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); + + // Check Scalar Bounds + if (has_min) { ET_KERNEL_CHECK( ctx, check_bounds(min_opt.value(), min_type, out_type, "minimum"), InvalidArgument, out); } - bool has_max = max_opt.has_value(); if (has_max) { - max_type = utils::get_scalar_dtype(max_opt.value()); - common_type = utils::promote_type_with_scalar(common_type, max_opt.value()); ET_KERNEL_CHECK( ctx, check_bounds(max_opt.value(), max_type, out_type, "maximum"), @@ -112,35 +126,37 @@ Tensor& clamp_out( out); } - ET_KERNEL_CHECK_MSG( - ctx, - has_min || has_max, - InvalidArgument, - out, - "At least one of 'min' or 'max' must not be None"); + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); - ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); static constexpr const char op_name[] = "clamp.out"; - ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() { - utils::apply_unitensor_elementwise_fn( - [has_min, min_opt, has_max, max_opt](const CTYPE_COMMON val_in) { - CTYPE_COMMON val_out = val_in; + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_unitensor_elementwise_fn( + [has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) { + CTYPE_COMPUTE val_out = val_in; if (has_min) { val_out = utils::max_override( - val_out, utils::scalar_to(min_opt.value())); + val_out, utils::scalar_to(min_opt.value())); } if (has_max) { val_out = utils::min_override( - val_out, utils::scalar_to(max_opt.value())); + val_out, utils::scalar_to(max_opt.value())); } return val_out; }, in, utils::SupportedTensorDtypes::REALHBBF16, out, - utils::SupportedTensorDtypes::REALHBBF16); + utils::SupportedTensorDtypes::SAME_AS_COMMON); }); return out; @@ -152,8 +168,6 @@ Tensor& clamp_tensor_out( const exec_aten::optional& min_opt, const exec_aten::optional& max_opt, Tensor& out) { - (void)ctx; - bool has_min = min_opt.has_value(); bool has_max = max_opt.has_value(); @@ -167,42 +181,54 @@ Tensor& clamp_tensor_out( const Tensor& min = has_min ? min_opt.value() : in; const Tensor& max = has_max ? max_opt.value() : in; + ET_KERNEL_CHECK( + ctx, + (executorch::runtime::tensor_is_realhbbf16_type(in) && + executorch::runtime::tensor_is_realhbbf16_type(min) && + executorch::runtime::tensor_is_realhbbf16_type(max) && + executorch::runtime::tensor_is_realhbbf16_type(out)), + InvalidArgument, + out); + + // Common Dtype + ScalarType common_type = in.scalar_type(); + if (has_min) { + common_type = promoteTypes(common_type, min.scalar_type()); + } + if (has_max) { + common_type = promoteTypes(common_type, max.scalar_type()); + } + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out); + + // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, min, max, out), InvalidArgument, out); + // Resize ET_KERNEL_CHECK( ctx, resize_to_broadcast_target_size(in, min, max, out) == Error::Ok, InvalidArgument, out); - ScalarType in_type = in.scalar_type(); - ScalarType min_type = min.scalar_type(); - ScalarType max_type = max.scalar_type(); - ScalarType common_type = in_type; - ScalarType out_type = out.scalar_type(); - - if (has_min) { - common_type = promoteTypes(common_type, min_type, /*half_to_float*/ true); - } - if (has_max) { - common_type = promoteTypes(common_type, max_type, /*half_to_float*/ true); - } - - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); static constexpr const char op_name[] = "clamp.Tensor_out"; - ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() { - utils::apply_tritensor_elementwise_fn( + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_tritensor_elementwise_fn( [has_min, has_max]( - const CTYPE_COMMON val_in, - const CTYPE_COMMON val_min, - const CTYPE_COMMON val_max) { - CTYPE_COMMON val_out = val_in; + const CTYPE_COMPUTE val_in, + const CTYPE_COMPUTE val_min, + const CTYPE_COMPUTE val_max) { + CTYPE_COMPUTE val_out = val_in; if (has_min) { val_out = utils::max_override(val_out, val_min); } diff --git a/kernels/portable/cpu/op_where.cpp b/kernels/portable/cpu/op_where.cpp index 3765590ee83..7a13bf881a0 100644 --- a/kernels/portable/cpu/op_where.cpp +++ b/kernels/portable/cpu/op_where.cpp @@ -19,36 +19,43 @@ Tensor& where_out( const Tensor& a, const Tensor& b, Tensor& out) { - ScalarType cond_type = cond.scalar_type(); - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType common_type = promoteTypes(a_type, b_type); - ScalarType out_type = out.scalar_type(); + ET_KERNEL_CHECK( + ctx, + ((cond.scalar_type() == ScalarType::Bool || + cond.scalar_type() == ScalarType::Byte) && + executorch::runtime::tensor_is_realhbbf16_type(a) && + executorch::runtime::tensor_is_realhbbf16_type(b) && + executorch::runtime::tensor_is_realhbbf16_type(out)), + InvalidArgument, + out); + + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); - ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); + // Check Common Dtype + ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out); - // Determine output size and resize for dynamic shapes + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(cond, a, b, out), InvalidArgument, out); + + // Resize ET_KERNEL_CHECK( ctx, resize_to_broadcast_target_size(a, b, cond, out) == Error::Ok, InvalidArgument, out); - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(cond, a, b, out), InvalidArgument, out); + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); static constexpr const char op_name[] = "where.self_out"; - ET_CHECK_MSG( - cond_type == ScalarType::Bool || cond_type == ScalarType::Byte, - "Unhandled dtype %s for where.self_out", - torch::executor::toString(cond_type)); - - ET_SWITCH_REALHBBF16_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() { - utils::apply_tritensor_elementwise_fn( - [](const CTYPE_COMMON val_a, - const CTYPE_COMMON val_b, - const CTYPE_COMMON val_c) { return val_c ? val_a : val_b; }, + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_tritensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, + const CTYPE_COMPUTE val_b, + const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; }, a, utils::SupportedTensorDtypes::REALHBBF16, b, diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 19dcbd73ba0..52ad6fca116 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -80,6 +80,41 @@ load_to_common_fn get_load_to_common_fn_bool_or_byte( return result; } +template +load_to_common_fn get_load_to_common_fn_same_as_compute( + const Tensor& t) { + constexpr auto common_scalar_type = CppTypeToScalarType::value; + ET_CHECK_MSG( + t.scalar_type() == common_scalar_type, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(common_scalar_type), + op_name); + return internal::load_and_convert; +} + +template < + typename CTYPE_COMMON, + const char* op_name, + std::enable_if_t, bool> = true> +load_to_common_fn get_load_to_common_fn_same_as_common( + const Tensor& t) { + CTYPE_COMMON (*result)(const void*) = nullptr; + ET_SWITCH_THREE_TYPES( + Float, Half, BFloat16, t.scalar_type(), unused, op_name, T, [&]() { + result = internal::load_and_convert; + }); + return result; +} + +template < + typename CTYPE_COMMON, + const char* op_name, + std::enable_if_t, bool> = true> +load_to_common_fn get_load_to_common_fn_same_as_common( + const Tensor& t) { + return get_load_to_common_fn_same_as_compute(t); +} + template using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*); @@ -105,11 +140,48 @@ get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) { return result; } +template +store_common_to_tensor_fn +get_store_common_to_tensor_fn_same_as_compute(const Tensor& t) { + constexpr auto common_scalar_type = CppTypeToScalarType::value; + ET_CHECK_MSG( + t.scalar_type() == common_scalar_type, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(common_scalar_type), + op_name); + return internal::convert_and_store; +} + +template < + typename CTYPE_COMMON, + const char* op_name, + std::enable_if_t, bool> = true> +store_common_to_tensor_fn +get_store_common_to_tensor_fn_same_as_common(const Tensor& t) { + void (*result)(CTYPE_COMMON, void*) = nullptr; + ET_SWITCH_THREE_TYPES( + Float, Half, BFloat16, t.scalar_type(), unused, op_name, CTYPE, [&]() { + result = internal::convert_and_store; + }); + return result; +} + +template < + typename CTYPE_COMMON, + const char* op_name, + std::enable_if_t, bool> = true> +store_common_to_tensor_fn +get_store_common_to_tensor_fn_same_as_common(const Tensor& t) { + return get_store_common_to_tensor_fn_same_as_compute( + t); +} + } // namespace internal enum class SupportedTensorDtypes { REALHBBF16, BOOL_OR_BYTE, + SAME_AS_COMPUTE, SAME_AS_COMMON, }; @@ -124,16 +196,10 @@ load_to_common_fn get_load_to_common_fn( return get_load_to_common_fn_realhbbf16(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_load_to_common_fn_bool_or_byte(t); - case SupportedTensorDtypes::SAME_AS_COMMON: { - constexpr auto common_scalar_type = - CppTypeToScalarType::value; - ET_CHECK_MSG( - t.scalar_type() == common_scalar_type, - "Unhandled dtype %s for %s", - ::executorch::runtime::toString(common_scalar_type), - op_name); - return internal::load_and_convert; - } + case SupportedTensorDtypes::SAME_AS_COMPUTE: + return get_load_to_common_fn_same_as_compute(t); + case SupportedTensorDtypes::SAME_AS_COMMON: + return get_load_to_common_fn_same_as_common(t); } ET_CHECK(false); return nullptr; @@ -149,15 +215,14 @@ store_common_to_tensor_fn get_store_common_to_tensor_fn( case SupportedTensorDtypes::BOOL_OR_BYTE: return get_store_common_to_tensor_fn_bool_or_byte( t); + case SupportedTensorDtypes::SAME_AS_COMPUTE: + return get_store_common_to_tensor_fn_same_as_compute< + CTYPE_COMMON, + op_name>(t); case SupportedTensorDtypes::SAME_AS_COMMON: { - constexpr auto common_scalar_type = - CppTypeToScalarType::value; - ET_CHECK_MSG( - t.scalar_type() == common_scalar_type, - "Unhandled dtype %s for %s", - ::executorch::runtime::toString(common_scalar_type), - op_name); - return internal::convert_and_store; + return get_store_common_to_tensor_fn_same_as_common< + CTYPE_COMMON, + op_name>(t); } } ET_CHECK(false); @@ -330,6 +395,14 @@ inline void apply_tritensor_elementwise_fn( } } +inline ScalarType get_compute_type(ScalarType& common_type) { + ScalarType compute_type = common_type; + if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) { + compute_type = ScalarType::Float; + } + return compute_type; +} + } // namespace utils } // namespace native } // namespace executor