From 6b0e11f5577187637bf7c124fc721e12a2e764f4 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 25 Mar 2025 16:56:48 -0700 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- kernels/portable/cpu/op_add.cpp | 56 ++++++------ kernels/portable/cpu/op_clamp.cpp | 41 ++++----- kernels/portable/cpu/op_div.cpp | 95 ++++++++++---------- kernels/portable/cpu/op_floor_divide.cpp | 42 ++++----- kernels/portable/cpu/op_maximum.cpp | 27 +++--- kernels/portable/cpu/op_minimum.cpp | 27 +++--- kernels/portable/cpu/op_mul.cpp | 48 +++++----- kernels/portable/cpu/op_remainder.cpp | 69 +++++++------- kernels/portable/cpu/op_rsub.cpp | 27 +++--- kernels/portable/cpu/op_sub.cpp | 56 ++++++------ kernels/portable/cpu/op_where.cpp | 31 +++---- kernels/portable/cpu/pattern/comparison_op.h | 23 ++--- kernels/portable/cpu/util/elementwise_util.h | 16 +++- 13 files changed, 294 insertions(+), 264 deletions(-) diff --git a/kernels/portable/cpu/op_add.cpp b/kernels/portable/cpu/op_add.cpp index adb9d4ea723..411861acce3 100644 --- a/kernels/portable/cpu/op_add.cpp +++ b/kernels/portable/cpu/op_add.cpp @@ -50,20 +50,21 @@ Tensor& add_out( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "add.out"; - ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - utils::apply_bitensor_elementwise_fn( - [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - return val_a + val_alpha * val_b; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - b, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); - }); + ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES( + compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + utils::apply_bitensor_elementwise_fn( + [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return val_a + val_alpha * val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); + }); return out; } @@ -99,19 +100,20 @@ Tensor& add_scalar_out( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "add.Scalar_out"; - ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_unitensor_elementwise_fn( - [b, alpha](const CTYPE_COMPUTE val_a) { - CTYPE_COMPUTE val_b = utils::scalar_to(b); - CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - return val_a + val_alpha * val_b; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); - }); + ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES( + compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_unitensor_elementwise_fn( + [b, alpha](const CTYPE_COMPUTE val_a) { + CTYPE_COMPUTE val_b = utils::scalar_to(b); + CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + return val_a + val_alpha * val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); + }); return out; } diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index c1c40a38f34..7a1c9143f40 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -133,26 +133,27 @@ Tensor& clamp_out( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "clamp.out"; - ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_unitensor_elementwise_fn( - [has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) { - CTYPE_COMPUTE val_out = val_in; - if (has_min) { - val_out = utils::max_override( - val_out, utils::scalar_to(min_opt.value())); - } - if (has_max) { - val_out = utils::min_override( - val_out, utils::scalar_to(max_opt.value())); - } - return val_out; - }, - ctx, - in, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); - }); + ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES( + compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_unitensor_elementwise_fn( + [has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) { + CTYPE_COMPUTE val_out = val_in; + if (has_min) { + val_out = utils::max_override( + val_out, utils::scalar_to(min_opt.value())); + } + if (has_max) { + val_out = utils::min_override( + val_out, utils::scalar_to(max_opt.value())); + } + return val_out; + }, + ctx, + in, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); + }); return out; } diff --git a/kernels/portable/cpu/op_div.cpp b/kernels/portable/cpu/op_div.cpp index 94cd9ea5011..0a2bec571cc 100644 --- a/kernels/portable/cpu/op_div.cpp +++ b/kernels/portable/cpu/op_div.cpp @@ -121,34 +121,36 @@ Tensor& div_out_mode( const bool mode_is_trunc = mode_val == "trunc"; bool div_by_zero_error = false; - ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( - [mode_is_trunc, &div_by_zero_error]( - const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - if (is_integral_type::value) { - if (val_b == 0) { - div_by_zero_error = true; - return static_cast(0); - } - } - CTYPE_COMPUTE value = val_a / val_b; - if (mode_is_trunc) { - value = std::trunc(value); - } else { - // We established above that the mode is either trunc or floor, so - // it must be floor. - value = utils::floor_divide(val_a, val_b); - } - return value; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - b, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); - }); + ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES( + compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [mode_is_trunc, &div_by_zero_error]( + const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + if (is_integral_type:: + value) { + if (val_b == 0) { + div_by_zero_error = true; + return static_cast(0); + } + } + CTYPE_COMPUTE value = val_a / val_b; + if (mode_is_trunc) { + value = std::trunc(value); + } else { + // We established above that the mode is either trunc or floor, + // so it must be floor. + value = utils::floor_divide(val_a, val_b); + } + return value; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); + }); ET_KERNEL_CHECK_MSG( ctx, @@ -252,24 +254,25 @@ Tensor& div_scalar_mode_out( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "div.Scalar_mode_out"; - ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn( - [val_b, mode_is_trunc](const CTYPE_COMPUTE val_a) { - CTYPE_COMPUTE value = val_a / val_b; - if (mode_is_trunc) { - value = std::trunc(value); - } else { - value = utils::floor_divide(val_a, val_b); - } - return value; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); - }); + ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES( + compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + utils::apply_unitensor_elementwise_fn( + [val_b, mode_is_trunc](const CTYPE_COMPUTE val_a) { + CTYPE_COMPUTE value = val_a / val_b; + if (mode_is_trunc) { + value = std::trunc(value); + } else { + value = utils::floor_divide(val_a, val_b); + } + return value; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); + }); return out; } diff --git a/kernels/portable/cpu/op_floor_divide.cpp b/kernels/portable/cpu/op_floor_divide.cpp index 85eb612ea1e..59bbdfa4e2a 100644 --- a/kernels/portable/cpu/op_floor_divide.cpp +++ b/kernels/portable/cpu/op_floor_divide.cpp @@ -52,26 +52,28 @@ Tensor& floor_divide_out( bool div_by_zero_error = false; - ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( - [&div_by_zero_error]( - const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - if (is_integral_type::value) { - if (val_b == 0) { - div_by_zero_error = true; - return static_cast(0); - } - } - return utils::floor_divide(val_a, val_b); - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - b, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); - }); + ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES( + compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [&div_by_zero_error]( + const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + if (is_integral_type:: + value) { + if (val_b == 0) { + div_by_zero_error = true; + return static_cast(0); + } + } + return utils::floor_divide(val_a, val_b); + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); + }); ET_KERNEL_CHECK_MSG( ctx, diff --git a/kernels/portable/cpu/op_maximum.cpp b/kernels/portable/cpu/op_maximum.cpp index 5cf3b5a19f8..70d496973e6 100644 --- a/kernels/portable/cpu/op_maximum.cpp +++ b/kernels/portable/cpu/op_maximum.cpp @@ -44,19 +44,20 @@ Tensor& maximum_out( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "maximum.out"; - ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( - [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - return utils::max_override(val_a, val_b); - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - b, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); - }); + ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES( + compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return utils::max_override(val_a, val_b); + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); + }); return out; } diff --git a/kernels/portable/cpu/op_minimum.cpp b/kernels/portable/cpu/op_minimum.cpp index e2c641bdb22..b3ea41771a3 100644 --- a/kernels/portable/cpu/op_minimum.cpp +++ b/kernels/portable/cpu/op_minimum.cpp @@ -44,19 +44,20 @@ Tensor& minimum_out( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "minimum.out"; - ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( - [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - return utils::min_override(val_a, val_b); - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - b, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); - }); + ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES( + compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return utils::min_override(val_a, val_b); + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); + }); return out; } diff --git a/kernels/portable/cpu/op_mul.cpp b/kernels/portable/cpu/op_mul.cpp index 1ee73d342ca..4e050ef2dc5 100644 --- a/kernels/portable/cpu/op_mul.cpp +++ b/kernels/portable/cpu/op_mul.cpp @@ -51,19 +51,20 @@ Tensor& mul_out( InvalidArgument, out); - ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( - [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - return val_a * val_b; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - b, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); - }); + ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES( + compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return val_a * val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); + }); return out; } @@ -93,16 +94,17 @@ Tensor& mul_scalar_out( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "mul.Scalar_out"; - ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn( - [val_b](const CTYPE_COMPUTE val_a) { return val_a * val_b; }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); - }); + ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES( + compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + utils::apply_unitensor_elementwise_fn( + [val_b](const CTYPE_COMPUTE val_a) { return val_a * val_b; }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); + }); return out; } diff --git a/kernels/portable/cpu/op_remainder.cpp b/kernels/portable/cpu/op_remainder.cpp index d34c34a0380..5393ced574c 100644 --- a/kernels/portable/cpu/op_remainder.cpp +++ b/kernels/portable/cpu/op_remainder.cpp @@ -52,28 +52,30 @@ Tensor& remainder_Tensor_out( bool div_by_zero_error = false; - ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( - [&div_by_zero_error]( - const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - CTYPE_COMPUTE value = 0; - if (is_integral_type::value) { - if (val_b == 0) { - div_by_zero_error = true; + ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES( + compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [&div_by_zero_error]( + const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + CTYPE_COMPUTE value = 0; + if (is_integral_type:: + value) { + if (val_b == 0) { + div_by_zero_error = true; + return value; + } + } + value = utils::remainder_override(val_a, val_b); return value; - } - } - value = utils::remainder_override(val_a, val_b); - return value; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - b, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); - }); + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); + }); ET_KERNEL_CHECK_MSG( ctx, @@ -124,18 +126,19 @@ Tensor& remainder_Scalar_out( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "remainder.Scalar_out"; - ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn( - [val_b](const CTYPE_COMPUTE val_a) { - return utils::remainder_override(val_a, val_b); - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); - }); + ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES( + compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + utils::apply_unitensor_elementwise_fn( + [val_b](const CTYPE_COMPUTE val_a) { + return utils::remainder_override(val_a, val_b); + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); + }); return out; } diff --git a/kernels/portable/cpu/op_rsub.cpp b/kernels/portable/cpu/op_rsub.cpp index 46af021efda..5104a79fc5e 100644 --- a/kernels/portable/cpu/op_rsub.cpp +++ b/kernels/portable/cpu/op_rsub.cpp @@ -49,19 +49,20 @@ Tensor& rsub_scalar_out( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "rsub.Scalar_out"; - ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - const CTYPE_COMPUTE val_b = utils::scalar_to(b); - const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - utils::apply_unitensor_elementwise_fn( - [val_b, val_alpha](const CTYPE_COMPUTE val_a) { - return val_b - val_alpha * val_a; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBF16, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); - }); + ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES( + compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + utils::apply_unitensor_elementwise_fn( + [val_b, val_alpha](const CTYPE_COMPUTE val_a) { + return val_b - val_alpha * val_a; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBF16, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); + }); return out; } diff --git a/kernels/portable/cpu/op_sub.cpp b/kernels/portable/cpu/op_sub.cpp index 6217f82c3b1..38c6334d089 100644 --- a/kernels/portable/cpu/op_sub.cpp +++ b/kernels/portable/cpu/op_sub.cpp @@ -54,20 +54,21 @@ Tensor& sub_out( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "sub.out"; - ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - utils::apply_bitensor_elementwise_fn( - [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - return val_a - val_alpha * val_b; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBF16, - b, - utils::SupportedTensorDtypes::REALHBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); - }); + ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES( + compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + utils::apply_bitensor_elementwise_fn( + [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return val_a - val_alpha * val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBF16, + b, + utils::SupportedTensorDtypes::REALHBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); + }); return out; } @@ -107,19 +108,20 @@ Tensor& sub_scalar_out( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "sub.Scalar_out"; - ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - const CTYPE_COMPUTE val_b = utils::scalar_to(b); - const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - utils::apply_unitensor_elementwise_fn( - [val_b, val_alpha](const CTYPE_COMPUTE val_a) { - return val_a - val_alpha * val_b; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBF16, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); - }); + ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES( + compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + utils::apply_unitensor_elementwise_fn( + [val_b, val_alpha](const CTYPE_COMPUTE val_a) { + return val_a - val_alpha * val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBF16, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); + }); return out; } diff --git a/kernels/portable/cpu/op_where.cpp b/kernels/portable/cpu/op_where.cpp index b455c45c2d1..3da999e1b84 100644 --- a/kernels/portable/cpu/op_where.cpp +++ b/kernels/portable/cpu/op_where.cpp @@ -42,21 +42,22 @@ Tensor& where_out( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "where.self_out"; - ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_tritensor_elementwise_fn( - [](const CTYPE_COMPUTE val_a, - const CTYPE_COMPUTE val_b, - const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - b, - utils::SupportedTensorDtypes::REALHBBF16, - cond, - utils::SupportedTensorDtypes::BOOL_OR_BYTE, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); - }); + ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES( + compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_tritensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, + const CTYPE_COMPUTE val_b, + const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + cond, + utils::SupportedTensorDtypes::BOOL_OR_BYTE, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); + }); return out; } diff --git a/kernels/portable/cpu/pattern/comparison_op.h b/kernels/portable/cpu/pattern/comparison_op.h index e0d9bf4dcab..9535c769441 100644 --- a/kernels/portable/cpu/pattern/comparison_op.h +++ b/kernels/portable/cpu/pattern/comparison_op.h @@ -90,17 +90,18 @@ Tensor& comparison_tensor_out( // Compute Dtype ScalarType compute_type = utils::get_compute_type(common_type); - ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( - ComparisonFnForOp::value, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - b, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); - }); + ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES( + compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + ComparisonFnForOp::value, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); + }); return out; } diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 206be87f98e..93b3631cc8a 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -199,11 +199,16 @@ inline void apply_tritensor_elementwise_fn( } inline ScalarType get_compute_type(ScalarType& common_type) { - ScalarType compute_type = common_type; + if (common_type == ScalarType::Long) { + return common_type; + } + if (isIntegralType(common_type, /*includeBool=*/true)) { + return ScalarType::Int; + } if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) { - compute_type = ScalarType::Float; + return ScalarType::Float; } - return compute_type; + return common_type; } } // namespace internal @@ -214,6 +219,11 @@ using internal::apply_tritensor_elementwise_fn; using internal::apply_unitensor_elementwise_fn; using internal::get_compute_type; +#define ET_SWITCH_ELEMENTWISE_COMPUTE_TYPES( \ + TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_SWITCH_FLOAT_TYPES_AND2( \ + Int, Long, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__) + } // namespace utils } // namespace native } // namespace executor From ac64f9edf69aadf5e90e1dbdeaa64f659ee2b261 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 26 Mar 2025 09:39:00 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- kernels/portable/cpu/util/elementwise_util.h | 21 ++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 93b3631cc8a..66ad1a51568 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -199,12 +199,25 @@ inline void apply_tritensor_elementwise_fn( } inline ScalarType get_compute_type(ScalarType& common_type) { - if (common_type == ScalarType::Long) { - return common_type; - } - if (isIntegralType(common_type, /*includeBool=*/true)) { + // Code size optimization: on typical 32-bit or 64-bit CPUs, the ALU should be + // just as good at 32-bit arithmetic as it is at 16-bit or 8-bit + // arithmetic, so don't go out of our way to generate 8-bit or + // 16-bit code. + + // Gate above optimization off if we appear to be on some kind of 8-bit or + // 16-bit CPU, which would invalidate our assumption about 32-bit + // math being just as fast. + constexpr bool cpu_appears_to_be_at_least_32_bit = sizeof(void*) >= 4 && sizeof(int) >= 4; + + if (cpu_appears_to_be_at_least_32_bit && + // Don't mess up 64-bit ints. + common_type != ScalarType::Long && + isIntegralType(common_type, /*includeBool=*/true)) { return ScalarType::Int; } + + // We compute in float for reduced-precision floating-point types as + // a matter of policy, not size optimization. if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) { return ScalarType::Float; }