diff --git a/kernels/optimized/cpu/binary_ops.h b/kernels/optimized/cpu/binary_ops.h index 3c2dfea4e39..3b25e011439 100644 --- a/kernels/optimized/cpu/binary_ops.h +++ b/kernels/optimized/cpu/binary_ops.h @@ -43,6 +43,8 @@ enum class ElementwiseOptimizedPath { kBroadcast2dBy1dReverseArguments, kBroadcastNdByNd, kBroadcastNdByNdReverseArguments, + kBroadcastLastDim, + kBroadcastLastDimReverseArguments, }; namespace internal { @@ -117,6 +119,12 @@ inline ElementwiseOptimizedPath select_broadcast_optimized_path( } else { return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments; } + } else if (broadcast_dim == -1) { + if (std::count_if(lhs_begin, lhs_end, [](Tensor::SizesType x) { return x == 1; }) == 1) { + return ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments; + } else { + return ElementwiseOptimizedPath::kBroadcastLastDim; + } } return ElementwiseOptimizedPath::kNone; } diff --git a/kernels/optimized/cpu/op_mul.cpp b/kernels/optimized/cpu/op_mul.cpp index 6c994c0dc51..d8d2f505313 100644 --- a/kernels/optimized/cpu/op_mul.cpp +++ b/kernels/optimized/cpu/op_mul.cpp @@ -11,6 +11,7 @@ #include #include #include +#include // IWYU pragma: export #include #include @@ -66,6 +67,117 @@ template < typename CTYPE_OUT> struct MulInner : public ReportCanCastBug {}; + +Tensor& handle_last_dim_broadcast( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + Tensor& out, + const ElementwiseOptimizedPath selected_optimized_path) { + ScalarType out_type = out.scalar_type(); + const Tensor* lhs; + const Tensor* rhs; + if (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) { + lhs = &b; + rhs = &a; + } else { + lhs = &a; + rhs = &b; + } + auto error = resize_tensor(out, lhs->sizes()); + ET_KERNEL_CHECK_MSG( + ctx, + error == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + const size_t outer_size = getLeadingDims(out, out.dim() - 1); + const auto broadcast_size = out.size(out.dim() - 1); + ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { + using Vec = executorch::vec::Vectorized; + executorch::vec::broadcasting_map_broadcast_last_dim( + [](Vec x, Vec y) { return x * y; }, + out.mutable_data_ptr(), + lhs->const_data_ptr(), + rhs->const_data_ptr(), + outer_size, + broadcast_size); + }); + return out; +} + +Tensor& handle_broadcast_mul( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + Tensor& out, + const ElementwiseOptimizedPath selected_optimized_path) { + + if ((selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastLastDim) || + (selected_optimized_path == ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) { + return handle_last_dim_broadcast( + ctx, a, b, out, selected_optimized_path); + } + + ScalarType out_type = out.scalar_type(); + const Tensor* lhs; + const Tensor* rhs; + if ((selected_optimized_path == + ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { + lhs = &b; + rhs = &a; + } else { + // Catch failure to update logic when adding new broadcasting possibility. + ET_DCHECK( + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcast2dBy1d) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNd)); + lhs = &a; + rhs = &b; + } + auto error = resize_tensor(out, lhs->sizes()); + ET_KERNEL_CHECK_MSG( + ctx, + error == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + int64_t outer_size = 1; + int64_t broadcast_size; + int64_t inner_size; + if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { + int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs); + int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim; + int32_t broadcast_dim_rhs = rhs->dim() + broadcast_dim; + auto normalized_tensor_size_lhs = + get_normalized_tensor_size(*lhs, broadcast_dim_lhs); + outer_size = normalized_tensor_size_lhs[0]; + broadcast_size = normalized_tensor_size_lhs[1]; + inner_size = normalized_tensor_size_lhs[2]; + } else { + broadcast_size = lhs->sizes()[lhs->dim() - 2]; + inner_size = lhs->sizes()[lhs->dim() - 1]; + } + ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { + using Vec = executorch::vec::Vectorized; + executorch::vec::broadcasting_map_3d_and_unsqueezed_3d( + [](Vec x, Vec y) { return x * y; }, + out.mutable_data_ptr(), + lhs->const_data_ptr(), + rhs->const_data_ptr(), + outer_size, + broadcast_size, + inner_size); + }); + return out; +} } // namespace Tensor& opt_mul_out( @@ -128,56 +240,7 @@ Tensor& opt_mul_out( out.numel()); }); } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { - const Tensor* lhs; - const Tensor* rhs; - if ((selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || - (selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { - lhs = &b; - rhs = &a; - } else { - // Catch failure to update logic when adding new broadcasting possibility. - ET_DCHECK( - (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1d) || - (selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd)); - lhs = &a; - rhs = &b; - } - auto error = resize_tensor(out, lhs->sizes()); - ET_KERNEL_CHECK_MSG( - ctx, - error == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - int64_t outer_size = 1; - int64_t broadcast_size; - int64_t inner_size; - if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) || - (selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { - int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs); - int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim; - int32_t broadcast_dim_rhs = rhs->dim() + broadcast_dim; - auto normalized_tensor_size_lhs = get_normalized_tensor_size(*lhs, broadcast_dim_lhs); - outer_size = normalized_tensor_size_lhs[0]; - broadcast_size = normalized_tensor_size_lhs[1]; - inner_size = normalized_tensor_size_lhs[2]; - } else { - broadcast_size = lhs->sizes()[lhs->dim() - 2]; - inner_size = lhs->sizes()[lhs->dim() - 1]; - } - ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { - using Vec = executorch::vec::Vectorized; - executorch::vec::broadcasting_map_3d_and_unsqueezed_3d( - [](Vec x, Vec y) { return x * y; }, - out.mutable_data_ptr(), - lhs->const_data_ptr(), - rhs->const_data_ptr(), - outer_size, - broadcast_size, - inner_size); - }); + return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path); } else { ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true); diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index 488d2af7fa1..77a270cc45d 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -72,6 +72,7 @@ _OPTIMIZED_ATEN_OPS = ( ":binary_ops", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/runtime/core/exec_aten/util:tensor_util", ], ), op_target( diff --git a/kernels/optimized/vec/functional_base.h b/kernels/optimized/vec/functional_base.h index 3a66904f8a9..a43007bfa4e 100644 --- a/kernels/optimized/vec/functional_base.h +++ b/kernels/optimized/vec/functional_base.h @@ -378,5 +378,34 @@ inline void broadcasting_map_2d_by_1d( broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2); } +template +inline void broadcasting_map_broadcast_last_dim( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* lhs, + const scalar_t* rhs, + int64_t outer_size, + int64_t broadcast_size) { + using Vec = vec::Vectorized; + int64_t outer_stride_lhs = broadcast_size; + int64_t outer_stride_rhs = 1; + for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { + const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs; + scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs; + int64_t inner_idx = 0; + Vec data_vec2 = Vec(rhs[outer_idx]); + for (; inner_idx < broadcast_size - (broadcast_size % Vec::size()); inner_idx += Vec::size()) { + Vec data_vec = Vec::loadu(lhs_outer + inner_idx); + Vec output_vec = vec_fun(data_vec, data_vec2); + output_vec.store(output_data_row + inner_idx); + } + if (broadcast_size - inner_idx > 0) { + Vec data_vec = Vec::loadu(lhs_outer + inner_idx, broadcast_size - inner_idx); + Vec output_vec = vec_fun(data_vec, data_vec2); + output_vec.store(output_data_row + inner_idx, broadcast_size - inner_idx); + } + } +} + } // namespace vec } // namespace executorch