From 27a79c4e9337f66e4ae5fc40f4be395c848ce135 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 4 Feb 2025 17:06:31 -0800 Subject: [PATCH 1/4] [Executorch] Refactor op_mul's broadcasting utils Summary: Refactoring broadcast handling utils that were added for op_mul. This is in prepartion use these utils to handle broadcast for other ops such as add, sub, div. Plus remove a redundant test Test Plan: optimized_kernels_test in CI Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- kernels/optimized/cpu/binary_ops.h | 112 +++++++++++++++++++++++++++++ kernels/optimized/cpu/op_mul.cpp | 112 +---------------------------- kernels/test/op_mul_test.cpp | 7 -- 3 files changed, 115 insertions(+), 116 deletions(-) diff --git a/kernels/optimized/cpu/binary_ops.h b/kernels/optimized/cpu/binary_ops.h index ce19a8fa9de..b7592ce6554 100644 --- a/kernels/optimized/cpu/binary_ops.h +++ b/kernels/optimized/cpu/binary_ops.h @@ -8,6 +8,7 @@ #pragma once +#include #include namespace torch { @@ -190,5 +191,116 @@ std::array inline get_normalized_tensor_size( return normalized_tensor_size; } +template +Tensor& handle_last_dim_broadcast_elementwise( + KernelRuntimeContext& ctx, + const Op& vec_fun, + const Tensor& a, + const Tensor& b, + Tensor& out, + const ElementwiseOptimizedPath selected_optimized_path) { + ScalarType out_type = out.scalar_type(); + const Tensor* lhs; + const Tensor* rhs; + if (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) { + lhs = &b; + rhs = &a; + } else { + lhs = &a; + rhs = &b; + } + auto error = resize_tensor(out, lhs->sizes()); + ET_KERNEL_CHECK_MSG( + ctx, + error == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + const size_t outer_size = getLeadingDims(out, out.dim() - 1); + const auto broadcast_size = out.size(out.dim() - 1); + ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { + executorch::vec::broadcasting_map_broadcast_last_dim( + vec_fun, + out.mutable_data_ptr(), + lhs->const_data_ptr(), + rhs->const_data_ptr(), + outer_size, + broadcast_size); + }); + return out; +} + +template +Tensor& handle_broadcast_elementwise( + KernelRuntimeContext& ctx, + const Op& vec_fun, + const Tensor& a, + const Tensor& b, + Tensor& out, + const ElementwiseOptimizedPath selected_optimized_path) { + if ((selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastLastDim) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) { + return handle_last_dim_broadcast_elementwise( + ctx, vec_fun, a, b, out, selected_optimized_path); + } + + ScalarType out_type = out.scalar_type(); + const Tensor* lhs; + const Tensor* rhs; + if ((selected_optimized_path == + ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { + lhs = &b; + rhs = &a; + } else { + // Catch failure to update logic when adding new broadcasting possibility. + ET_DCHECK( + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcast2dBy1d) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNd)); + lhs = &a; + rhs = &b; + } + auto error = resize_tensor(out, lhs->sizes()); + ET_KERNEL_CHECK_MSG( + ctx, + error == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + int64_t outer_size = 1; + int64_t broadcast_size; + int64_t inner_size; + if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { + int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs); + int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim; + auto normalized_tensor_size_lhs = + get_normalized_tensor_size(*lhs, broadcast_dim_lhs); + outer_size = normalized_tensor_size_lhs[0]; + broadcast_size = normalized_tensor_size_lhs[1]; + inner_size = normalized_tensor_size_lhs[2]; + } else { + broadcast_size = lhs->sizes()[lhs->dim() - 2]; + inner_size = lhs->sizes()[lhs->dim() - 1]; + } + ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { + executorch::vec::broadcasting_map_3d_and_unsqueezed_3d( + vec_fun, + out.mutable_data_ptr(), + lhs->const_data_ptr(), + rhs->const_data_ptr(), + outer_size, + broadcast_size, + inner_size); + }); + return out; +} } // namespace executor } // namespace torch diff --git a/kernels/optimized/cpu/op_mul.cpp b/kernels/optimized/cpu/op_mul.cpp index c8e6ba7a99a..3564f8a872c 100644 --- a/kernels/optimized/cpu/op_mul.cpp +++ b/kernels/optimized/cpu/op_mul.cpp @@ -68,114 +68,6 @@ template < struct MulInner : public ReportCanCastBug {}; -Tensor& handle_last_dim_broadcast( - KernelRuntimeContext& ctx, - const Tensor& a, - const Tensor& b, - Tensor& out, - const ElementwiseOptimizedPath selected_optimized_path) { - ScalarType out_type = out.scalar_type(); - const Tensor* lhs; - const Tensor* rhs; - if (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) { - lhs = &b; - rhs = &a; - } else { - lhs = &a; - rhs = &b; - } - auto error = resize_tensor(out, lhs->sizes()); - ET_KERNEL_CHECK_MSG( - ctx, - error == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - const size_t outer_size = getLeadingDims(out, out.dim() - 1); - const auto broadcast_size = out.size(out.dim() - 1); - ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { - using Vec = executorch::vec::Vectorized; - executorch::vec::broadcasting_map_broadcast_last_dim( - [](Vec x, Vec y) { return x * y; }, - out.mutable_data_ptr(), - lhs->const_data_ptr(), - rhs->const_data_ptr(), - outer_size, - broadcast_size); - }); - return out; -} - -Tensor& handle_broadcast_mul( - KernelRuntimeContext& ctx, - const Tensor& a, - const Tensor& b, - Tensor& out, - const ElementwiseOptimizedPath selected_optimized_path) { - if ((selected_optimized_path == - ElementwiseOptimizedPath::kBroadcastLastDim) || - (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) { - return handle_last_dim_broadcast(ctx, a, b, out, selected_optimized_path); - } - - ScalarType out_type = out.scalar_type(); - const Tensor* lhs; - const Tensor* rhs; - if ((selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || - (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { - lhs = &b; - rhs = &a; - } else { - // Catch failure to update logic when adding new broadcasting possibility. - ET_DCHECK( - (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1d) || - (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcastNdByNd)); - lhs = &a; - rhs = &b; - } - auto error = resize_tensor(out, lhs->sizes()); - ET_KERNEL_CHECK_MSG( - ctx, - error == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - int64_t outer_size = 1; - int64_t broadcast_size; - int64_t inner_size; - if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) || - (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { - int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs); - int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim; - auto normalized_tensor_size_lhs = - get_normalized_tensor_size(*lhs, broadcast_dim_lhs); - outer_size = normalized_tensor_size_lhs[0]; - broadcast_size = normalized_tensor_size_lhs[1]; - inner_size = normalized_tensor_size_lhs[2]; - } else { - broadcast_size = lhs->sizes()[lhs->dim() - 2]; - inner_size = lhs->sizes()[lhs->dim() - 1]; - } - ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { - using Vec = executorch::vec::Vectorized; - executorch::vec::broadcasting_map_3d_and_unsqueezed_3d( - [](Vec x, Vec y) { return x * y; }, - out.mutable_data_ptr(), - lhs->const_data_ptr(), - rhs->const_data_ptr(), - outer_size, - broadcast_size, - inner_size); - }); - return out; -} } // namespace Tensor& opt_mul_out( @@ -238,7 +130,9 @@ Tensor& opt_mul_out( out.numel()); }); } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { - return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path); + auto mul_lambda = [](auto x, auto y) { return x * y; }; + return torch::executor::handle_broadcast_elementwise( + ctx, mul_lambda, a, b, out, selected_optimized_path); } else { ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true); diff --git a/kernels/test/op_mul_test.cpp b/kernels/test/op_mul_test.cpp index 6a4314db304..5e7b0a4efe4 100644 --- a/kernels/test/op_mul_test.cpp +++ b/kernels/test/op_mul_test.cpp @@ -453,13 +453,6 @@ TEST_F(OpMulOutTest, BroadcastNDTest) { test_broadcast_last_dim(); } -TEST_F(OpMulOutTest, BroadcastLastDimTest) { - // Test broadcasting on the last dimension - test_broadcast_last_dim(); - test_broadcast_last_dim(); - test_broadcast_last_dim(); -} - // Broadcast tensor a and b's size to a new size c. TEST_F(OpMulOutTest, BroadcastAB2CTest) { TensorFactory tf_a; From ed79e8c4ebf5afbed5180320a542cdc81854537e Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 10 Feb 2025 16:06:08 -0800 Subject: [PATCH 2/4] Update base for Update on "[Executorch] Refactor op_mul's broadcasting utils" Summary: Refactoring broadcast handling utils that were added for op_mul. This is in prepartion use these utils to handle broadcast for other ops such as add, sub, div. Plus remove a redundant test Test Plan: optimized_kernels_test in CI Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned] From ebf62fedd646a859335584751b2449c53f91bd52 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 11 Feb 2025 19:16:11 -0800 Subject: [PATCH 3/4] Update base for Update on "[Executorch] Refactor op_mul's broadcasting utils" Summary: Refactoring broadcast handling utils that were added for op_mul. This is in prepartion use these utils to handle broadcast for other ops such as add, sub, div. Plus remove a redundant test Test Plan: optimized_kernels_test in CI Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491816](https://our.internmc.facebook.com/intern/diff/D69491816) [ghstack-poisoned] From f25833fe134f23c820ac9d15ae5ce4c23732ae7c Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 12 Feb 2025 12:55:23 -0800 Subject: [PATCH 4/4] Update base for Update on "[Executorch] Refactor op_mul's broadcasting utils" Summary: Refactoring broadcast handling utils that were added for op_mul. This is in prepartion use these utils to handle broadcast for other ops such as add, sub, div. Plus remove a redundant test Test Plan: optimized_kernels_test in CI Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491816](https://our.internmc.facebook.com/intern/diff/D69491816) [ghstack-poisoned]