From 11443faf2bb7ffaad21add4e2eac6aa863d4a15f Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 26 Jun 2025 11:03:30 -0700 Subject: [PATCH 1/2] [EE/BE][ET][Portable] Move scalar_to utils to scalar_utils.h Differential Revision: [D75962656](https://our.internmc.facebook.com/intern/diff/D75962656/) ghstack-source-id: 292676104 Pull Request resolved: https://github.com/pytorch/executorch/pull/12009 --- kernels/portable/cpu/scalar_utils.h | 28 ++++++++++++++++++- kernels/portable/cpu/util/elementwise_util.h | 29 +------------------- kernels/portable/cpu/util/targets.bzl | 2 +- 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/kernels/portable/cpu/scalar_utils.h b/kernels/portable/cpu/scalar_utils.h index 02700804819..162f96ba85d 100644 --- a/kernels/portable/cpu/scalar_utils.h +++ b/kernels/portable/cpu/scalar_utils.h @@ -8,7 +8,6 @@ #pragma once -#include #include #include @@ -261,6 +260,33 @@ bool extract_scalar(Scalar scalar, BOOL_T* out_val) { return false; } +/* + * Convert Scalar to C++ type + */ + +template +T scalar_to(const Scalar& s) { + if (s.isBoolean()) { + return static_cast(s.to()); + } else if (s.isFloatingPoint()) { + return static_cast(s.to()); + } else { + return static_cast(s.to()); + } +} + +template <> +inline double scalar_to(const Scalar& s) { + return s.isFloatingPoint() ? s.to() + : static_cast(s.to()); +} + +template <> +inline int64_t scalar_to(const Scalar& s) { + return s.isFloatingPoint() ? static_cast(s.to()) + : s.to(); +} + } // namespace utils } // namespace native } // namespace executor diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 948da50fdd4..6adf81f70e3 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include #include @@ -27,34 +28,6 @@ namespace torch { namespace executor { namespace native { namespace utils { - -/* - * Convert Scalar to C++ type - */ - -template -T scalar_to(const Scalar& s) { - if (s.isBoolean()) { - return static_cast(s.to()); - } else if (s.isFloatingPoint()) { - return static_cast(s.to()); - } else { - return static_cast(s.to()); - } -} - -template <> -inline double scalar_to(const Scalar& s) { - return s.isFloatingPoint() ? s.to() - : static_cast(s.to()); -} - -template <> -inline int64_t scalar_to(const Scalar& s) { - return s.isFloatingPoint() ? static_cast(s.to()) - : s.to(); -} - namespace internal { /** * Causes these utility functions to make sure to respect Tensor diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 1523fcfe706..ef3a878fd70 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -116,9 +116,9 @@ def define_common_targets(): "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", "//executorch/runtime/kernel:kernel_runtime_context", "//executorch/extension/threadpool:threadpool", + "//executorch/kernels/portable/cpu:scalar_utils", ], deps = [ - "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/runtime/kernel:kernel_includes", ], visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/...", "@EXECUTORCH_CLIENTS"], From ae5d69b796e8f45d61db3116e0d48bbc575e4e78 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 26 Jun 2025 15:34:11 -0700 Subject: [PATCH 2/2] [ET][Optimized] Eliminate usage of ET_SWITCH_SCALAR in optimized kernels Differential Revision: [D77405083](https://our.internmc.facebook.com/intern/diff/D77405083/) ghstack-source-id: 292967023 Pull Request resolved: https://github.com/pytorch/executorch/pull/12033 --- kernels/optimized/cpu/op_add.cpp | 77 +++++++++++++++----------------- kernels/optimized/cpu/op_mul.cpp | 55 ++++++++++------------- kernels/optimized/cpu/op_sub.cpp | 73 +++++++++++++----------------- 3 files changed, 90 insertions(+), 115 deletions(-) diff --git a/kernels/optimized/cpu/op_add.cpp b/kernels/optimized/cpu/op_add.cpp index de16429c598..97bdb0a0d5e 100644 --- a/kernels/optimized/cpu/op_add.cpp +++ b/kernels/optimized/cpu/op_add.cpp @@ -45,9 +45,7 @@ Tensor& opt_add_out( ET_SWITCH_REALB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() { CTYPE alpha_val; ET_KERNEL_CHECK( - ctx, - torch::executor::native::utils::extract_scalar(alpha, &alpha_val), - InvalidArgument, ); + ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); CTYPE_B b_val = *b.const_data_ptr(); CTYPE b_casted = static_cast(b_val); @@ -81,7 +79,6 @@ Tensor& opt_add_scalar_out( (void)ctx; ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); ScalarType common_type = utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false); ScalarType out_type = out.scalar_type(); @@ -99,47 +96,43 @@ Tensor& opt_add_scalar_out( if (a_type == common_type && a_type == out_type && a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { ET_SWITCH_REALB_TYPES(a_type, ctx, "add.Scalar_out", CTYPE, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE b_casted = static_cast(b_val); - CTYPE alpha_val; - ET_EXTRACT_SCALAR(alpha, alpha_val); - - using Vec = at::vec::Vectorized; - at::vec::map( - [alpha_val, b_casted](Vec x) { - return x + Vec(alpha_val * b_casted); - }, - out.mutable_data_ptr(), - a.const_data_ptr(), - out.numel()); - }); + CTYPE b_casted = utils::scalar_to(b); + CTYPE alpha_val; + ET_KERNEL_CHECK( + ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); + + using Vec = at::vec::Vectorized; + at::vec::map( + [alpha_val, b_casted](Vec x) { + return x + Vec(alpha_val * b_casted); + }, + out.mutable_data_ptr(), + a.const_data_ptr(), + out.numel()); }); } else { ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REALB_TYPES( - common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALHBBF16_TYPES( - out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE_IN b_casted = static_cast(b_val); - CTYPE_IN alpha_val; - ET_EXTRACT_SCALAR(alpha, alpha_val); - - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) + - alpha_val * b_casted); - } - }); - }); - }); + ET_SWITCH_REALB_TYPES( + common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() { + ET_SWITCH_REALHBBF16_TYPES( + out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() { + CTYPE_IN b_casted = utils::scalar_to(b); + CTYPE_IN alpha_val; + ET_KERNEL_CHECK( + ctx, + utils::extract_scalar(alpha, &alpha_val), + InvalidArgument, ); + + const size_t n = a.numel(); + const CTYPE_A* a_data = a.const_data_ptr(); + CTYPE_OUT* out_data = out.mutable_data_ptr(); + for (auto i = 0; i < n; ++i) { + out_data[i] = static_cast( + static_cast(a_data[i]) + + alpha_val * b_casted); + } + }); + }); }); } diff --git a/kernels/optimized/cpu/op_mul.cpp b/kernels/optimized/cpu/op_mul.cpp index 50eb297a625..8783812ede1 100644 --- a/kernels/optimized/cpu/op_mul.cpp +++ b/kernels/optimized/cpu/op_mul.cpp @@ -218,7 +218,6 @@ Tensor& opt_mul_scalar_out( (void)ctx; ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); ScalarType common_type = utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false); ScalarType out_type = out.scalar_type(); @@ -236,40 +235,32 @@ Tensor& opt_mul_scalar_out( if (a_type == common_type && a_type == out_type && a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE b_casted = static_cast(b_val); - - using Vec = at::vec::Vectorized; - at::vec::map( - [b_casted](Vec x) { return x * Vec(b_casted); }, - out.mutable_data_ptr(), - a.const_data_ptr(), - out.numel()); - }); + CTYPE b_casted = utils::scalar_to(b); + + using Vec = at::vec::Vectorized; + at::vec::map( + [b_casted](Vec x) { return x * Vec(b_casted); }, + out.mutable_data_ptr(), + a.const_data_ptr(), + out.numel()); }); } else { ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REALB_TYPES( - common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALHBBF16_TYPES( - out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE_IN b_casted = static_cast(b_val); - - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) * b_casted); - } - }); - }); - }); + ET_SWITCH_REALB_TYPES( + common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() { + ET_SWITCH_REALHBBF16_TYPES( + out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() { + CTYPE_IN b_casted = utils::scalar_to(b); + + const size_t n = a.numel(); + const CTYPE_A* a_data = a.const_data_ptr(); + CTYPE_OUT* out_data = out.mutable_data_ptr(); + for (auto i = 0; i < n; ++i) { + out_data[i] = static_cast( + static_cast(a_data[i]) * b_casted); + } + }); + }); }); } diff --git a/kernels/optimized/cpu/op_sub.cpp b/kernels/optimized/cpu/op_sub.cpp index 94f0ba4e785..db2f1dd97f7 100644 --- a/kernels/optimized/cpu/op_sub.cpp +++ b/kernels/optimized/cpu/op_sub.cpp @@ -154,7 +154,6 @@ Tensor& opt_sub_scalar_out( (void)ctx; ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); ScalarType common_type = utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false); ScalarType out_type = out.scalar_type(); @@ -172,49 +171,41 @@ Tensor& opt_sub_scalar_out( if (a_type == common_type && a_type == out_type && a_type != ScalarType::Half) { ET_SWITCH_REAL_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE, [&]() { - ET_SWITCH_SCALAR_OBJ_REAL_TYPES( - b_type, ctx, "sub.Scalar_out", CTYPE_B, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE b_casted = static_cast(b_val); - CTYPE alpha_val; - ET_EXTRACT_SCALAR(alpha, alpha_val); - - using Vec = at::vec::Vectorized; - at::vec::map( - [alpha_val, b_casted](Vec x) { - return x - Vec(alpha_val * b_casted); - }, - out.mutable_data_ptr(), - a.const_data_ptr(), - out.numel()); - }); + CTYPE b_casted = utils::scalar_to(b); + CTYPE alpha_val; + ET_KERNEL_CHECK( + ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); + + using Vec = at::vec::Vectorized; + at::vec::map( + [alpha_val, b_casted](Vec x) { + return x - Vec(alpha_val * b_casted); + }, + out.mutable_data_ptr(), + a.const_data_ptr(), + out.numel()); }); } else { ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_REAL_TYPES( - b_type, ctx, "sub.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES( - common_type, ctx, "sub.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALH_TYPES( - out_type, ctx, "sub.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE_IN b_casted = static_cast(b_val); - CTYPE_IN alpha_val; - ET_EXTRACT_SCALAR(alpha, alpha_val); - - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) - - alpha_val * b_casted); - } - }); - }); - }); + ET_SWITCH_REAL_TYPES(common_type, ctx, "sub.Scalar_out", CTYPE_IN, [&]() { + ET_SWITCH_REALH_TYPES( + out_type, ctx, "sub.Scalar_out", CTYPE_OUT, [&]() { + CTYPE_IN b_casted = utils::scalar_to(b); + CTYPE_IN alpha_val; + ET_KERNEL_CHECK( + ctx, + utils::extract_scalar(alpha, &alpha_val), + InvalidArgument, ); + + const size_t n = a.numel(); + const CTYPE_A* a_data = a.const_data_ptr(); + CTYPE_OUT* out_data = out.mutable_data_ptr(); + for (auto i = 0; i < n; ++i) { + out_data[i] = static_cast( + static_cast(a_data[i]) - alpha_val * b_casted); + } + }); + }); }); }