diff --git a/kernels/optimized/cpu/op_add.cpp b/kernels/optimized/cpu/op_add.cpp index 88b102b5650..562d4e227dd 100644 --- a/kernels/optimized/cpu/op_add.cpp +++ b/kernels/optimized/cpu/op_add.cpp @@ -10,7 +10,8 @@ #include #include #include -#include +#include +#include #include #include @@ -31,6 +32,26 @@ Tensor& opt_add_out( ScalarType a_type = a.scalar_type(); ScalarType b_type = b.scalar_type(); ScalarType out_type = out.scalar_type(); + ScalarType common_type = promoteTypes(a_type, b_type); + + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out_type) && + check_alpha_type(utils::get_scalar_dtype(alpha), common_type)), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "add.out"; if (b.numel() == 1) { if (executorch::runtime::isComplexType(a_type) || @@ -40,13 +61,8 @@ Tensor& opt_add_out( // output tensors have the same dtype. Support mixed dtypes in the future. ET_KERNEL_CHECK( ctx, a_type == b_type && a_type == out_type, InvalidArgument, out); - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "add.out", CTYPE, [&]() { + ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, op_name, CTYPE, [&]() { CTYPE alpha_val = utils::scalar_to(alpha); CTYPE b_val = *b.const_data_ptr(); @@ -61,14 +77,8 @@ Tensor& opt_add_out( } else if ( a_type == b_type && a_type == out_type && a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ET_SWITCH_REALB_TYPES(a_type, ctx, "add.out", CTYPE, [&]() { - ET_SWITCH_REALB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() { + ET_SWITCH_REALB_TYPES(a_type, ctx, op_name, CTYPE, [&]() { + ET_SWITCH_REALB_TYPES(b_type, ctx, op_name, CTYPE_B, [&]() { CTYPE alpha_val; ET_KERNEL_CHECK( ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); @@ -91,7 +101,6 @@ Tensor& opt_add_out( return opt_add_out(ctx, b, a, alpha, out); } - static constexpr const char op_name[] = "add.out"; return torch::executor::kernels::impl::opt_add_sub_out_impl( ctx, a, b, alpha, out); } @@ -102,26 +111,29 @@ Tensor& opt_add_scalar_out( const Scalar& b, const Scalar& alpha, Tensor& out) { - (void)ctx; - ScalarType a_type = a.scalar_type(); - ScalarType common_type = - utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false); + ScalarType common_type = utils::promote_type_with_scalar(a_type, b); ScalarType out_type = out.scalar_type(); - ET_CHECK(common_type == out_type); + ET_KERNEL_CHECK( + ctx, + (common_type == a_type && + check_alpha_type(utils::get_scalar_dtype(alpha), common_type)), + InvalidArgument, + out); - if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) { - common_type = ScalarType::Float; - } + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); - // Resize for dynamic shape - auto error = resize_tensor(out, a.sizes()); - ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor."); + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "add.Scalar_out"; if (a_type == common_type && a_type == out_type && a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { - ET_SWITCH_REALB_TYPES(a_type, ctx, "add.Scalar_out", CTYPE, [&]() { + ET_SWITCH_REALB_TYPES(a_type, ctx, op_name, CTYPE, [&]() { CTYPE b_casted = utils::scalar_to(b); CTYPE alpha_val; ET_KERNEL_CHECK( @@ -137,28 +149,28 @@ Tensor& opt_add_scalar_out( out.numel()); }); } else { - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_REALB_TYPES( - common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALHBBF16_TYPES( - out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_IN b_casted = utils::scalar_to(b); - CTYPE_IN alpha_val; - ET_KERNEL_CHECK( - ctx, - utils::extract_scalar(alpha, &alpha_val), - InvalidArgument, ); - - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) + - alpha_val * b_casted); - } - }); - }); + ScalarType compute_type = utils::internal::get_compute_type(common_type); + + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + CTYPE_COMPUTE val_b = utils::scalar_to(b); + CTYPE_COMPUTE val_alpha; + ET_KERNEL_CHECK( + ctx, utils::extract_scalar(alpha, &val_alpha), InvalidArgument, ); + auto val_alpha_times_b = val_alpha * val_b; + utils::apply_unitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::SAME_AS_COMMON>( + [val_alpha_times_b](const auto val_a) { + // Cast here supports vectorization; either it does nothing + // or it casts from CTYPE_COMPUTE to + // Vectorized. + return val_a + decltype(val_a)(val_alpha_times_b); + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out); }); } diff --git a/kernels/optimized/cpu/op_add_sub_impl.h b/kernels/optimized/cpu/op_add_sub_impl.h index b3dcd41d74b..d15c143770d 100644 --- a/kernels/optimized/cpu/op_add_sub_impl.h +++ b/kernels/optimized/cpu/op_add_sub_impl.h @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include @@ -19,55 +19,6 @@ namespace executor { namespace kernels { namespace impl { -namespace { -template < - bool can_cast, - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct AddInner; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct AddInner { - static void - run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) { - apply_binary_elementwise_fn( - // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) - [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted + alpha_val * b_casted; - - return static_cast(value); - }, - a, - b, - out); - } -}; - -template -struct ReportCanCastBug { - static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) { - ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); - } -}; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct AddInner - : public ReportCanCastBug {}; - -} // namespace - using Tensor = executorch::aten::Tensor; using ScalarType = executorch::aten::ScalarType; @@ -78,8 +29,6 @@ Tensor& opt_add_sub_out_impl( const Tensor& b, const Scalar& alpha, Tensor& out) { - (void)ctx; - ScalarType a_type = a.scalar_type(); ScalarType b_type = b.scalar_type(); ScalarType out_type = out.scalar_type(); @@ -115,14 +64,6 @@ Tensor& opt_add_sub_out_impl( } if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) { - // Resize for dynamic shape - ET_KERNEL_CHECK_MSG( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - ET_SWITCH_REALB_TYPES(a_type, ctx, op_name, CTYPE, [&]() { CTYPE alpha_val; ET_KERNEL_CHECK( @@ -202,39 +143,32 @@ Tensor& opt_add_sub_out_impl( } }); } else { - ScalarType common_type = - promoteTypes(a_type, b_type, /*half_to_float*/ true); - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); + ScalarType common_type = promoteTypes(a_type, b_type); + ScalarType compute_type = + native::utils::internal::get_compute_type(common_type); - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, op_name, CTYPE_A, [&]() { - ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, op_name, CTYPE_B, [&]() { - using CTYPE_IN = typename torch::executor:: - promote_types::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, op_name, CTYPE_OUT, [&]() { - CTYPE_IN alpha_val; - ET_KERNEL_CHECK( - ctx, - torch::executor::native::utils::extract_scalar(alpha, &alpha_val), - InvalidArgument, ); - if constexpr (is_sub) { - alpha_val = -alpha_val; - } - - AddInner< - can_cast::value, - CTYPE_A, - CTYPE_B, - CTYPE_IN, - CTYPE_OUT>::run(a, b, alpha_val, out); - }); - }); + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + CTYPE_COMPUTE val_alpha; + ET_KERNEL_CHECK( + ctx, + native::utils::extract_scalar(alpha, &val_alpha), + InvalidArgument, ); + if constexpr (is_sub) { + val_alpha = -val_alpha; + } + native::utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + native::utils::SupportedTensorDtypes::REALHBBF16>( + [val_alpha](const auto val_a, const auto val_b) { + return val_a + val_alpha * val_b; + }, + ctx, + a, + native::utils::SupportedTensorDtypes::REALHBBF16, + b, + native::utils::SupportedTensorDtypes::REALHBBF16, + out); }); } diff --git a/kernels/optimized/cpu/op_sub.cpp b/kernels/optimized/cpu/op_sub.cpp index 58f8d2a7fdf..41d46d1661e 100644 --- a/kernels/optimized/cpu/op_sub.cpp +++ b/kernels/optimized/cpu/op_sub.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include @@ -20,55 +20,6 @@ namespace torch { namespace executor { namespace native { -namespace { - -template < - bool can_cast, - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct SubInner; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct SubInner { - static void - run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) { - apply_binary_elementwise_fn( - // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) - [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted - alpha_val * b_casted; - - return static_cast(value); - }, - a, - b, - out); - } -}; - -template -struct ReportCanCastBug { - static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) { - ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); - } -}; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct SubInner - : public ReportCanCastBug {}; - -} // namespace using Tensor = executorch::aten::Tensor; using ScalarType = executorch::aten::ScalarType; @@ -79,19 +30,36 @@ Tensor& opt_sub_out( const Tensor& b, const Scalar& alpha, Tensor& out) { - (void)ctx; - ScalarType a_type = a.scalar_type(); ScalarType b_type = b.scalar_type(); + ScalarType alpha_type = utils::get_scalar_dtype(alpha); ScalarType out_type = out.scalar_type(); + ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out); + + ScalarType common_type = promoteTypes(a_type, b_type); + + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out_type) && canCast(alpha_type, common_type)), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + ET_KERNEL_CHECK( ctx, - executorch::runtime::tensor_is_realhbf16_type(out), + resize_to_broadcast_target_size(a, b, out) == Error::Ok, InvalidArgument, out); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "sub.out"; + if (a.numel() == 1 || b.numel() == 1) { - if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) { + if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half && + a_type != ScalarType::BFloat16) { const Tensor* tensor; const Tensor* scalar; ScalarType tensor_type; @@ -107,13 +75,8 @@ Tensor& opt_sub_out( scalar = &b; scalar_type = b_type; } - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - ET_SWITCH_REAL_TYPES(tensor_type, ctx, "sub.out", CTYPE, [&]() { - ET_SWITCH_REAL_TYPES(scalar_type, ctx, "sub.out", CTYPE_SCALAR, [&]() { + ET_SWITCH_REAL_TYPES(tensor_type, ctx, op_name, CTYPE, [&]() { + ET_SWITCH_REAL_TYPES(scalar_type, ctx, op_name, CTYPE_SCALAR, [&]() { CTYPE alpha_val; ET_KERNEL_CHECK( ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); @@ -144,7 +107,6 @@ Tensor& opt_sub_out( } } - static constexpr const char op_name[] = "sub.out"; return torch::executor::kernels::impl::opt_add_sub_out_impl( ctx, a, b, alpha, out); } @@ -155,26 +117,31 @@ Tensor& opt_sub_scalar_out( const Scalar& b, const Scalar& alpha, Tensor& out) { - (void)ctx; - ScalarType a_type = a.scalar_type(); - ScalarType common_type = - utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false); + ScalarType common_type = utils::promote_type_with_scalar(a_type, b); + ScalarType alpha_type = utils::get_scalar_dtype(alpha); ScalarType out_type = out.scalar_type(); - ET_CHECK(common_type == out_type); + ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out); - if (common_type == ScalarType::Half) { - common_type = ScalarType::Float; - } + ET_KERNEL_CHECK( + ctx, + (common_type == out_type && canCast(alpha_type, common_type)), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); - // Resize for dynamic shape - auto error = resize_tensor(out, a.sizes()); - ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor."); + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "sub.Scalar_out"; if (a_type == common_type && a_type == out_type && a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { - ET_SWITCH_REAL_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE, [&]() { + ET_SWITCH_REAL_TYPES(a_type, ctx, op_name, CTYPE, [&]() { CTYPE b_casted = utils::scalar_to(b); CTYPE alpha_val; ET_KERNEL_CHECK( @@ -190,26 +157,23 @@ Tensor& opt_sub_scalar_out( out.numel()); }); } else { - ET_SWITCH_REALHBF16_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES(common_type, ctx, "sub.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALHBF16_TYPES( - out_type, ctx, "sub.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_IN b_casted = utils::scalar_to(b); - CTYPE_IN alpha_val; - ET_KERNEL_CHECK( - ctx, - utils::extract_scalar(alpha, &alpha_val), - InvalidArgument, ); - - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) - alpha_val * b_casted); - } - }); - }); + ScalarType compute_type = utils::internal::get_compute_type(common_type); + + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + const auto val_alpha_times_b = val_alpha * val_b; + utils::apply_unitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::SAME_AS_COMMON>( + [val_alpha_times_b](const auto val_a) { + return val_a - (decltype(val_a))(val_alpha_times_b); + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBF16, + out); }); } diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index 3d9c5caf815..b5fbff3021e 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -29,6 +29,9 @@ def define_common_targets(): exported_deps = [ "//executorch/runtime/core:core", "//executorch/kernels/portable/cpu/util:broadcast_indexes_range", + "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", ], ) diff --git a/kernels/test/op_sub_test.cpp b/kernels/test/op_sub_test.cpp index aa7d4d51e4e..c8e7c69c443 100644 --- a/kernels/test/op_sub_test.cpp +++ b/kernels/test/op_sub_test.cpp @@ -73,7 +73,7 @@ class OpSubOutTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_sub_enumerate_out_types(); - ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } @@ -208,7 +208,7 @@ class OpSubOutTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_sub_enumerate_b_types(); - ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } diff --git a/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl index 8c2a5a417ef..a3c9d707a67 100644 --- a/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl @@ -153,6 +153,9 @@ OPTIMIZED_ATEN_OPS = ( ":add_sub_impl", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", + "//executorch/kernels/portable/cpu/util:kernel_ops_util", "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", ], ), @@ -280,6 +283,8 @@ OPTIMIZED_ATEN_OPS = ( ":add_sub_impl", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", ], ),