diff --git a/kernels/optimized/cpu/binary_ops.h b/kernels/optimized/cpu/binary_ops.h index 6d941509f72..3c2dfea4e39 100644 --- a/kernels/optimized/cpu/binary_ops.h +++ b/kernels/optimized/cpu/binary_ops.h @@ -41,10 +41,56 @@ enum class ElementwiseOptimizedPath { kTreatAs1d, kBroadcast2dBy1d, kBroadcast2dBy1dReverseArguments, + kBroadcastNdByNd, + kBroadcastNdByNdReverseArguments, }; namespace internal { -inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path( + +// Find the single broadcast dimension if it exists. +int32_t inline get_broadcast_dim(const Tensor& lhs, const Tensor& rhs) { + auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes()); + auto lhs_end = lhs.sizes().end(); + + auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs.sizes()); + auto rhs_end = rhs.sizes().end(); + + const auto lhs_size = lhs_end - lhs_begin; + const auto rhs_size = rhs_end - rhs_begin; + + // Would like to handle this + // [1, 3, 4, 5] + // [2, 3, 4, 5] + if (lhs_size != rhs_size) { + return 0; + } + + int32_t broadcast_dim = 0; + // Check + // 1. if any dim value is 1 (it constitutes a broadcast dim) + // 2. If more than one dim value is 1 (we cannot handle) + // 3. If non-1 dim values are equal + lhs_end--; + rhs_end--; + while (lhs_end != lhs_begin) { + if (*lhs_end == 1 || *rhs_end == 1) { + // If more than one broadcast dim is found, return 0. + if (broadcast_dim != 0) { + return 0; + } + // negative index is used + broadcast_dim = lhs_end - lhs.sizes().end(); + } else if (*lhs_end != *rhs_end) { + // If non-1 dim values are not equal, return 0. + return 0; + } + lhs_end--; + rhs_end--; + } + return broadcast_dim; +} + +inline ElementwiseOptimizedPath select_broadcast_optimized_path( const Tensor& lhs, const Tensor& rhs) { auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes()); @@ -63,6 +109,15 @@ inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path( return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments; } + int32_t broadcast_dim = get_broadcast_dim(lhs, rhs); + // Right now we dont handle last dim broadcast + if (broadcast_dim < -1) { + if (std::count_if(rhs_begin, rhs_end, [](Tensor::SizesType x) { return x == 1; }) == 1) { + return ElementwiseOptimizedPath::kBroadcastNdByNd; + } else { + return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments; + } + } return ElementwiseOptimizedPath::kNone; } } // namespace internal @@ -85,7 +140,22 @@ ElementwiseOptimizedPath inline select_optimized_path( internal::sizes_match_ignoring_leading_1s(a.sizes(), b.sizes())))) { return ElementwiseOptimizedPath::kTreatAs1d; } - return internal::select_broadcast_2d_by_1d_optimized_path(a, b); + return internal::select_broadcast_optimized_path(a, b); +} + +std::array inline get_normalized_tensor_size(const Tensor& a, const int32_t broadcast_dim) { + ET_CHECK_MSG(a.dim() > broadcast_dim, "Size of tensor: %zd, must be larger than broadcast_dim: %d", a.dim(), broadcast_dim); + std::array normalized_tensor_size; + normalized_tensor_size[0] = 1; + normalized_tensor_size[1] = a.size(broadcast_dim); + normalized_tensor_size[2] = 1; + for (size_t i = 0; i < broadcast_dim; i++) { + normalized_tensor_size[0] = normalized_tensor_size[0] * a.size(i); + } + for (size_t i = broadcast_dim + 1; i < a.dim(); i++) { + normalized_tensor_size[2] = normalized_tensor_size[2] * a.size(i); + } + return normalized_tensor_size; } } // namespace executor diff --git a/kernels/optimized/cpu/op_mul.cpp b/kernels/optimized/cpu/op_mul.cpp index ad6034638f9..6c994c0dc51 100644 --- a/kernels/optimized/cpu/op_mul.cpp +++ b/kernels/optimized/cpu/op_mul.cpp @@ -130,15 +130,17 @@ Tensor& opt_mul_out( } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { const Tensor* lhs; const Tensor* rhs; - if (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) { + if ((selected_optimized_path == + ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || + (selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { lhs = &b; rhs = &a; } else { // Catch failure to update logic when adding new broadcasting possibility. ET_DCHECK( - selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1d); + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcast2dBy1d) || + (selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd)); lhs = &a; rhs = &b; } @@ -149,15 +151,32 @@ Tensor& opt_mul_out( InvalidArgument, out, "Failed to resize output tensor."); + int64_t outer_size = 1; + int64_t broadcast_size; + int64_t inner_size; + if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) || + (selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { + int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs); + int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim; + int32_t broadcast_dim_rhs = rhs->dim() + broadcast_dim; + auto normalized_tensor_size_lhs = get_normalized_tensor_size(*lhs, broadcast_dim_lhs); + outer_size = normalized_tensor_size_lhs[0]; + broadcast_size = normalized_tensor_size_lhs[1]; + inner_size = normalized_tensor_size_lhs[2]; + } else { + broadcast_size = lhs->sizes()[lhs->dim() - 2]; + inner_size = lhs->sizes()[lhs->dim() - 1]; + } ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { using Vec = executorch::vec::Vectorized; - executorch::vec::broadcasting_map_2d_by_1d( + executorch::vec::broadcasting_map_3d_and_unsqueezed_3d( [](Vec x, Vec y) { return x * y; }, out.mutable_data_ptr(), lhs->const_data_ptr(), rhs->const_data_ptr(), - lhs->sizes()[lhs->dim() - 2], - lhs->sizes()[lhs->dim() - 1]); + outer_size, + broadcast_size, + inner_size); }); } else { ScalarType common_type = diff --git a/kernels/optimized/vec/functional_base.h b/kernels/optimized/vec/functional_base.h index 7edb043abc9..3a66904f8a9 100644 --- a/kernels/optimized/vec/functional_base.h +++ b/kernels/optimized/vec/functional_base.h @@ -330,6 +330,43 @@ inline void map4( // a two-dimensional array of size (size, size2), input_data2 is a // one-dimensional array of size size2, and input_data2 is broadcast // to be of size (size, size2). +template +inline void broadcasting_map_3d_and_unsqueezed_3d( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* lhs, + const scalar_t* rhs, + int64_t outer_size, + int64_t broadcast_size, + int64_t inner_size) { + using Vec = vec::Vectorized; + int64_t outer_stride_lhs = inner_size * broadcast_size; + int64_t outer_stride_rhs = inner_size; + int64_t broadcast_stride_lhs = inner_size; + for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { + const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs; + scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs; + const scalar_t* rhs_outer = rhs + outer_idx * outer_stride_rhs; + for (int64_t broadcast_idx = 0; broadcast_idx < broadcast_size; ++broadcast_idx) { + const scalar_t* lhs_outer_2 = lhs_outer + broadcast_idx * broadcast_stride_lhs; + scalar_t* output_data_row_2 = output_data_row + broadcast_idx * broadcast_stride_lhs; + int64_t inner_idx = 0; + for (; inner_idx < inner_size - (inner_size % Vec::size()); inner_idx += Vec::size()) { + Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx); + Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx); + Vec output_vec = vec_fun(data_vec, data_vec2); + output_vec.store(output_data_row_2 + inner_idx); + } + if (inner_size - inner_idx > 0) { + Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx, inner_size - inner_idx); + Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx, inner_size - inner_idx); + Vec output_vec = vec_fun(data_vec, data_vec2); + output_vec.store(output_data_row_2 + inner_idx, inner_size - inner_idx); + } + } + } +} + template inline void broadcasting_map_2d_by_1d( const Op& vec_fun, @@ -338,27 +375,8 @@ inline void broadcasting_map_2d_by_1d( const scalar_t* input_data2, int64_t size, int64_t size2) { - using Vec = vec::Vectorized; - for (int64_t outer_idx = 0; outer_idx < size; ++outer_idx) { - const scalar_t* input_data_row = input_data + outer_idx * size2; - scalar_t* output_data_row = output_data + outer_idx * size2; - int64_t inner_idx = 0; - for (; inner_idx < size2 - (size2 % Vec::size()); inner_idx += Vec::size()) { - Vec data_vec = Vec::loadu(input_data_row + inner_idx); - Vec data_vec2 = Vec::loadu(input_data2 + inner_idx); - Vec output_vec = vec_fun(data_vec, data_vec2); - output_vec.store(output_data_row + inner_idx); - } - if (size2 - inner_idx > 0) { - Vec data_vec = Vec::loadu(input_data_row + inner_idx, size2 - inner_idx); - Vec data_vec2 = Vec::loadu(input_data2 + inner_idx, size2 - inner_idx); - Vec output_vec = vec_fun(data_vec, data_vec2); - output_vec.store(output_data_row + inner_idx, size2 - inner_idx); - } - } + broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2); } - - } // namespace vec } // namespace executorch