From 87eaeea1efae2f66e5c5016a9ba8c8e27dc67eda Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 28 Oct 2024 08:47:48 -0700 Subject: [PATCH] [Executorch] mul broadcast update Handle broadcast for > 2D tensors in optimized library. For now broadcast across only non 0th and (N-1)st dim is supported in optimized path. Differential Revision: [D64156862](https://our.internmc.facebook.com/intern/diff/D64156862/) [ghstack-poisoned] --- kernels/optimized/cpu/binary_ops.h | 88 +++++++++++++++++++++++- kernels/optimized/cpu/op_mul.cpp | 37 ++++++++-- kernels/optimized/vec/functional_base.h | 68 ++++++++++++------- kernels/test/op_mul_test.cpp | 89 +++++++++++++++++++++++++ 4 files changed, 249 insertions(+), 33 deletions(-) diff --git a/kernels/optimized/cpu/binary_ops.h b/kernels/optimized/cpu/binary_ops.h index 6d941509f72..d02153ea441 100644 --- a/kernels/optimized/cpu/binary_ops.h +++ b/kernels/optimized/cpu/binary_ops.h @@ -41,10 +41,62 @@ enum class ElementwiseOptimizedPath { kTreatAs1d, kBroadcast2dBy1d, kBroadcast2dBy1dReverseArguments, + kBroadcastNdByNd, + kBroadcastNdByNdReverseArguments, }; namespace internal { -inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path( + +// Find the single broadcast dimension if it exists. +// This path aims to handle broadcast of the following form +// A = [a1, a2,., 1, .., an] +// B = [b1, b2,., bm, .., bn] +// OR +// A = [a1, a2,., am, .., an] +// B = [b1, b2,., 1, .., bn] +int32_t inline get_broadcast_dim(const Tensor& lhs, const Tensor& rhs) { + auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes()); + auto lhs_end = lhs.sizes().end(); + + auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs.sizes()); + auto rhs_end = rhs.sizes().end(); + + const auto lhs_size = lhs_end - lhs_begin; + const auto rhs_size = rhs_end - rhs_begin; + + // Following example is not handled at the moment + // [1, 3, 4, 5] + // [2, 3, 4, 5] + if (lhs_size != rhs_size) { + return 0; + } + + int32_t broadcast_dim = 0; + // Check + // 1. if any dim value is 1 (it constitutes a broadcast dim) + // 2. If more than one dim value is 1 (we cannot handle) + // 3. If non-1 dim values are equal + lhs_end--; + rhs_end--; + while (lhs_end != lhs_begin) { + if (*lhs_end == 1 || *rhs_end == 1) { + // If more than one broadcast dim is found, return 0. + if (broadcast_dim != 0) { + return 0; + } + // negative index is used + broadcast_dim = lhs_end - lhs.sizes().end(); + } else if (*lhs_end != *rhs_end) { + // If non-1 dim values are not equal, return 0. + return 0; + } + lhs_end--; + rhs_end--; + } + return broadcast_dim; +} + +inline ElementwiseOptimizedPath select_broadcast_optimized_path( const Tensor& lhs, const Tensor& rhs) { auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes()); @@ -63,6 +115,17 @@ inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path( return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments; } + int32_t broadcast_dim = get_broadcast_dim(lhs, rhs); + // Right now we dont handle last dim broadcast + if (broadcast_dim < -1) { + if (std::count_if(rhs_begin, rhs_end, [](Tensor::SizesType x) { + return x == 1; + }) == 1) { + return ElementwiseOptimizedPath::kBroadcastNdByNd; + } else { + return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments; + } + } return ElementwiseOptimizedPath::kNone; } } // namespace internal @@ -85,7 +148,28 @@ ElementwiseOptimizedPath inline select_optimized_path( internal::sizes_match_ignoring_leading_1s(a.sizes(), b.sizes())))) { return ElementwiseOptimizedPath::kTreatAs1d; } - return internal::select_broadcast_2d_by_1d_optimized_path(a, b); + return internal::select_broadcast_optimized_path(a, b); +} + +std::array inline get_normalized_tensor_size( + const Tensor& a, + const int32_t broadcast_dim) { + ET_CHECK_MSG( + a.dim() > broadcast_dim, + "Size of tensor: %zd, must be larger than broadcast_dim: %d", + a.dim(), + broadcast_dim); + std::array normalized_tensor_size; + normalized_tensor_size[0] = 1; + normalized_tensor_size[1] = a.size(broadcast_dim); + normalized_tensor_size[2] = 1; + for (size_t i = 0; i < broadcast_dim; i++) { + normalized_tensor_size[0] *= a.size(i); + } + for (size_t i = broadcast_dim + 1; i < a.dim(); i++) { + normalized_tensor_size[2] *= a.size(i); + } + return normalized_tensor_size; } } // namespace executor diff --git a/kernels/optimized/cpu/op_mul.cpp b/kernels/optimized/cpu/op_mul.cpp index ad6034638f9..9208ca3b3a7 100644 --- a/kernels/optimized/cpu/op_mul.cpp +++ b/kernels/optimized/cpu/op_mul.cpp @@ -130,15 +130,19 @@ Tensor& opt_mul_out( } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { const Tensor* lhs; const Tensor* rhs; - if (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) { + if ((selected_optimized_path == + ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { lhs = &b; rhs = &a; } else { // Catch failure to update logic when adding new broadcasting possibility. ET_DCHECK( - selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1d); + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcast2dBy1d) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNd)); lhs = &a; rhs = &b; } @@ -149,15 +153,34 @@ Tensor& opt_mul_out( InvalidArgument, out, "Failed to resize output tensor."); + int64_t outer_size = 1; + int64_t broadcast_size; + int64_t inner_size; + if ((selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNd) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { + int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs); + int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim; + auto normalized_tensor_size_lhs = + get_normalized_tensor_size(*lhs, broadcast_dim_lhs); + outer_size = normalized_tensor_size_lhs[0]; + broadcast_size = normalized_tensor_size_lhs[1]; + inner_size = normalized_tensor_size_lhs[2]; + } else { + broadcast_size = lhs->sizes()[lhs->dim() - 2]; + inner_size = lhs->sizes()[lhs->dim() - 1]; + } ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { using Vec = executorch::vec::Vectorized; - executorch::vec::broadcasting_map_2d_by_1d( + executorch::vec::broadcasting_map_3d_and_unsqueezed_3d( [](Vec x, Vec y) { return x * y; }, out.mutable_data_ptr(), lhs->const_data_ptr(), rhs->const_data_ptr(), - lhs->sizes()[lhs->dim() - 2], - lhs->sizes()[lhs->dim() - 1]); + outer_size, + broadcast_size, + inner_size); }); } else { ScalarType common_type = diff --git a/kernels/optimized/vec/functional_base.h b/kernels/optimized/vec/functional_base.h index 7edb043abc9..6f9fcd8e796 100644 --- a/kernels/optimized/vec/functional_base.h +++ b/kernels/optimized/vec/functional_base.h @@ -326,10 +326,49 @@ inline void map4( } -// Map vec_fun across input_data and input_data2, where input_data is -// a two-dimensional array of size (size, size2), input_data2 is a -// one-dimensional array of size size2, and input_data2 is broadcast -// to be of size (size, size2). +// This function implements broadcasting binary operation on two tensors +// where lhs tensor is treated to be of shape [outer_size, broadcast_size, inner_size] +// and rhs tensor is treated to be of shape [outer_size, 1, inner_size] +// And this 1st dimension is considered broadcasting dimension +// This formula can map broadcasting on any dim=broadcast_dim +// for any two N dimensional tensors, where 0 < braodcast_dim < N-1 +template +inline void broadcasting_map_3d_and_unsqueezed_3d( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* lhs, + const scalar_t* rhs, + int64_t outer_size, + int64_t broadcast_size, + int64_t inner_size) { + using Vec = vec::Vectorized; + int64_t outer_stride_lhs = inner_size * broadcast_size; + int64_t outer_stride_rhs = inner_size; + int64_t broadcast_stride_lhs = inner_size; + for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { + const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs; + scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs; + const scalar_t* rhs_outer = rhs + outer_idx * outer_stride_rhs; + for (int64_t broadcast_idx = 0; broadcast_idx < broadcast_size; ++broadcast_idx) { + const scalar_t* lhs_outer_2 = lhs_outer + broadcast_idx * broadcast_stride_lhs; + scalar_t* output_data_row_2 = output_data_row + broadcast_idx * broadcast_stride_lhs; + int64_t inner_idx = 0; + for (; inner_idx < inner_size - (inner_size % Vec::size()); inner_idx += Vec::size()) { + Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx); + Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx); + Vec output_vec = vec_fun(data_vec, data_vec2); + output_vec.store(output_data_row_2 + inner_idx); + } + if (inner_size - inner_idx > 0) { + Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx, inner_size - inner_idx); + Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx, inner_size - inner_idx); + Vec output_vec = vec_fun(data_vec, data_vec2); + output_vec.store(output_data_row_2 + inner_idx, inner_size - inner_idx); + } + } + } +} + template inline void broadcasting_map_2d_by_1d( const Op& vec_fun, @@ -338,27 +377,8 @@ inline void broadcasting_map_2d_by_1d( const scalar_t* input_data2, int64_t size, int64_t size2) { - using Vec = vec::Vectorized; - for (int64_t outer_idx = 0; outer_idx < size; ++outer_idx) { - const scalar_t* input_data_row = input_data + outer_idx * size2; - scalar_t* output_data_row = output_data + outer_idx * size2; - int64_t inner_idx = 0; - for (; inner_idx < size2 - (size2 % Vec::size()); inner_idx += Vec::size()) { - Vec data_vec = Vec::loadu(input_data_row + inner_idx); - Vec data_vec2 = Vec::loadu(input_data2 + inner_idx); - Vec output_vec = vec_fun(data_vec, data_vec2); - output_vec.store(output_data_row + inner_idx); - } - if (size2 - inner_idx > 0) { - Vec data_vec = Vec::loadu(input_data_row + inner_idx, size2 - inner_idx); - Vec data_vec2 = Vec::loadu(input_data2 + inner_idx, size2 - inner_idx); - Vec output_vec = vec_fun(data_vec, data_vec2); - output_vec.store(output_data_row + inner_idx, size2 - inner_idx); - } - } + broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2); } - - } // namespace vec } // namespace executorch diff --git a/kernels/test/op_mul_test.cpp b/kernels/test/op_mul_test.cpp index f3c9e54c862..6430bfbd8f5 100644 --- a/kernels/test/op_mul_test.cpp +++ b/kernels/test/op_mul_test.cpp @@ -153,6 +153,73 @@ class OpMulOutTest : public OperatorTest { } } + template + void test_broadcast_3D() { + TensorFactory tf_a; + + Tensor a = + tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Tensor b = tf_a.make({2, 1, 3}, /*data=*/{2, 3, 4, 5, 6, 7}); + + // Destination for output of mul. + Tensor out = + tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Tensor expected = tf_a.make( + {2, 2, 3}, /*data=*/{2, 6, 12, 8, 15, 24, 35, 48, 63, 50, 66, 84}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected); + EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected); + } + + template + void test_broadcast_4D() { + TensorFactory tf_a; + + Tensor a = tf_a.make( + {2, 2, 3, 5}, + /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60}); + Tensor b = tf_a.make( + {2, 1, 3, 5}, + /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}); + + // Destination for output of mul. + Tensor out = tf_a.zeros({2, 2, 3, 5}); + Tensor expected = tf_a.make( + {2, 2, 3, 5}, + /*data=*/{1, 4, 9, 16, 25, 36, 49, 64, 81, 100, + 121, 144, 169, 196, 225, 16, 34, 54, 76, 100, + 126, 154, 184, 216, 250, 286, 324, 364, 406, 450, + 496, 544, 594, 646, 700, 756, 814, 874, 936, 1000, + 1066, 1134, 1204, 1276, 1350, 736, 799, 864, 931, 1000, + 1071, 1144, 1219, 1296, 1375, 1456, 1539, 1624, 1711, 1800}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected); + EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected); + + b = tf_a.make( + {2, 2, 1, 5}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}); + out = tf_a.zeros({2, 2, 3, 5}); + expected = tf_a.make( + {2, 2, 3, 5}, + /*data=*/{1, 4, 9, 16, 25, 6, 14, 24, 36, 50, + 11, 24, 39, 56, 75, 96, 119, 144, 171, 200, + 126, 154, 184, 216, 250, 156, 189, 224, 261, 300, + 341, 384, 429, 476, 525, 396, 444, 494, 546, 600, + 451, 504, 559, 616, 675, 736, 799, 864, 931, 1000, + 816, 884, 954, 1026, 1100, 896, 969, 1044, 1121, 1200}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected); + EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected); + } + template void test_broadcast_b2a() { TensorFactory tf_a; @@ -296,6 +363,16 @@ TEST_F(OpMulOutTest, BroadcastA2BTest) { test_broadcast_a2b(); test_broadcast_a2b(); test_broadcast_a2b(); + + // Test 3D tensors + test_broadcast_3D(); + test_broadcast_3D(); + test_broadcast_3D(); + + // Test 4D tensors + test_broadcast_4D(); + test_broadcast_4D(); + test_broadcast_4D(); } // Broadcast tensor a's size to tensor b's size @@ -305,6 +382,18 @@ TEST_F(OpMulOutTest, BroadcastB2ATest) { test_broadcast_b2a(); } +TEST_F(OpMulOutTest, BroadcastNDTest) { + // Test 3D tensors + test_broadcast_3D(); + test_broadcast_3D(); + test_broadcast_3D(); + + // Test 4D tensors + test_broadcast_4D(); + test_broadcast_4D(); + test_broadcast_4D(); +} + // Broadcast tensor a and b's size to a new size c. TEST_F(OpMulOutTest, BroadcastAB2CTest) { TensorFactory tf_a;