diff --git a/kernels/optimized/cpu/op_mul.cpp b/kernels/optimized/cpu/op_mul.cpp index 0d132ab1e03..48670b7441b 100644 --- a/kernels/optimized/cpu/op_mul.cpp +++ b/kernels/optimized/cpu/op_mul.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include // IWYU pragma: export #include #include @@ -22,76 +22,35 @@ namespace native { using Tensor = executorch::aten::Tensor; using ScalarType = executorch::aten::ScalarType; -namespace { - -template < - bool can_cast, - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct MulInner; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct MulInner { - static void run(const Tensor& a, const Tensor& b, Tensor& out) { - apply_binary_elementwise_fn( - // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted * b_casted; - - return static_cast(value); - }, - a, - b, - out); - } -}; - -struct ReportCanCastBug { - static void run(const Tensor&, const Tensor&, Tensor&) { - ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); - } -}; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct MulInner - : public ReportCanCastBug {}; - -} // namespace - Tensor& opt_mul_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { - (void)ctx; - ScalarType a_type = a.scalar_type(); ScalarType b_type = b.scalar_type(); ScalarType out_type = out.scalar_type(); + ScalarType common_type = promoteTypes(a_type, b_type); + + ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "mul.out"; if (b.numel() == 1) { if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.out", CTYPE, [&]() { - ET_SWITCH_REALB_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() { + ET_SWITCH_REALB_TYPES(a_type, ctx, op_name, CTYPE, [&]() { + ET_SWITCH_REALB_TYPES(b_type, ctx, op_name, CTYPE_B, [&]() { CTYPE_B b_val = *b.const_data_ptr(); CTYPE b_casted = static_cast(b_val); @@ -111,17 +70,11 @@ Tensor& opt_mul_out( auto selected_optimized_path = select_optimized_path(a, b, out); if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) { - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - if (executorch::runtime::isComplexType(out_type)) { ET_KERNEL_CHECK( ctx, a_type == b_type && a_type == out_type, InvalidArgument, out); - ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { + ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, op_name, CTYPE, [&]() { using Vec = at::vec::Vectorized; at::vec::map2( [](Vec x, Vec y) { return x * y; }, @@ -131,7 +84,7 @@ Tensor& opt_mul_out( out.numel()); }); } else { - ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { + ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() { using Vec = at::vec::Vectorized; at::vec::map2( [](Vec x, Vec y) { return x * y; }, @@ -146,36 +99,26 @@ Tensor& opt_mul_out( ET_KERNEL_CHECK( ctx, a_type == b_type && a_type == out_type, InvalidArgument, out); - ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { + ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, op_name, CTYPE, [&]() { auto mul_lambda = [](auto x, auto y) { return x * y; }; torch::executor::handle_broadcast_elementwise( ctx, mul_lambda, a, b, out, selected_optimized_path); }); } else { - ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { + ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() { auto mul_lambda = [](auto x, auto y) { return x * y; }; torch::executor::handle_broadcast_elementwise( ctx, mul_lambda, a, b, out, selected_optimized_path); }); } } else { - ScalarType common_type = - promoteTypes(a_type, b_type, /*half_to_float*/ true); - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - if (executorch::runtime::isComplexType(a_type) || executorch::runtime::isComplexType(b_type) || executorch::runtime::isComplexType(out_type)) { ET_KERNEL_CHECK( ctx, a_type == b_type && a_type == out_type, InvalidArgument, out); - ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { + ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, op_name, CTYPE, [&]() { apply_binary_elementwise_fn( [](const CTYPE val_a, const CTYPE val_b) { return val_a * val_b; }, a, @@ -183,26 +126,20 @@ Tensor& opt_mul_out( out); }); } else { - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() { - ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() { - using CTYPE_IN = typename torch::executor:: - promote_types::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REALHBBF16_TYPES( - out_type, ctx, "mul.out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted * b_casted; - - return static_cast(value); - }, - a, - b, - out); - }); - }); + ScalarType compute_type = utils::internal::get_compute_type(common_type); + + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBBF16>( + [](const auto val_a, const auto val_b) { return val_a * val_b; }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out); }); } } @@ -215,26 +152,24 @@ Tensor& opt_mul_scalar_out( const Tensor& a, const Scalar& b, Tensor& out) { - (void)ctx; - ScalarType a_type = a.scalar_type(); - ScalarType common_type = - utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false); + ScalarType common_type = utils::promote_type_with_scalar(a_type, b); ScalarType out_type = out.scalar_type(); - ET_CHECK(common_type == out_type); + ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); - if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) { - common_type = ScalarType::Float; - } + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - // Resize for dynamic shape - auto error = resize_tensor(out, a.sizes()); - ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor."); + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "mul.Scalar_out"; if (a_type == common_type && a_type == out_type && a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { - ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE, [&]() { + ET_SWITCH_REALB_TYPES(a_type, ctx, op_name, CTYPE, [&]() { CTYPE b_casted = utils::scalar_to(b); using Vec = at::vec::Vectorized; @@ -245,22 +180,19 @@ Tensor& opt_mul_scalar_out( out.numel()); }); } else { - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_REALB_TYPES( - common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALHBBF16_TYPES( - out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_IN b_casted = utils::scalar_to(b); - - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) * b_casted); - } - }); - }); + ScalarType compute_type = utils::internal::get_compute_type(common_type); + + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + utils::apply_unitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::SAME_AS_COMMON>( + [val_b](const auto val_a) { return val_a * val_b; }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out); }); } diff --git a/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl index 8c2a5a417ef..e6849182e52 100644 --- a/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl @@ -261,6 +261,8 @@ OPTIMIZED_ATEN_OPS = ( ":binary_ops", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/runtime/core/exec_aten/util:tensor_util", "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", ],