diff --git a/kernels/portable/cpu/op_add.cpp b/kernels/portable/cpu/op_add.cpp index 05b5eab6f03..c9c6eda9324 100644 --- a/kernels/portable/cpu/op_add.cpp +++ b/kernels/portable/cpu/op_add.cpp @@ -24,36 +24,44 @@ Tensor& add_out( Tensor& out) { ET_KERNEL_CHECK( ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, + (executorch::runtime::tensor_is_realhbbf16_type(a) && + executorch::runtime::tensor_is_realhbbf16_type(b) && + executorch::runtime::tensor_is_realhbbf16_type(out)), InvalidArgument, out); + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype ET_KERNEL_CHECK( ctx, - (executorch::runtime::tensor_is_realhbbf16_type(a) && - executorch::runtime::tensor_is_realhbbf16_type(b) && - executorch::runtime::tensor_is_realhbbf16_type(out)), + (canCast(common_type, out.scalar_type()) && + check_alpha_type(utils::get_scalar_dtype(alpha), common_type)), InvalidArgument, out); + + // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType alpha_type = utils::get_scalar_dtype(alpha); - ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); + // Resize ET_KERNEL_CHECK( - ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out); + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "add.out"; - ET_SWITCH_REALB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() { - utils::apply_bitensor_elementwise_fn( - [alpha](const CTYPE_COMMON val_a, const CTYPE_COMMON val_b) { - CTYPE_COMMON val_alpha = utils::scalar_to(alpha); + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + utils::apply_bitensor_elementwise_fn( + [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { return val_a + val_alpha * val_b; }, a, @@ -73,52 +81,49 @@ Tensor& add_scalar_out( const Scalar& b, const Scalar& alpha, Tensor& out) { - (void)ctx; - - // Resize for dynamic shape - ET_KERNEL_CHECK_MSG( + ET_KERNEL_CHECK( ctx, - resize_tensor(out, a.sizes()) == Error::Ok, + (executorch::runtime::tensor_is_realhbbf16_type(a) && + executorch::runtime::tensor_is_realhbbf16_type(out)), InvalidArgument, - out, - "Failed to resize output tensor."); + out); + // Common Dtype + ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); + + // Check Common Dtype ET_KERNEL_CHECK( ctx, - (executorch::runtime::tensor_is_realhbbf16_type(a) && - executorch::runtime::tensor_is_realhbbf16_type(out)), + (common_type == out.scalar_type() && + check_alpha_type(utils::get_scalar_dtype(alpha), common_type)), InvalidArgument, out); + + // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - ScalarType a_type = a.scalar_type(); - ScalarType alpha_type = utils::get_scalar_dtype(alpha); - ScalarType common_type = - utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); + // Resize ET_KERNEL_CHECK( - ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out); + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); - if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) { - common_type = ScalarType::Float; - } + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "add.Scalar_out"; - ET_SWITCH_REALB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() { - utils::apply_unitensor_elementwise_fn( - [b, alpha](const CTYPE_COMMON val_a) { - CTYPE_COMMON val_b = utils::scalar_to(b); - CTYPE_COMMON val_alpha = utils::scalar_to(alpha); + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_unitensor_elementwise_fn( + [b, alpha](const CTYPE_COMPUTE val_a) { + CTYPE_COMPUTE val_b = utils::scalar_to(b); + CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); return val_a + val_alpha * val_b; }, a, utils::SupportedTensorDtypes::REALHBBF16, out, - utils::SupportedTensorDtypes::REALHBBF16); + utils::SupportedTensorDtypes::SAME_AS_COMMON); }); return out; diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index c8bb9297a13..35218cbb599 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -73,38 +73,52 @@ Tensor& clamp_out( const exec_aten::optional& min_opt, const exec_aten::optional& max_opt, Tensor& out) { - (void)ctx; + ET_KERNEL_CHECK( + ctx, + (executorch::runtime::tensor_is_realhbbf16_type(in) && + executorch::runtime::tensor_is_realhbbf16_type(out)), + InvalidArgument, + out); + + bool has_min = min_opt.has_value(); + bool has_max = max_opt.has_value(); ET_KERNEL_CHECK_MSG( ctx, - resize_tensor(out, in.sizes()) == Error::Ok, + has_min || has_max, InvalidArgument, out, - "Failed to resize output tensor."); - - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + "At least one of 'min' or 'max' must not be None"); + // Input Dtypes ScalarType in_type = in.scalar_type(); - ScalarType min_type = in_type; - ScalarType max_type = in_type; - ScalarType common_type = in_type; + ScalarType min_type = + has_min ? utils::get_scalar_dtype(min_opt.value()) : in_type; + ScalarType max_type = + has_max ? utils::get_scalar_dtype(max_opt.value()) : in_type; ScalarType out_type = out.scalar_type(); - bool has_min = min_opt.has_value(); + // Common Dtype + ScalarType common_type = in_type; if (has_min) { - min_type = utils::get_scalar_dtype(min_opt.value()); common_type = utils::promote_type_with_scalar(common_type, min_opt.value()); + } + if (has_max) { + common_type = utils::promote_type_with_scalar(common_type, max_opt.value()); + } + + // Check Common Dtype + ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); + + // Check Scalar Bounds + if (has_min) { ET_KERNEL_CHECK( ctx, check_bounds(min_opt.value(), min_type, out_type, "minimum"), InvalidArgument, out); } - bool has_max = max_opt.has_value(); if (has_max) { - max_type = utils::get_scalar_dtype(max_opt.value()); - common_type = utils::promote_type_with_scalar(common_type, max_opt.value()); ET_KERNEL_CHECK( ctx, check_bounds(max_opt.value(), max_type, out_type, "maximum"), @@ -112,35 +126,38 @@ Tensor& clamp_out( out); } - ET_KERNEL_CHECK_MSG( - ctx, - has_min || has_max, - InvalidArgument, - out, - "At least one of 'min' or 'max' must not be None"); + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); - ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "clamp.out"; - ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() { - utils::apply_unitensor_elementwise_fn( - [has_min, min_opt, has_max, max_opt](const CTYPE_COMMON val_in) { - CTYPE_COMMON val_out = val_in; + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_unitensor_elementwise_fn( + [has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) { + CTYPE_COMPUTE val_out = val_in; if (has_min) { val_out = utils::max_override( - val_out, utils::scalar_to(min_opt.value())); + val_out, utils::scalar_to(min_opt.value())); } if (has_max) { val_out = utils::min_override( - val_out, utils::scalar_to(max_opt.value())); + val_out, utils::scalar_to(max_opt.value())); } return val_out; }, in, utils::SupportedTensorDtypes::REALHBBF16, out, - utils::SupportedTensorDtypes::REALHBBF16); + utils::SupportedTensorDtypes::SAME_AS_COMMON); }); return out; @@ -152,8 +169,6 @@ Tensor& clamp_tensor_out( const exec_aten::optional& min_opt, const exec_aten::optional& max_opt, Tensor& out) { - (void)ctx; - bool has_min = min_opt.has_value(); bool has_max = max_opt.has_value(); @@ -167,42 +182,55 @@ Tensor& clamp_tensor_out( const Tensor& min = has_min ? min_opt.value() : in; const Tensor& max = has_max ? max_opt.value() : in; + ET_KERNEL_CHECK( + ctx, + (executorch::runtime::tensor_is_realhbbf16_type(in) && + executorch::runtime::tensor_is_realhbbf16_type(min) && + executorch::runtime::tensor_is_realhbbf16_type(max) && + executorch::runtime::tensor_is_realhbbf16_type(out)), + InvalidArgument, + out); + + // Common Dtype + ScalarType common_type = in.scalar_type(); + if (has_min) { + common_type = promoteTypes(common_type, min.scalar_type()); + } + if (has_max) { + common_type = promoteTypes(common_type, max.scalar_type()); + } + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out); + + // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, min, max, out), InvalidArgument, out); + // Resize ET_KERNEL_CHECK( ctx, resize_to_broadcast_target_size(in, min, max, out) == Error::Ok, InvalidArgument, out); - ScalarType in_type = in.scalar_type(); - ScalarType min_type = min.scalar_type(); - ScalarType max_type = max.scalar_type(); - ScalarType common_type = in_type; - ScalarType out_type = out.scalar_type(); - - if (has_min) { - common_type = promoteTypes(common_type, min_type, /*half_to_float*/ true); - } - if (has_max) { - common_type = promoteTypes(common_type, max_type, /*half_to_float*/ true); - } - - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "clamp.Tensor_out"; - ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() { - utils::apply_tritensor_elementwise_fn( + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_tritensor_elementwise_fn( [has_min, has_max]( - const CTYPE_COMMON val_in, - const CTYPE_COMMON val_min, - const CTYPE_COMMON val_max) { - CTYPE_COMMON val_out = val_in; + const CTYPE_COMPUTE val_in, + const CTYPE_COMPUTE val_min, + const CTYPE_COMPUTE val_max) { + CTYPE_COMPUTE val_out = val_in; if (has_min) { val_out = utils::max_override(val_out, val_min); } diff --git a/kernels/portable/cpu/op_where.cpp b/kernels/portable/cpu/op_where.cpp index 3765590ee83..983cbc8cbb9 100644 --- a/kernels/portable/cpu/op_where.cpp +++ b/kernels/portable/cpu/op_where.cpp @@ -19,36 +19,44 @@ Tensor& where_out( const Tensor& a, const Tensor& b, Tensor& out) { - ScalarType cond_type = cond.scalar_type(); - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType common_type = promoteTypes(a_type, b_type); - ScalarType out_type = out.scalar_type(); + ET_KERNEL_CHECK( + ctx, + ((cond.scalar_type() == ScalarType::Bool || + cond.scalar_type() == ScalarType::Byte) && + executorch::runtime::tensor_is_realhbbf16_type(a) && + executorch::runtime::tensor_is_realhbbf16_type(b) && + executorch::runtime::tensor_is_realhbbf16_type(out)), + InvalidArgument, + out); + + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); - ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); + // Check Common Dtype + ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out); - // Determine output size and resize for dynamic shapes + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(cond, a, b, out), InvalidArgument, out); + + // Resize ET_KERNEL_CHECK( ctx, resize_to_broadcast_target_size(a, b, cond, out) == Error::Ok, InvalidArgument, out); - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(cond, a, b, out), InvalidArgument, out); + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "where.self_out"; - ET_CHECK_MSG( - cond_type == ScalarType::Bool || cond_type == ScalarType::Byte, - "Unhandled dtype %s for where.self_out", - torch::executor::toString(cond_type)); - - ET_SWITCH_REALHBBF16_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() { - utils::apply_tritensor_elementwise_fn( - [](const CTYPE_COMMON val_a, - const CTYPE_COMMON val_b, - const CTYPE_COMMON val_c) { return val_c ? val_a : val_b; }, + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_tritensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, + const CTYPE_COMPUTE val_b, + const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; }, a, utils::SupportedTensorDtypes::REALHBBF16, b, diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 19dcbd73ba0..52ad6fca116 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -80,6 +80,41 @@ load_to_common_fn get_load_to_common_fn_bool_or_byte( return result; } +template +load_to_common_fn get_load_to_common_fn_same_as_compute( + const Tensor& t) { + constexpr auto common_scalar_type = CppTypeToScalarType::value; + ET_CHECK_MSG( + t.scalar_type() == common_scalar_type, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(common_scalar_type), + op_name); + return internal::load_and_convert; +} + +template < + typename CTYPE_COMMON, + const char* op_name, + std::enable_if_t, bool> = true> +load_to_common_fn get_load_to_common_fn_same_as_common( + const Tensor& t) { + CTYPE_COMMON (*result)(const void*) = nullptr; + ET_SWITCH_THREE_TYPES( + Float, Half, BFloat16, t.scalar_type(), unused, op_name, T, [&]() { + result = internal::load_and_convert; + }); + return result; +} + +template < + typename CTYPE_COMMON, + const char* op_name, + std::enable_if_t, bool> = true> +load_to_common_fn get_load_to_common_fn_same_as_common( + const Tensor& t) { + return get_load_to_common_fn_same_as_compute(t); +} + template using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*); @@ -105,11 +140,48 @@ get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) { return result; } +template +store_common_to_tensor_fn +get_store_common_to_tensor_fn_same_as_compute(const Tensor& t) { + constexpr auto common_scalar_type = CppTypeToScalarType::value; + ET_CHECK_MSG( + t.scalar_type() == common_scalar_type, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(common_scalar_type), + op_name); + return internal::convert_and_store; +} + +template < + typename CTYPE_COMMON, + const char* op_name, + std::enable_if_t, bool> = true> +store_common_to_tensor_fn +get_store_common_to_tensor_fn_same_as_common(const Tensor& t) { + void (*result)(CTYPE_COMMON, void*) = nullptr; + ET_SWITCH_THREE_TYPES( + Float, Half, BFloat16, t.scalar_type(), unused, op_name, CTYPE, [&]() { + result = internal::convert_and_store; + }); + return result; +} + +template < + typename CTYPE_COMMON, + const char* op_name, + std::enable_if_t, bool> = true> +store_common_to_tensor_fn +get_store_common_to_tensor_fn_same_as_common(const Tensor& t) { + return get_store_common_to_tensor_fn_same_as_compute( + t); +} + } // namespace internal enum class SupportedTensorDtypes { REALHBBF16, BOOL_OR_BYTE, + SAME_AS_COMPUTE, SAME_AS_COMMON, }; @@ -124,16 +196,10 @@ load_to_common_fn get_load_to_common_fn( return get_load_to_common_fn_realhbbf16(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_load_to_common_fn_bool_or_byte(t); - case SupportedTensorDtypes::SAME_AS_COMMON: { - constexpr auto common_scalar_type = - CppTypeToScalarType::value; - ET_CHECK_MSG( - t.scalar_type() == common_scalar_type, - "Unhandled dtype %s for %s", - ::executorch::runtime::toString(common_scalar_type), - op_name); - return internal::load_and_convert; - } + case SupportedTensorDtypes::SAME_AS_COMPUTE: + return get_load_to_common_fn_same_as_compute(t); + case SupportedTensorDtypes::SAME_AS_COMMON: + return get_load_to_common_fn_same_as_common(t); } ET_CHECK(false); return nullptr; @@ -149,15 +215,14 @@ store_common_to_tensor_fn get_store_common_to_tensor_fn( case SupportedTensorDtypes::BOOL_OR_BYTE: return get_store_common_to_tensor_fn_bool_or_byte( t); + case SupportedTensorDtypes::SAME_AS_COMPUTE: + return get_store_common_to_tensor_fn_same_as_compute< + CTYPE_COMMON, + op_name>(t); case SupportedTensorDtypes::SAME_AS_COMMON: { - constexpr auto common_scalar_type = - CppTypeToScalarType::value; - ET_CHECK_MSG( - t.scalar_type() == common_scalar_type, - "Unhandled dtype %s for %s", - ::executorch::runtime::toString(common_scalar_type), - op_name); - return internal::convert_and_store; + return get_store_common_to_tensor_fn_same_as_common< + CTYPE_COMMON, + op_name>(t); } } ET_CHECK(false); @@ -330,6 +395,14 @@ inline void apply_tritensor_elementwise_fn( } } +inline ScalarType get_compute_type(ScalarType& common_type) { + ScalarType compute_type = common_type; + if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) { + compute_type = ScalarType::Float; + } + return compute_type; +} + } // namespace utils } // namespace native } // namespace executor