From 81e150286d4dc9be4e1ec43f6cfed4ad970e955f Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 28 Oct 2024 08:48:05 -0700 Subject: [PATCH] [Executorch] enable sleef consistently Earlier only android platofrms had support for sleef Differential Revision: [D64571782](https://our.internmc.facebook.com/intern/diff/D64571782/) [ghstack-poisoned] --- extension/llm/custom_ops/targets.bzl | 8 ++- kernels/optimized/lib_defs.bzl | 67 +++++++++++----------- kernels/optimized/op_registration_util.bzl | 4 +- 3 files changed, 41 insertions(+), 38 deletions(-) diff --git a/extension/llm/custom_ops/targets.bzl b/extension/llm/custom_ops/targets.bzl index dcc172e2641..eea303c5a63 100644 --- a/extension/llm/custom_ops/targets.bzl +++ b/extension/llm/custom_ops/targets.bzl @@ -1,4 +1,9 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load( + "@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl", + "get_vec_preprocessor_flags", + "get_vec_deps", +) def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. @@ -21,6 +26,7 @@ def define_common_targets(): "op_sdpa.h", "op_update_quantized_cache.h", ], + preprocessor_flags = get_vec_preprocessor_flags(), exported_deps = [ "//executorch/runtime/kernel:kernel_includes", "//executorch/kernels/portable/cpu:scalar_utils", @@ -33,7 +39,7 @@ def define_common_targets(): deps = [ "//executorch/kernels/portable/cpu/util:reduce_util", "//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform", - ], + ] + get_vec_deps(), compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors", "-O2"], visibility = [ "//executorch/...", diff --git a/kernels/optimized/lib_defs.bzl b/kernels/optimized/lib_defs.bzl index 1284f87df25..2cd803ff0e2 100644 --- a/kernels/optimized/lib_defs.bzl +++ b/kernels/optimized/lib_defs.bzl @@ -11,15 +11,37 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") # functions in order to declare the required compiler flags needed in order to # access CPU vector intrinsics. -def get_vec_android_preprocessor_flags(): - preprocessor_flags = [ - ( - "^android-arm64.*$", - [ - "-DET_BUILD_ARM_VEC256_WITH_SLEEF", - ], - ), - ] +def get_vec_preprocessor_flags(): + preprocessor_flags = select({ + "ovr_config//os:iphoneos": [ + "-DET_BUILD_ARM_VEC256_WITH_SLEEF", + ] if not runtime.is_oss else [], + "ovr_config//os:macos-arm64": [ + "-DET_BUILD_ARM_VEC256_WITH_SLEEF", + ] if not runtime.is_oss else [], + "ovr_config//os:android-arm64": [ + "-DET_BUILD_ARM_VEC256_WITH_SLEEF", + ] if not runtime.is_oss else [], + "DEFAULT": [], + }) + return preprocessor_flags + +def get_vec_deps(): + preprocessor_flags = select({ + "ovr_config//os:linux-x86_64": [ + "fbsource//third-party/sleef:sleef", + ] if not runtime.is_oss else [], + "ovr_config//os:iphoneos": [ + "fbsource//third-party/sleef:sleef_arm", + ] if not runtime.is_oss else [], + "ovr_config//os:macos-arm64": [ + "fbsource//third-party/sleef:sleef_arm", + ] if not runtime.is_oss else [], + "ovr_config//os:android-arm64": [ + "fbsource//third-party/sleef:sleef_arm", + ] if not runtime.is_oss else [], + "DEFAULT": [], + }) return preprocessor_flags def get_vec_cxx_preprocessor_flags(): @@ -56,32 +78,7 @@ def define_libs(): "//executorch/...", "@EXECUTORCH_CLIENTS", ], - cxx_platform_deps = select({ - "DEFAULT": [ - ( - DEVSERVER_PLATFORM_REGEX, - [ - "fbsource//third-party/sleef:sleef", - ], - ), - ], - "ovr_config//cpu:arm64": [ - ( - DEVSERVER_PLATFORM_REGEX, - [ - "fbsource//third-party/sleef:sleef_arm", - ], - ), - ], - }), - fbandroid_platform_deps = [ - ( - "^android-arm64.*$", - [ - "fbsource//third-party/sleef:sleef_arm", - ], - ), - ], + deps = get_vec_deps(), ) runtime.cxx_library( diff --git a/kernels/optimized/op_registration_util.bzl b/kernels/optimized/op_registration_util.bzl index 887d5711eab..0ef58a023f7 100644 --- a/kernels/optimized/op_registration_util.bzl +++ b/kernels/optimized/op_registration_util.bzl @@ -2,7 +2,7 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load("@fbsource//xplat/executorch/build:selects.bzl", "selects") load( "@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl", - "get_vec_android_preprocessor_flags", + "get_vec_preprocessor_flags", ) def op_target(name, deps = []): @@ -91,7 +91,7 @@ def define_op_library(name, deps): deps = [ "//executorch/runtime/kernel:kernel_includes", ] + augmented_deps, - fbandroid_platform_preprocessor_flags = get_vec_android_preprocessor_flags(), + preprocessor_flags = get_vec_preprocessor_flags(), # sleef needs to be added as a direct dependency of the operator target when building for Android, # or a linker error may occur. Not sure why this happens; it seems that fbandroid_platform_deps of # dependencies are not transitive