diff --git a/extension/llm/custom_ops/targets.bzl b/extension/llm/custom_ops/targets.bzl index c2843f5c2f7..6b9f9cb959c 100644 --- a/extension/llm/custom_ops/targets.bzl +++ b/extension/llm/custom_ops/targets.bzl @@ -1,4 +1,9 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load( + "@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", + "get_compiler_optimization_flags", +) + def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. @@ -34,7 +39,7 @@ def define_common_targets(): "//executorch/kernels/portable/cpu/util:reduce_util", "//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform", ], - compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"], + compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"] + get_compiler_optimization_flags(), visibility = [ "//executorch/...", "//executorch/extension/llm/custom_ops/...", diff --git a/kernels/optimized/cpu/binary_ops.h b/kernels/optimized/cpu/binary_ops.h index 6d941509f72..d02153ea441 100644 --- a/kernels/optimized/cpu/binary_ops.h +++ b/kernels/optimized/cpu/binary_ops.h @@ -41,10 +41,62 @@ enum class ElementwiseOptimizedPath { kTreatAs1d, kBroadcast2dBy1d, kBroadcast2dBy1dReverseArguments, + kBroadcastNdByNd, + kBroadcastNdByNdReverseArguments, }; namespace internal { -inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path( + +// Find the single broadcast dimension if it exists. +// This path aims to handle broadcast of the following form +// A = [a1, a2,., 1, .., an] +// B = [b1, b2,., bm, .., bn] +// OR +// A = [a1, a2,., am, .., an] +// B = [b1, b2,., 1, .., bn] +int32_t inline get_broadcast_dim(const Tensor& lhs, const Tensor& rhs) { + auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes()); + auto lhs_end = lhs.sizes().end(); + + auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs.sizes()); + auto rhs_end = rhs.sizes().end(); + + const auto lhs_size = lhs_end - lhs_begin; + const auto rhs_size = rhs_end - rhs_begin; + + // Following example is not handled at the moment + // [1, 3, 4, 5] + // [2, 3, 4, 5] + if (lhs_size != rhs_size) { + return 0; + } + + int32_t broadcast_dim = 0; + // Check + // 1. if any dim value is 1 (it constitutes a broadcast dim) + // 2. If more than one dim value is 1 (we cannot handle) + // 3. If non-1 dim values are equal + lhs_end--; + rhs_end--; + while (lhs_end != lhs_begin) { + if (*lhs_end == 1 || *rhs_end == 1) { + // If more than one broadcast dim is found, return 0. + if (broadcast_dim != 0) { + return 0; + } + // negative index is used + broadcast_dim = lhs_end - lhs.sizes().end(); + } else if (*lhs_end != *rhs_end) { + // If non-1 dim values are not equal, return 0. + return 0; + } + lhs_end--; + rhs_end--; + } + return broadcast_dim; +} + +inline ElementwiseOptimizedPath select_broadcast_optimized_path( const Tensor& lhs, const Tensor& rhs) { auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes()); @@ -63,6 +115,17 @@ inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path( return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments; } + int32_t broadcast_dim = get_broadcast_dim(lhs, rhs); + // Right now we dont handle last dim broadcast + if (broadcast_dim < -1) { + if (std::count_if(rhs_begin, rhs_end, [](Tensor::SizesType x) { + return x == 1; + }) == 1) { + return ElementwiseOptimizedPath::kBroadcastNdByNd; + } else { + return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments; + } + } return ElementwiseOptimizedPath::kNone; } } // namespace internal @@ -85,7 +148,28 @@ ElementwiseOptimizedPath inline select_optimized_path( internal::sizes_match_ignoring_leading_1s(a.sizes(), b.sizes())))) { return ElementwiseOptimizedPath::kTreatAs1d; } - return internal::select_broadcast_2d_by_1d_optimized_path(a, b); + return internal::select_broadcast_optimized_path(a, b); +} + +std::array inline get_normalized_tensor_size( + const Tensor& a, + const int32_t broadcast_dim) { + ET_CHECK_MSG( + a.dim() > broadcast_dim, + "Size of tensor: %zd, must be larger than broadcast_dim: %d", + a.dim(), + broadcast_dim); + std::array normalized_tensor_size; + normalized_tensor_size[0] = 1; + normalized_tensor_size[1] = a.size(broadcast_dim); + normalized_tensor_size[2] = 1; + for (size_t i = 0; i < broadcast_dim; i++) { + normalized_tensor_size[0] *= a.size(i); + } + for (size_t i = broadcast_dim + 1; i < a.dim(); i++) { + normalized_tensor_size[2] *= a.size(i); + } + return normalized_tensor_size; } } // namespace executor diff --git a/kernels/optimized/cpu/op_mul.cpp b/kernels/optimized/cpu/op_mul.cpp index ad6034638f9..9208ca3b3a7 100644 --- a/kernels/optimized/cpu/op_mul.cpp +++ b/kernels/optimized/cpu/op_mul.cpp @@ -130,15 +130,19 @@ Tensor& opt_mul_out( } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { const Tensor* lhs; const Tensor* rhs; - if (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) { + if ((selected_optimized_path == + ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { lhs = &b; rhs = &a; } else { // Catch failure to update logic when adding new broadcasting possibility. ET_DCHECK( - selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1d); + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcast2dBy1d) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNd)); lhs = &a; rhs = &b; } @@ -149,15 +153,34 @@ Tensor& opt_mul_out( InvalidArgument, out, "Failed to resize output tensor."); + int64_t outer_size = 1; + int64_t broadcast_size; + int64_t inner_size; + if ((selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNd) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { + int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs); + int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim; + auto normalized_tensor_size_lhs = + get_normalized_tensor_size(*lhs, broadcast_dim_lhs); + outer_size = normalized_tensor_size_lhs[0]; + broadcast_size = normalized_tensor_size_lhs[1]; + inner_size = normalized_tensor_size_lhs[2]; + } else { + broadcast_size = lhs->sizes()[lhs->dim() - 2]; + inner_size = lhs->sizes()[lhs->dim() - 1]; + } ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { using Vec = executorch::vec::Vectorized; - executorch::vec::broadcasting_map_2d_by_1d( + executorch::vec::broadcasting_map_3d_and_unsqueezed_3d( [](Vec x, Vec y) { return x * y; }, out.mutable_data_ptr(), lhs->const_data_ptr(), rhs->const_data_ptr(), - lhs->sizes()[lhs->dim() - 2], - lhs->sizes()[lhs->dim() - 1]); + outer_size, + broadcast_size, + inner_size); }); } else { ScalarType common_type = diff --git a/kernels/optimized/lib_defs.bzl b/kernels/optimized/lib_defs.bzl index d9721e5055d..c3799f7db51 100644 --- a/kernels/optimized/lib_defs.bzl +++ b/kernels/optimized/lib_defs.bzl @@ -2,6 +2,10 @@ load("@fbsource//tools/build_defs:default_platform_defs.bzl", "DEVSERVER_PLATFOR load("@fbsource//tools/build_defs:fb_native_wrapper.bzl", "fb_native") load("@fbsource//xplat/executorch/backends/xnnpack/third-party:third_party_libs.bzl", "third_party_dep") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load( + "@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", + "get_compiler_optimization_flags", +) # Because vec exists as a collection of header files, compile and preprocessor # flags applied to the vec target do not have any effect, since no compilation @@ -121,6 +125,7 @@ def define_libs(): exported_headers = native.glob([ "blas/**/*.h", ]), + compiler_flags = get_compiler_optimization_flags(), header_namespace = "executorch/kernels/optimized", visibility = [ "//executorch/...", diff --git a/kernels/optimized/op_registration_util.bzl b/kernels/optimized/op_registration_util.bzl index c969aa81a9a..6e74836bb79 100644 --- a/kernels/optimized/op_registration_util.bzl +++ b/kernels/optimized/op_registration_util.bzl @@ -4,6 +4,10 @@ load( "@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl", "get_vec_android_preprocessor_flags", ) +load( + "@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", + "get_compiler_optimization_flags", +) def op_target(name, deps = []): """Registers an optimized implementation for an operator overload group. @@ -87,7 +91,7 @@ def define_op_library(name, deps): ], # kernels often have helpers with no prototypes just disabling the warning here as the headers # are codegend and linked in later - compiler_flags = ["-Wno-missing-prototypes"], + compiler_flags = ["-Wno-missing-prototypes"] + get_compiler_optimization_flags(), deps = [ "//executorch/runtime/kernel:kernel_includes", ] + augmented_deps, diff --git a/kernels/optimized/vec/functional_base.h b/kernels/optimized/vec/functional_base.h index 7edb043abc9..6f9fcd8e796 100644 --- a/kernels/optimized/vec/functional_base.h +++ b/kernels/optimized/vec/functional_base.h @@ -326,10 +326,49 @@ inline void map4( } -// Map vec_fun across input_data and input_data2, where input_data is -// a two-dimensional array of size (size, size2), input_data2 is a -// one-dimensional array of size size2, and input_data2 is broadcast -// to be of size (size, size2). +// This function implements broadcasting binary operation on two tensors +// where lhs tensor is treated to be of shape [outer_size, broadcast_size, inner_size] +// and rhs tensor is treated to be of shape [outer_size, 1, inner_size] +// And this 1st dimension is considered broadcasting dimension +// This formula can map broadcasting on any dim=broadcast_dim +// for any two N dimensional tensors, where 0 < braodcast_dim < N-1 +template +inline void broadcasting_map_3d_and_unsqueezed_3d( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* lhs, + const scalar_t* rhs, + int64_t outer_size, + int64_t broadcast_size, + int64_t inner_size) { + using Vec = vec::Vectorized; + int64_t outer_stride_lhs = inner_size * broadcast_size; + int64_t outer_stride_rhs = inner_size; + int64_t broadcast_stride_lhs = inner_size; + for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { + const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs; + scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs; + const scalar_t* rhs_outer = rhs + outer_idx * outer_stride_rhs; + for (int64_t broadcast_idx = 0; broadcast_idx < broadcast_size; ++broadcast_idx) { + const scalar_t* lhs_outer_2 = lhs_outer + broadcast_idx * broadcast_stride_lhs; + scalar_t* output_data_row_2 = output_data_row + broadcast_idx * broadcast_stride_lhs; + int64_t inner_idx = 0; + for (; inner_idx < inner_size - (inner_size % Vec::size()); inner_idx += Vec::size()) { + Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx); + Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx); + Vec output_vec = vec_fun(data_vec, data_vec2); + output_vec.store(output_data_row_2 + inner_idx); + } + if (inner_size - inner_idx > 0) { + Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx, inner_size - inner_idx); + Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx, inner_size - inner_idx); + Vec output_vec = vec_fun(data_vec, data_vec2); + output_vec.store(output_data_row_2 + inner_idx, inner_size - inner_idx); + } + } + } +} + template inline void broadcasting_map_2d_by_1d( const Op& vec_fun, @@ -338,27 +377,8 @@ inline void broadcasting_map_2d_by_1d( const scalar_t* input_data2, int64_t size, int64_t size2) { - using Vec = vec::Vectorized; - for (int64_t outer_idx = 0; outer_idx < size; ++outer_idx) { - const scalar_t* input_data_row = input_data + outer_idx * size2; - scalar_t* output_data_row = output_data + outer_idx * size2; - int64_t inner_idx = 0; - for (; inner_idx < size2 - (size2 % Vec::size()); inner_idx += Vec::size()) { - Vec data_vec = Vec::loadu(input_data_row + inner_idx); - Vec data_vec2 = Vec::loadu(input_data2 + inner_idx); - Vec output_vec = vec_fun(data_vec, data_vec2); - output_vec.store(output_data_row + inner_idx); - } - if (size2 - inner_idx > 0) { - Vec data_vec = Vec::loadu(input_data_row + inner_idx, size2 - inner_idx); - Vec data_vec2 = Vec::loadu(input_data2 + inner_idx, size2 - inner_idx); - Vec output_vec = vec_fun(data_vec, data_vec2); - output_vec.store(output_data_row + inner_idx, size2 - inner_idx); - } - } + broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2); } - - } // namespace vec } // namespace executorch diff --git a/kernels/test/op_mul_test.cpp b/kernels/test/op_mul_test.cpp index f3c9e54c862..6430bfbd8f5 100644 --- a/kernels/test/op_mul_test.cpp +++ b/kernels/test/op_mul_test.cpp @@ -153,6 +153,73 @@ class OpMulOutTest : public OperatorTest { } } + template + void test_broadcast_3D() { + TensorFactory tf_a; + + Tensor a = + tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Tensor b = tf_a.make({2, 1, 3}, /*data=*/{2, 3, 4, 5, 6, 7}); + + // Destination for output of mul. + Tensor out = + tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Tensor expected = tf_a.make( + {2, 2, 3}, /*data=*/{2, 6, 12, 8, 15, 24, 35, 48, 63, 50, 66, 84}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected); + EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected); + } + + template + void test_broadcast_4D() { + TensorFactory tf_a; + + Tensor a = tf_a.make( + {2, 2, 3, 5}, + /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60}); + Tensor b = tf_a.make( + {2, 1, 3, 5}, + /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}); + + // Destination for output of mul. + Tensor out = tf_a.zeros({2, 2, 3, 5}); + Tensor expected = tf_a.make( + {2, 2, 3, 5}, + /*data=*/{1, 4, 9, 16, 25, 36, 49, 64, 81, 100, + 121, 144, 169, 196, 225, 16, 34, 54, 76, 100, + 126, 154, 184, 216, 250, 286, 324, 364, 406, 450, + 496, 544, 594, 646, 700, 756, 814, 874, 936, 1000, + 1066, 1134, 1204, 1276, 1350, 736, 799, 864, 931, 1000, + 1071, 1144, 1219, 1296, 1375, 1456, 1539, 1624, 1711, 1800}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected); + EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected); + + b = tf_a.make( + {2, 2, 1, 5}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}); + out = tf_a.zeros({2, 2, 3, 5}); + expected = tf_a.make( + {2, 2, 3, 5}, + /*data=*/{1, 4, 9, 16, 25, 6, 14, 24, 36, 50, + 11, 24, 39, 56, 75, 96, 119, 144, 171, 200, + 126, 154, 184, 216, 250, 156, 189, 224, 261, 300, + 341, 384, 429, 476, 525, 396, 444, 494, 546, 600, + 451, 504, 559, 616, 675, 736, 799, 864, 931, 1000, + 816, 884, 954, 1026, 1100, 896, 969, 1044, 1121, 1200}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected); + EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected); + } + template void test_broadcast_b2a() { TensorFactory tf_a; @@ -296,6 +363,16 @@ TEST_F(OpMulOutTest, BroadcastA2BTest) { test_broadcast_a2b(); test_broadcast_a2b(); test_broadcast_a2b(); + + // Test 3D tensors + test_broadcast_3D(); + test_broadcast_3D(); + test_broadcast_3D(); + + // Test 4D tensors + test_broadcast_4D(); + test_broadcast_4D(); + test_broadcast_4D(); } // Broadcast tensor a's size to tensor b's size @@ -305,6 +382,18 @@ TEST_F(OpMulOutTest, BroadcastB2ATest) { test_broadcast_b2a(); } +TEST_F(OpMulOutTest, BroadcastNDTest) { + // Test 3D tensors + test_broadcast_3D(); + test_broadcast_3D(); + test_broadcast_3D(); + + // Test 4D tensors + test_broadcast_4D(); + test_broadcast_4D(); + test_broadcast_4D(); +} + // Broadcast tensor a and b's size to a new size c. TEST_F(OpMulOutTest, BroadcastAB2CTest) { TensorFactory tf_a; diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index ef170d62970..6a25f35c304 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -1,6 +1,24 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "is_xplat", "runtime") load("@fbsource//xplat/executorch/build:selects.bzl", "selects") +def get_compiler_optimization_flags(): + # various ovr_configs are not available in oss + if not runtime.is_oss: + compiler_flags = select({ + "DEFAULT": [], + "ovr_config//os:android-arm64": [ + "-O2", + ], + "ovr_config//os:iphoneos": [ + "-O2", + ], + "ovr_config//os:macos-arm64": [ + "-O2", + ], + }) + return compiler_flags + return [] + def op_target(name, deps = [], android_deps = [], _allow_third_party_deps = False, _aten_mode_deps = []): """Registers an implementation of an operator overload group. @@ -132,7 +150,7 @@ def define_op_library(name, deps, android_deps, aten_target, _allow_third_party_ # library, and it blocks users like unit tests to use kernel # implementation directly. So we enable this for xplat only. ["-fvisibility=hidden"] if is_xplat() else [] - ), + ) + get_compiler_optimization_flags(), deps = [ "//executorch/runtime/kernel:kernel_includes" + aten_suffix, ] + deps,