From 9eaf41865c5d1d32563e290571be5a57b650f057 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 27 Aug 2025 14:41:04 -0700 Subject: [PATCH] Refactor op_div: fix bug + enable Half/Bfloat16 Reviewed By: SS-JIA Differential Revision: D81169893 --- kernels/optimized/cpu/op_div.cpp | 179 ++++++++---------- kernels/test/op_div_test.cpp | 22 ++- .../optimized/op_registration_util.bzl | 2 + 3 files changed, 101 insertions(+), 102 deletions(-) diff --git a/kernels/optimized/cpu/op_div.cpp b/kernels/optimized/cpu/op_div.cpp index 7af2b4b4695..d74a293af8a 100644 --- a/kernels/optimized/cpu/op_div.cpp +++ b/kernels/optimized/cpu/op_div.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include @@ -20,7 +20,7 @@ namespace native { namespace { -ScalarType get_compute_type(ScalarType a_type, ScalarType b_type) { +ScalarType get_common_type(ScalarType a_type, ScalarType b_type) { ET_CHECK( !isComplexType(a_type) && !isQIntType(a_type) && !isBitsType(a_type)); ET_CHECK( @@ -43,14 +43,27 @@ Tensor& opt_div_out( const Tensor& a, const Tensor& b, Tensor& out) { - (void)ctx; + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "div.out"; ScalarType a_type = a.scalar_type(); ScalarType b_type = b.scalar_type(); ScalarType out_type = out.scalar_type(); if (a.numel() == 1 || b.numel() == 1) { - if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) { + if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half && + a_type != ScalarType::BFloat16) { const Tensor* tensor; const Tensor* scalar; ScalarType tensor_type; @@ -66,13 +79,8 @@ Tensor& opt_div_out( scalar = &b; scalar_type = b_type; } - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - ET_SWITCH_REALB_TYPES(tensor_type, ctx, "div.out", CTYPE, [&]() { - ET_SWITCH_REALB_TYPES(scalar_type, ctx, "div.out", CTYPE_SCALAR, [&]() { + ET_SWITCH_REALB_TYPES(tensor_type, ctx, op_name, CTYPE, [&]() { + ET_SWITCH_REALB_TYPES(scalar_type, ctx, op_name, CTYPE_SCALAR, [&]() { CTYPE_SCALAR scalar_val = *scalar->const_data_ptr(); CTYPE scalar_casted = static_cast(scalar_val); @@ -101,16 +109,7 @@ Tensor& opt_div_out( auto selected_optimized_path = select_optimized_path(a, b, out); if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) { - // Resize for dynamic shape - auto error = resize_tensor(out, a.sizes()); - ET_KERNEL_CHECK_MSG( - ctx, - error == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - - ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "div.out", CTYPE, [&]() { + ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() { using Vec = at::vec::Vectorized; at::vec::map2( [](Vec x, Vec y) { return x / y; }, @@ -122,7 +121,7 @@ Tensor& opt_div_out( } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { // Reason for using alpha is becasuse handle_broadcast_elementwise // is used for add and sub as well: - ET_SWITCH_REALB_TYPES(out_type, ctx, "div.out", CTYPE, [&]() { + ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() { if (selected_optimized_path == ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments || selected_optimized_path == @@ -139,33 +138,21 @@ Tensor& opt_div_out( } }); } else { - ScalarType common_type = get_compute_type(a_type, b_type); - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ET_SWITCH_REALB_TYPES(a_type, ctx, "div.out", CTYPE_A, [&]() { - ET_SWITCH_REALB_TYPES(b_type, ctx, "div.out", CTYPE_B, [&]() { - ET_SWITCH_REALB_TYPES(common_type, ctx, "div.out", CTYPE_IN, [&]() { - ET_SWITCH_REALB_TYPES(out_type, ctx, "div.out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted / b_casted; - - return static_cast(value); - }, - a, - b, - out); - }); - }); - }); + ScalarType common_type = get_common_type(a.scalar_type(), b.scalar_type()); + ScalarType compute_type = utils::get_compute_type(common_type); + + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::FLOATHBF16>( + [](const auto val_a, const auto val_b) { return val_a / val_b; }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out); }); } @@ -177,63 +164,57 @@ Tensor& opt_div_scalar_out( const Tensor& a, const Scalar& b, Tensor& out) { - (void)ctx; - ScalarType a_type = a.scalar_type(); ScalarType b_type = utils::get_scalar_dtype(b); ScalarType common_type = isFloatingType(a_type) ? a_type : ScalarType::Float; ScalarType out_type = out.scalar_type(); - ET_CHECK(common_type == out_type); - - // Resize for dynamic shape - auto error = resize_tensor(out, a.sizes()); - ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor."); - - if (a_type == common_type && a_type == out_type) { - ET_SWITCH_REAL_TYPES(a_type, ctx, "div.Scalar_out", CTYPE, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, b_type, ctx, "div.Scalar_out", CTYPE_B, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE b_casted = static_cast(b_val); - - using Vec = at::vec::Vectorized; - Vec inv_b_casted_vec(CTYPE(1) / b_casted); - at::vec::map( - [inv_b_casted_vec](Vec x) { return x * inv_b_casted_vec; }, - out.mutable_data_ptr(), - a.const_data_ptr(), - out.numel()); - }); + // Check Common Dtype + ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "div.Scalar_out"; + + if (a_type == common_type && a_type == out_type && + a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { + ET_SWITCH_REAL_TYPES(a_type, ctx, op_name, CTYPE, [&]() { + ET_SWITCH_REALB_TYPES(b_type, ctx, op_name, CTYPE_B, [&]() { + CTYPE_B b_val; + ET_EXTRACT_SCALAR(b, b_val); + CTYPE b_casted = static_cast(b_val); + + using Vec = at::vec::Vectorized; + Vec inv_b_casted_vec(CTYPE(1) / b_casted); + at::vec::map( + [inv_b_casted_vec](Vec x) { return x * inv_b_casted_vec; }, + out.mutable_data_ptr(), + a.const_data_ptr(), + out.numel()); + }); }); } else { - ET_SWITCH_REAL_TYPES_AND( - Bool, a_type, ctx, "div.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, b_type, ctx, "div.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES( - common_type, ctx, "div.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES( - out_type, ctx, "div.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE_IN b_casted = static_cast(b_val); - CTYPE_IN inv_b_casted = CTYPE_IN(1) / b_casted; - - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = - out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) * - inv_b_casted); - } - }); - }); - }); - }); + ScalarType compute_type = utils::get_compute_type(common_type); + + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + utils::apply_unitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::SAME_AS_COMMON>( + [val_b](const auto val_a) { return val_a / val_b; }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out); + }); } return out; diff --git a/kernels/test/op_div_test.cpp b/kernels/test/op_div_test.cpp index 8f41419a8e0..61c15675bc8 100644 --- a/kernels/test/op_div_test.cpp +++ b/kernels/test/op_div_test.cpp @@ -54,7 +54,7 @@ class OpDivOutTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_div(); - ET_FORALL_FLOAT_TYPES(ENUMERATE_TEST_ENTRY) + ET_FORALL_FLOATHBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } @@ -64,7 +64,7 @@ class OpDivOutTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_div_enumerate_out_types(); - ET_FORALL_REAL_TYPES(ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } @@ -183,7 +183,7 @@ void OpDivOutTest::test_div_enumerate_a_types() { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_div_enumerate_b_types(); - ET_FORALL_REAL_TYPES(ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY) test_div(); @@ -283,6 +283,22 @@ TEST_F(OpDivOutTest, BroadcastScalarSupported2) { EXPECT_TENSOR_EQ(out, ret); } +TEST_F(OpDivOutTest, BroadcastSupported3) { + TensorFactory tf; + + Tensor a = tf.make({5}, {2, 3, 4, 5, 6}); + Tensor b = tf.make({1, 5}, {2, 1, 2, 2, 3}); + + // Destination for the broadcasting div. Follow the broadcasting rules in + // https://fburl.com/n9wl4d0o + Tensor out = tf.zeros({1, 5}); + + op_div_out(a, b, out); + + Tensor ret = tf.make({1, 5}, {1, 3, 2, 2.5, 2}); + EXPECT_TENSOR_EQ(out, ret); +} + TEST_F(OpDivOutTest, BroadcastScalarRank0Supported) { TensorFactory tf; diff --git a/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl index 7d9b1a0c317..8c2a5a417ef 100644 --- a/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl @@ -180,6 +180,8 @@ OPTIMIZED_ATEN_OPS = ( ":binary_ops", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", ], ),