From aa9354cc2ada9078fae7252dfc338728667e0e33 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 26 Jun 2025 15:34:09 -0700 Subject: [PATCH] [ET][Optimized] Eliminate usage of ET_SWITCH_SCALAR in optimized kernels Differential Revision: [D77405083](https://our.internmc.facebook.com/intern/diff/D77405083/) [ghstack-poisoned] --- kernels/optimized/cpu/op_add.cpp | 77 +++++++++++++++----------------- kernels/optimized/cpu/op_mul.cpp | 55 ++++++++++------------- kernels/optimized/cpu/op_sub.cpp | 73 +++++++++++++----------------- 3 files changed, 90 insertions(+), 115 deletions(-) diff --git a/kernels/optimized/cpu/op_add.cpp b/kernels/optimized/cpu/op_add.cpp index de16429c598..97bdb0a0d5e 100644 --- a/kernels/optimized/cpu/op_add.cpp +++ b/kernels/optimized/cpu/op_add.cpp @@ -45,9 +45,7 @@ Tensor& opt_add_out( ET_SWITCH_REALB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() { CTYPE alpha_val; ET_KERNEL_CHECK( - ctx, - torch::executor::native::utils::extract_scalar(alpha, &alpha_val), - InvalidArgument, ); + ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); CTYPE_B b_val = *b.const_data_ptr(); CTYPE b_casted = static_cast(b_val); @@ -81,7 +79,6 @@ Tensor& opt_add_scalar_out( (void)ctx; ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); ScalarType common_type = utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false); ScalarType out_type = out.scalar_type(); @@ -99,47 +96,43 @@ Tensor& opt_add_scalar_out( if (a_type == common_type && a_type == out_type && a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { ET_SWITCH_REALB_TYPES(a_type, ctx, "add.Scalar_out", CTYPE, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE b_casted = static_cast(b_val); - CTYPE alpha_val; - ET_EXTRACT_SCALAR(alpha, alpha_val); - - using Vec = at::vec::Vectorized; - at::vec::map( - [alpha_val, b_casted](Vec x) { - return x + Vec(alpha_val * b_casted); - }, - out.mutable_data_ptr(), - a.const_data_ptr(), - out.numel()); - }); + CTYPE b_casted = utils::scalar_to(b); + CTYPE alpha_val; + ET_KERNEL_CHECK( + ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); + + using Vec = at::vec::Vectorized; + at::vec::map( + [alpha_val, b_casted](Vec x) { + return x + Vec(alpha_val * b_casted); + }, + out.mutable_data_ptr(), + a.const_data_ptr(), + out.numel()); }); } else { ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REALB_TYPES( - common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALHBBF16_TYPES( - out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE_IN b_casted = static_cast(b_val); - CTYPE_IN alpha_val; - ET_EXTRACT_SCALAR(alpha, alpha_val); - - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) + - alpha_val * b_casted); - } - }); - }); - }); + ET_SWITCH_REALB_TYPES( + common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() { + ET_SWITCH_REALHBBF16_TYPES( + out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() { + CTYPE_IN b_casted = utils::scalar_to(b); + CTYPE_IN alpha_val; + ET_KERNEL_CHECK( + ctx, + utils::extract_scalar(alpha, &alpha_val), + InvalidArgument, ); + + const size_t n = a.numel(); + const CTYPE_A* a_data = a.const_data_ptr(); + CTYPE_OUT* out_data = out.mutable_data_ptr(); + for (auto i = 0; i < n; ++i) { + out_data[i] = static_cast( + static_cast(a_data[i]) + + alpha_val * b_casted); + } + }); + }); }); } diff --git a/kernels/optimized/cpu/op_mul.cpp b/kernels/optimized/cpu/op_mul.cpp index 50eb297a625..8783812ede1 100644 --- a/kernels/optimized/cpu/op_mul.cpp +++ b/kernels/optimized/cpu/op_mul.cpp @@ -218,7 +218,6 @@ Tensor& opt_mul_scalar_out( (void)ctx; ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); ScalarType common_type = utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false); ScalarType out_type = out.scalar_type(); @@ -236,40 +235,32 @@ Tensor& opt_mul_scalar_out( if (a_type == common_type && a_type == out_type && a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE b_casted = static_cast(b_val); - - using Vec = at::vec::Vectorized; - at::vec::map( - [b_casted](Vec x) { return x * Vec(b_casted); }, - out.mutable_data_ptr(), - a.const_data_ptr(), - out.numel()); - }); + CTYPE b_casted = utils::scalar_to(b); + + using Vec = at::vec::Vectorized; + at::vec::map( + [b_casted](Vec x) { return x * Vec(b_casted); }, + out.mutable_data_ptr(), + a.const_data_ptr(), + out.numel()); }); } else { ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REALB_TYPES( - common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALHBBF16_TYPES( - out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE_IN b_casted = static_cast(b_val); - - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) * b_casted); - } - }); - }); - }); + ET_SWITCH_REALB_TYPES( + common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() { + ET_SWITCH_REALHBBF16_TYPES( + out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() { + CTYPE_IN b_casted = utils::scalar_to(b); + + const size_t n = a.numel(); + const CTYPE_A* a_data = a.const_data_ptr(); + CTYPE_OUT* out_data = out.mutable_data_ptr(); + for (auto i = 0; i < n; ++i) { + out_data[i] = static_cast( + static_cast(a_data[i]) * b_casted); + } + }); + }); }); } diff --git a/kernels/optimized/cpu/op_sub.cpp b/kernels/optimized/cpu/op_sub.cpp index 94f0ba4e785..db2f1dd97f7 100644 --- a/kernels/optimized/cpu/op_sub.cpp +++ b/kernels/optimized/cpu/op_sub.cpp @@ -154,7 +154,6 @@ Tensor& opt_sub_scalar_out( (void)ctx; ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); ScalarType common_type = utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false); ScalarType out_type = out.scalar_type(); @@ -172,49 +171,41 @@ Tensor& opt_sub_scalar_out( if (a_type == common_type && a_type == out_type && a_type != ScalarType::Half) { ET_SWITCH_REAL_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE, [&]() { - ET_SWITCH_SCALAR_OBJ_REAL_TYPES( - b_type, ctx, "sub.Scalar_out", CTYPE_B, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE b_casted = static_cast(b_val); - CTYPE alpha_val; - ET_EXTRACT_SCALAR(alpha, alpha_val); - - using Vec = at::vec::Vectorized; - at::vec::map( - [alpha_val, b_casted](Vec x) { - return x - Vec(alpha_val * b_casted); - }, - out.mutable_data_ptr(), - a.const_data_ptr(), - out.numel()); - }); + CTYPE b_casted = utils::scalar_to(b); + CTYPE alpha_val; + ET_KERNEL_CHECK( + ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); + + using Vec = at::vec::Vectorized; + at::vec::map( + [alpha_val, b_casted](Vec x) { + return x - Vec(alpha_val * b_casted); + }, + out.mutable_data_ptr(), + a.const_data_ptr(), + out.numel()); }); } else { ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_REAL_TYPES( - b_type, ctx, "sub.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES( - common_type, ctx, "sub.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALH_TYPES( - out_type, ctx, "sub.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B b_val; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE_IN b_casted = static_cast(b_val); - CTYPE_IN alpha_val; - ET_EXTRACT_SCALAR(alpha, alpha_val); - - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) - - alpha_val * b_casted); - } - }); - }); - }); + ET_SWITCH_REAL_TYPES(common_type, ctx, "sub.Scalar_out", CTYPE_IN, [&]() { + ET_SWITCH_REALH_TYPES( + out_type, ctx, "sub.Scalar_out", CTYPE_OUT, [&]() { + CTYPE_IN b_casted = utils::scalar_to(b); + CTYPE_IN alpha_val; + ET_KERNEL_CHECK( + ctx, + utils::extract_scalar(alpha, &alpha_val), + InvalidArgument, ); + + const size_t n = a.numel(); + const CTYPE_A* a_data = a.const_data_ptr(); + CTYPE_OUT* out_data = out.mutable_data_ptr(); + for (auto i = 0; i < n; ++i) { + out_data[i] = static_cast( + static_cast(a_data[i]) - alpha_val * b_casted); + } + }); + }); }); }