From d6e35a3caa9548b853aa662ede60a19728260b81 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 18 Sep 2025 20:18:26 -0700 Subject: [PATCH] Add selective build support for prim ops (#14332) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: This diff implements selective build functionality for primitive operations (prim ops) in ExecutorTorch, allowing users to include only specific prim ops in their builds to reduce binary size and compilation time. ## Key Changes: 1. **Conditional compilation in register_prim_ops.cpp**: Wrapped each of the prim op registrations with conditional compilation macros that check both selective build enablement (`EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD`) and individual op selection (e.g., `INCLUDE_EXECUTORCH_PRIM_ET_VIEW_DEFAULT`). 2. **Code generation tool**: Added `gen_selected_prim_ops.py` that takes comma-separated prim op names and generates a header file (`selected_prim_ops.h`) containing appropriate `#define` statements for selected ops. The tool normalizes op names to macro-safe format (e.g., `executorch_prim::et_view.default` → `INCLUDE_EXECUTORCH_PRIM_ET_VIEW_DEFAULT`). 3. **Build system integration**: In order to make et_operator_library also handle prim of selective build we make a few changes. 1. Extract prim ops in et_operator_library 2. Similar to gen_op_list, we invoke script that geneates selected_prim_ops.h file per et_operator_library target. Thus et_operator_library now generates selected_operators.yaml and selected_prim_ops.h. Note that in order to make these work we have to allow et_operator_libray to handle the following cases. 1. All ops are aten ops 2. All ops are prim ops 3. Mix To do this we must make sure that the genrule continues to produce the file it says it will produce. In the case of 1 we have to produce empty selected_prim_opsh. and in case 2 we have to produce emtpy selected_operators.yaml 3. In gen_all_oplist we allow for empty selected_operators.yaml and skip the file. 4. Similar to gen_all_oplist we introduce another binary that combines all selected_prim_ops.h. 5. Then in executorch_generated_lib we query targets from 4 that have selected_prim_ops and use those to compile register_prim_ops.cpp. In executorch_generate_lib we introduce include_all_prim_ops which by default is True. Hence if one wants to enable selective build for prim ops one must turn off that flag ## Usage: Users can now specify prim ops like: ``` et_operator_library(name="my_aten_prim_ops", ops=["aten::mul.out", "executorch_prim::et_view.default", "aten::sym_size.int"]) executorch_generated_lib(name="my_lib", deps=[":my_aten_prim_ops"] + other_deps, include_all_prim_ops=False) ``` Reviewed By: ivayloen, larryliu0820 Differential Revision: D81648030 --- codegen/tools/combine_prim_ops_headers.py | 164 +++++++++++++++++++ codegen/tools/gen_all_oplist.py | 20 ++- codegen/tools/gen_oplist.py | 20 ++- codegen/tools/gen_selected_prim_ops.py | 96 +++++++++++ codegen/tools/targets.bzl | 41 +++++ codegen/tools/test/test_gen_oplist.py | 11 +- examples/selective_build/targets.bzl | 114 +++++++++++++ kernels/prim_ops/register_prim_ops.cpp | 91 +++++++++- kernels/prim_ops/selective_build_prim_ops.h | 12 ++ kernels/prim_ops/targets.bzl | 27 ++- shim_et/xplat/executorch/codegen/codegen.bzl | 160 +++++++++++++++++- 11 files changed, 734 insertions(+), 22 deletions(-) create mode 100644 codegen/tools/combine_prim_ops_headers.py create mode 100644 codegen/tools/gen_selected_prim_ops.py create mode 100644 kernels/prim_ops/selective_build_prim_ops.h diff --git a/codegen/tools/combine_prim_ops_headers.py b/codegen/tools/combine_prim_ops_headers.py new file mode 100644 index 00000000000..b579de2047d --- /dev/null +++ b/codegen/tools/combine_prim_ops_headers.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Script to combine multiple selected_prim_ops.h header files into a single header. +This is used by selected_prim_operators_genrule to merge prim ops headers from dependencies. +""" + +import argparse +import os +import sys +from pathlib import Path +from typing import List, Set + + +def read_header_file(file_path: Path) -> Set[str]: + """ + Read a selected_prim_ops.h file and extract the macros and comments. + + Args: + file_path: Path to the header file + + Returns: + macros_set where macros_set contains unique macro defines + """ + macros = set() + + try: + with open(file_path, "r") as f: + for line in f: + line = line.strip() + + # Extract #define statements for prim ops + if line.startswith("#define INCLUDE_") and not line.startswith( + "#define EXECUTORCH_ENABLE" + ): + macros.add(line) + except FileNotFoundError: + print(f"Warning: Header file not found: {file_path}", file=sys.stderr) + except Exception as e: + print(f"Error reading {file_path}: {e}", file=sys.stderr) + + return macros + + +def combine_prim_ops_headers(header_file_paths: List[str], output_path: str) -> None: + """ + Combine multiple selected_prim_ops.h files into a single header. + + Args: + header_files: List of paths to header files to combine + output_path: Path to output the combined header + """ + all_macros = set() + has_selective_build = False + + # Read all header files and collect unique macros + for header_file_path in header_file_paths: + header_file = Path(header_file_path) / "selected_prim_ops.h" + if os.path.exists(header_file): + macros = read_header_file(header_file) + all_macros.update(macros) + if len(all_macros) > 0: + has_selective_build = True + else: + print( + f"Warning: Header file does not exist: {header_file}", file=sys.stderr + ) + + # Generate combined header + header_content = [ + "// Combined header for selective prim ops build", + "// This file is auto-generated by combining multiple selected_prim_ops.h files", + "// Do not edit manually.", + "", + "#pragma once", + "", + ] + + if all_macros and has_selective_build: + header_content.extend( + [ + "// Enable selective build for prim ops", + "#define EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD", + "", + "// Combined prim ops macros from all dependencies", + ] + ) + + # Sort macros for deterministic output + sorted_macros = sorted(all_macros) + header_content.extend(sorted_macros) + else: + header_content.extend( + [ + "// No prim ops found in dependencies - all prim ops will be included", + "// Selective build is disabled", + ] + ) + + header_content.append("") + + # Write the combined header + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "w") as f: + f.write("\n".join(header_content)) + + +def _get_header_file_paths_from_query_output(query_output_file: str) -> List[str]: + """ + Parse the output of a Buck query command to extract header file paths. + + Args: + query_output_file: Path to the file containing the query output + + Returns: + List of header file paths + """ + header_file_paths = [] + assert ( + query_output_file[0] == "@" + ), "query_output_file is not a valid file path, or it doesn't start with '@'." + query_output_file = query_output_file[1:] + + with open(query_output_file, "r") as f: + for line in f: + # Extract the header file path from the query output + header_file_paths += line.split() + return header_file_paths + + +def main(): + parser = argparse.ArgumentParser( + description="Combine multiple selected_prim_ops.h header files" + ) + parser.add_argument( + "--header_files", + required=True, + help="Comma-separated list of header file paths", + ) + parser.add_argument( + "--output_dir", required=True, help="Output directory for combined header" + ) + + args = parser.parse_args() + import os + + header_file_paths = _get_header_file_paths_from_query_output(args.header_files) + + if not header_file_paths: + print("Error: No header files provided", file=sys.stderr) + sys.exit(1) + + # Generate output path + output_path = os.path.join(args.output_dir, "selected_prim_ops.h") + + combine_prim_ops_headers(header_file_paths, output_path) + + +if __name__ == "__main__": + main() diff --git a/codegen/tools/gen_all_oplist.py b/codegen/tools/gen_all_oplist.py index 5cb93bb9153..f33c3dc935d 100644 --- a/codegen/tools/gen_all_oplist.py +++ b/codegen/tools/gen_all_oplist.py @@ -10,7 +10,7 @@ import sys from functools import reduce from pathlib import Path -from typing import Any, List +from typing import Any, Dict, List import yaml from torchgen.selective_build.selector import ( @@ -72,6 +72,19 @@ def _raise_if_check_prim_ops_fail(options): raise Exception(error) +def _selected_ops_model_dict_is_empty(model_dict: Dict[str, Any]) -> bool: + return ( + not model_dict.get("build_features", []) + and not model_dict.get("custom_classes", []) + and not model_dict.get("et_kernel_metadata", None) + and not model_dict.get("include_all_non_op_selectives", False) + and not model_dict.get("include_all_operators", False) + and not model_dict.get("kernel_metadata", {}) + and not model_dict.get("operators", {}) + ) + + +# flake8: noqa: C901 def main(argv: List[Any]) -> None: """This binary generates 3 files: @@ -171,6 +184,11 @@ def main(argv: List[Any]) -> None: ), f"{model_file_name} is not a valid file path. This is likely a BUCK issue." with open(model_file_name, "rb") as model_file: model_dict = yaml.safe_load(model_file) + # It is possible that we created an empty yaml file. + # This is because et_operator_library may only contain prim ops. + # In that case selected_operators.yaml will be empty. + if _selected_ops_model_dict_is_empty(model_dict): + continue resolved = resolve_model_file_path_to_buck_target(model_file_name) for op in model_dict["operators"]: model_dict["operators"][op]["debug_info"] = [resolved] diff --git a/codegen/tools/gen_oplist.py b/codegen/tools/gen_oplist.py index cca5bf1b1d2..28506050a8e 100644 --- a/codegen/tools/gen_oplist.py +++ b/codegen/tools/gen_oplist.py @@ -9,6 +9,7 @@ import os import sys from enum import IntEnum +from pathlib import Path from typing import Any, Dict, List, Optional, Set import yaml @@ -158,7 +159,7 @@ def _get_et_kernel_metadata_from_ops_yaml(ops_yaml_path: str) -> Dict[str, List[ def _dump_yaml( op_list: List[str], - output_path: str, + output_path: Path, model_name: Optional[str] = None, et_kernel_metadata: Optional[Dict[str, List[str]]] = None, include_all_operators: bool = False, @@ -212,20 +213,23 @@ def create_kernel_key(maybe_kernel_key: str) -> str: def gen_oplist( - output_path: str, + output_path: Path, model_file_path: Optional[str] = None, ops_schema_yaml_path: Optional[str] = None, root_ops: Optional[str] = None, ops_dict: Optional[str] = None, include_all_operators: bool = False, ): - assert ( + if not ( model_file_path or ops_schema_yaml_path or root_ops or ops_dict or include_all_operators - ), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or ops_dict or include_all_operators." + ): + # dump empty yaml file + _dump_yaml([], output_path) + return assert output_path, "Need to provide output_path for dumped yaml file." op_set = set() @@ -326,9 +330,15 @@ def main(args: List[Any]) -> None: ) options = parser.parse_args(args) + # check if the output_path is a directory, then generate operators + # under selected_operators.yaml + if Path(options.output_path).is_dir(): + output_path = Path(options.output_path) / "selected_operators.yaml" + else: + output_path = Path(options.output_path) try: gen_oplist( - output_path=options.output_path, + output_path=output_path, model_file_path=options.model_file_path, ops_schema_yaml_path=options.ops_schema_yaml_path, root_ops=options.root_ops, diff --git a/codegen/tools/gen_selected_prim_ops.py b/codegen/tools/gen_selected_prim_ops.py new file mode 100644 index 00000000000..4535ffaa57a --- /dev/null +++ b/codegen/tools/gen_selected_prim_ops.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import argparse +import os +import sys +from typing import Any, List + +from torchgen.code_template import CodeTemplate # type: ignore[import-not-found] + + +selected_prim_ops_h_template_str = """#pragma once +/** + * Generated by executorch/codegen/tools/gen_selected_prim_ops.py + */ + +$defines +""" +selected_prim_ops_h_template = CodeTemplate(selected_prim_ops_h_template_str) + + +def normalize_op_name(op_name: str) -> str: + """ + Normalize an operator name to a macro-safe format. + Convert op names like "executorch_prim::et_view.default" to "EXECUTORCH_PRIM_ET_VIEW_DEFAULT" + or "aten::sym_size.int" to "ATEN_SYM_SIZE_INT" + """ + # Remove namespace separator and replace with underscore + normalized = op_name.replace("::", "_") + # Replace dots with underscores + normalized = normalized.replace(".", "_") + # Convert to uppercase + normalized = normalized.upper() + # Add INCLUDE_ prefix + normalized = f"INCLUDE_{normalized}" + return normalized + + +def write_selected_prim_ops(prim_op_names: List[str], output_dir: str) -> None: + """ + Generate selected_prim_ops.h from a list of prim op names. + + Args: + prim_op_names: List of prim op names like ["executorch_prim::et_view.default", "aten::sym_size.int"] + output_dir: Directory where to write selected_prim_ops.h + """ + # Generate #define statements for each op + defines = [] + for op_name in prim_op_names: + macro_name = normalize_op_name(op_name) + defines.append(f"#define {macro_name}") + + # Join all defines with newlines + defines_str = "\n".join(defines) + + # Generate header content + header_contents = selected_prim_ops_h_template.substitute(defines=defines_str) + + # Write to file + selected_prim_ops_path = os.path.join(output_dir, "selected_prim_ops.h") + with open(selected_prim_ops_path, "wb") as out_file: + out_file.write(header_contents.encode("utf-8")) + + +def main(argv: List[Any]) -> None: + parser = argparse.ArgumentParser(description="Generate selected prim ops header") + parser.add_argument( + "--prim-op-names", + "--prim_op_names", + help="Comma-separated list of prim op names to include", + required=True, + ) + parser.add_argument( + "--output-dir", + "--output_dir", + help="The directory to store the output header file (selected_prim_ops.h)", + required=True, + ) + + options = parser.parse_args(argv) + + # Parse comma-separated prim op names + prim_op_names = [ + name.strip() for name in options.prim_op_names.split(",") if name.strip() + ] + + write_selected_prim_ops(prim_op_names, options.output_dir) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/codegen/tools/targets.bzl b/codegen/tools/targets.bzl index acea3370e7d..d594b7178b8 100644 --- a/codegen/tools/targets.bzl +++ b/codegen/tools/targets.bzl @@ -103,6 +103,26 @@ def define_common_targets(is_fbcode = False): _is_external_target = True, ) + runtime.python_library( + name = "combine_prim_ops_headers_lib", + srcs = ["combine_prim_ops_headers.py"], + base_module = "executorch.codegen.tools", + visibility = ["//executorch/..."], + ) + + runtime.python_binary( + name = "combine_prim_ops_headers", + main_module = "executorch.codegen.tools.combine_prim_ops_headers", + package_style = "inplace", + visibility = [ + "PUBLIC", + ], + deps = [ + ":combine_prim_ops_headers_lib", + ], + _is_external_target = True, + ) + runtime.python_test( name = "test_gen_all_oplist", srcs = [ @@ -155,6 +175,27 @@ def define_common_targets(is_fbcode = False): _is_external_target = True, ) + runtime.python_library( + name = "gen_selected_prim_ops_lib", + srcs = ["gen_selected_prim_ops.py"], + base_module = "executorch.codegen.tools", + visibility = ["//executorch/..."], + external_deps = ["torchgen"], + ) + + runtime.python_binary( + name = "gen_selected_prim_ops", + main_module = "executorch.codegen.tools.gen_selected_prim_ops", + package_style = "inplace", + visibility = [ + "PUBLIC", + ], + deps = [ + ":gen_selected_prim_ops_lib", + ], + _is_external_target = True, + ) + if not runtime.is_oss: runtime.cxx_python_extension( name = "selective_build", diff --git a/codegen/tools/test/test_gen_oplist.py b/codegen/tools/test/test_gen_oplist.py index f5c6829d6a0..18689cd2505 100644 --- a/codegen/tools/test/test_gen_oplist.py +++ b/codegen/tools/test/test_gen_oplist.py @@ -8,6 +8,7 @@ import os import tempfile import unittest +from pathlib import Path from typing import Dict, List from unittest.mock import NonCallableMock, patch @@ -77,7 +78,7 @@ def test_gen_op_list_with_valid_root_ops( gen_oplist.main(args) mock_dump_yaml.assert_called_once_with( ["aten::add", "aten::mul"], - output_path, + Path(output_path), None, {"aten::add": ["default"], "aten::mul": ["default"]}, False, @@ -100,7 +101,7 @@ def test_gen_op_list_with_root_ops_and_dtypes( gen_oplist.main(args) mock_dump_yaml.assert_called_once_with( ["aten::add", "aten::mul"], - output_path, + Path(output_path), None, { "aten::add": [ @@ -129,7 +130,7 @@ def test_gen_op_list_with_both_op_list_and_ops_schema_yaml_merges( gen_oplist.main(args) mock_dump_yaml.assert_called_once_with( ["aten::add.out", "aten::mul.out", "aten::relu.out"], - output_path, + Path(output_path), test_path, { "aten::relu.out": ["default"], @@ -153,7 +154,7 @@ def test_gen_op_list_with_include_all_operators( gen_oplist.main(args) mock_dump_yaml.assert_called_once_with( ["aten::add", "aten::mul"], - output_path, + Path(output_path), None, {"aten::add": ["default"], "aten::mul": ["default"]}, True, @@ -164,7 +165,7 @@ def test_get_custom_build_selector_with_both_allowlist_and_yaml( ) -> None: op_list = ["aten::add", "aten::mul"] filename = os.path.join(self.temp_dir.name, "selected_operators.yaml") - gen_oplist._dump_yaml(op_list, filename, "model.pte") + gen_oplist._dump_yaml(op_list, Path(filename), "model.pte") self.assertTrue(os.path.isfile(filename)) with open(filename) as f: es = yaml.safe_load(f) diff --git a/examples/selective_build/targets.bzl b/examples/selective_build/targets.bzl index 72639fef842..bd11a53e3e0 100644 --- a/examples/selective_build/targets.bzl +++ b/examples/selective_build/targets.bzl @@ -1,6 +1,118 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_oss_build_kwargs", "is_xplat", "runtime") load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib", "ScalarType") +def define_selective_build_prim_ops_example(): + """ + Example showing how selected_prim_operators_genrule works to combine + prim ops headers from multiple dependencies. + """ + + # Define several operator libraries with automatic prim ops extraction + et_operator_library( + name = "model_a_ops", + ops = [ + "aten::add.out", + "aten::mul.out", + "executorch_prim::et_view.default", # Auto-extracted to prim ops + "aten::sym_size.int", # Auto-extracted to prim ops + ], + visibility = ["//executorch/..."], + ) + # This creates: "model_a_ops" + "model_a_ops_selected_prim_ops" + + et_operator_library( + name = "model_b_ops", + ops = [ + "aten::sub.out", + "aten::div.out", + "executorch_prim::add.Scalar", # Auto-extracted to prim ops + "aten::sym_numel.int", # Auto-extracted to prim ops + ], + visibility = ["//executorch/..."], + ) + # This creates: "model_b_ops" + "model_b_ops_selected_prim_ops" + + # Define a manual prim ops target as well + et_operator_library( + name = "extra_prim_ops", + ops = [ + "executorch_prim::mul.Scalar", + "executorch_prim::sym_max.Scalar", + ], + visibility = ["//executorch/..."], + ) + # Use the combined header in an executorch_generated_lib + executorch_generated_lib( + name = "library_with_combined_prim_ops", + deps = [ + ":model_a_ops", + ":model_b_ops", + ":extra_prim_ops", + ], + kernel_deps = [ + "//executorch/kernels/portable:operators", + ], + functions_yaml_target = "//executorch/kernels/portable:functions.yaml", + aten_mode = False, + visibility = ["PUBLIC"], + include_all_prim_ops = False, + ) + + # Prim ops selected separately + et_operator_library( + name = "model_b_ops_no_prim_ops", + ops = [ + "aten::sub.out", + "aten::div.out", + ], + visibility = ["//executorch/..."], + ) + + # Use the combined header in an executorch_generated_lib + executorch_generated_lib( + name = "library_with_combined_prim_ops_1", + deps = [ + ":model_b_ops_no_prim_ops", + ":extra_prim_ops", + ], + kernel_deps = [ + "//executorch/kernels/portable:operators", + ], + functions_yaml_target = "//executorch/kernels/portable:functions.yaml", + aten_mode = False, + visibility = ["PUBLIC"], + include_all_prim_ops = False, + ) + + # No prim ops selected. So include all prim ops. + executorch_generated_lib( + name = "library_with_combined_prim_ops_2", + deps = [ + ":model_b_ops_no_prim_ops", + ], + kernel_deps = [ + "//executorch/kernels/portable:operators", + ], + functions_yaml_target = "//executorch/kernels/portable:functions.yaml", + aten_mode = False, + visibility = ["PUBLIC"], + include_all_prim_ops = False, + ) + + # default to selecting all prim ops + executorch_generated_lib( + name = "library_with_all_prim_ops", + deps = [ + ":model_b_ops", + ], + kernel_deps = [ + "//executorch/kernels/portable:operators", + ], + functions_yaml_target = "//executorch/kernels/portable:functions.yaml", + aten_mode = False, + visibility = ["PUBLIC"], + ) + def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. @@ -165,3 +277,5 @@ def define_common_targets(): define_static_target = True, **get_oss_build_kwargs() ) + + define_selective_build_prim_ops_example() diff --git a/kernels/prim_ops/register_prim_ops.cpp b/kernels/prim_ops/register_prim_ops.cpp index 8607c36204d..dc6ed9ac26f 100644 --- a/kernels/prim_ops/register_prim_ops.cpp +++ b/kernels/prim_ops/register_prim_ops.cpp @@ -12,6 +12,18 @@ #include #include +/* +For internal builds using buck rules, the target that depends on +selective prim ops, will manage its own artifacts. It is in the +artifacts directory where the geneated selected_prim_ops.h resides +and thus compilation sources must be copied there including +selective_build_prim_ops.h. Hence it does not have fully qualified +name unlike the header files above. +*/ +#ifdef ET_PRIM_OPS_SELECTIVE_BUILD +#include "selective_build_prim_ops.h" +#endif + #include #include @@ -87,6 +99,8 @@ void floor_div_double(double a, double b, EValue& out) { } static Kernel prim_ops[] = { +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_ATEN_SYM_SIZE_INT) // aten::sym_size.int(Tensor self, int dim) -> SymInt Kernel( "aten::sym_size.int", @@ -108,6 +122,9 @@ static Kernel prim_ops[] = { int64_t size = self_tensor.size(dim_val); out = EValue(size); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_ATEN_LOCAL_SCALAR_DENSE) // aten::_local_scalar_dense(Tensor self) -> Scalar Kernel( "aten::_local_scalar_dense", @@ -134,6 +151,9 @@ static Kernel prim_ops[] = { out = EValue(Scalar(self_tensor.const_data_ptr()[0])); }); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_ATEN_SYM_NUMEL) // aten::sym_numel(Tensor self) -> SymInt Kernel( "aten::sym_numel", @@ -153,6 +173,9 @@ static Kernel prim_ops[] = { int64_t numel = self_tensor.numel(); out = EValue(numel); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_SYM_MAX_SCALAR) // executorch_prim::sym_max.Scalar(SymInt a, SymInt b) -> SymInt Kernel( "executorch_prim::sym_max.Scalar", @@ -182,6 +205,9 @@ static Kernel prim_ops[] = { (size_t)b.tag); } }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_SYM_MIN_SCALAR) // executorch_prim::sym_min.Scalar(SymInt a, SymInt b) -> SymInt Kernel( "executorch_prim::sym_min.Scalar", @@ -210,27 +236,39 @@ static Kernel prim_ops[] = { (size_t)b.tag); } }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_ADD_SCALAR) // executorch_prim::add.Scalar(Scalar, Scalar) -> Scalar Kernel( "executorch_prim::add.Scalar", [](KernelRuntimeContext& context, Span stack) { ALGEBRA_ET_PRIM_OP(+, stack, context); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_SUB_SCALAR) // executorch_prim::sub.Scalar(Scalar, Scalar) -> Scalar Kernel( "executorch_prim::sub.Scalar", [](KernelRuntimeContext& context, Span stack) { ALGEBRA_ET_PRIM_OP(-, stack, context); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_MUL_SCALAR) // executorch_prim::mul.Scalar(Scalar, Scalar) -> Scalar Kernel( "executorch_prim::mul.Scalar", [](KernelRuntimeContext& context, Span stack) { ALGEBRA_ET_PRIM_OP(*, stack, context); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_FLOORDIV_SCALAR) /** * Python's __floordiv__ operator is more complicated than just floor(a / * b). It aims to maintain the property: a == (a // b) * b + remainder(a, b) @@ -280,8 +318,11 @@ static Kernel prim_ops[] = { (size_t)b.tag); } }), +#endif - // executorch_prim::floordiv.Scalar(Scalar, Scalar) -> Scalar +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_TRUEDIV_SCALAR) + // executorch_prim::truediv.Scalar(Scalar, Scalar) -> Scalar Kernel( "executorch_prim::truediv.Scalar", [](KernelRuntimeContext& context, Span stack) { @@ -318,7 +359,10 @@ static Kernel prim_ops[] = { (size_t)b.tag); } }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_SYM_FLOAT_SCALAR) // executorch_prim::sym_float.Scalar(Scalar) -> Scalar Kernel( "executorch_prim::sym_float.Scalar", @@ -346,41 +390,60 @@ static Kernel prim_ops[] = { context, false, InvalidType, /* void */, "%zu", (size_t)a.tag); } }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_EQ_SCALAR) // executorch_prim::eq.Scalar(Scalar, Scalar) -> bool Kernel( "executorch_prim::eq.Scalar", [](KernelRuntimeContext& context, Span stack) { BOOLEAN_ET_PRIM_OP(==, stack, context); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_GT_SCALAR) // executorch_prim::gt.Scalar(Scalar, Scalar) -> bool Kernel( "executorch_prim::gt.Scalar", [](KernelRuntimeContext& context, Span stack) { BOOLEAN_ET_PRIM_OP(>, stack, context); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_LT_SCALAR) // executorch_prim::lt.Scalar(Scalar, Scalar) -> bool Kernel( "executorch_prim::lt.Scalar", [](KernelRuntimeContext& context, Span stack) { BOOLEAN_ET_PRIM_OP(<, stack, context); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_GE_SCALAR) // executorch_prim::ge.Scalar(Scalar, Scalar) -> bool Kernel( "executorch_prim::ge.Scalar", [](KernelRuntimeContext& context, Span stack) { BOOLEAN_ET_PRIM_OP(>=, stack, context); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_LE_SCALAR) // executorch_prim::le.Scalar(Scalar, Scalar) -> bool Kernel( "executorch_prim::le.Scalar", [](KernelRuntimeContext& context, Span stack) { BOOLEAN_ET_PRIM_OP(<=, stack, context); }), +#endif + +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_NEG_SCALAR) // executorch_prim::neg.Scalar(Scalar) -> Scalar Kernel( "executorch_prim::neg.Scalar", @@ -404,7 +467,10 @@ static Kernel prim_ops[] = { context, false, InvalidType, /* void */, "%zu", (size_t)a.tag); } }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_FLOORDIV_INT) // executorch_prim::floordiv.int(int, int) -> int Kernel( "executorch_prim::floordiv.int", @@ -422,7 +488,10 @@ static Kernel prim_ops[] = { EValue& out = *stack[2]; out = EValue(a.toInt() / b.toInt()); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_MOD_INT) // executorch_prim::mod.int(int, int) -> int Kernel( "executorch_prim::mod.int", @@ -440,7 +509,10 @@ static Kernel prim_ops[] = { EValue& out = *stack[2]; out = EValue(a.toInt() % b.toInt()); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_MOD_SCALAR) // executorch_prim::mod.Scalar(Scalar, Scalar) -> Scalar Kernel( "executorch_prim::mod.Scalar", @@ -469,7 +541,10 @@ static Kernel prim_ops[] = { (size_t)b.tag); } }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_CEIL_SCALAR) // ceil.Scalar(Scalar a) -> Scalar Kernel( "executorch_prim::ceil.Scalar", @@ -496,7 +571,10 @@ static Kernel prim_ops[] = { (size_t)a.tag); } }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_ROUND_SCALAR) // round.Scalar(Scalar a) -> Scalar Kernel( "executorch_prim::round.Scalar", @@ -540,7 +618,10 @@ static Kernel prim_ops[] = { (size_t)a.tag); } }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_TRUNC_SCALAR) // trunc.Scalar(Scalar a) -> Scalar Kernel( "executorch_prim::trunc.Scalar", @@ -562,19 +643,27 @@ static Kernel prim_ops[] = { context, false, InvalidType, /* void */, "%zu", (size_t)a.tag); } }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_ET_COPY_INDEX_TENSOR) // executorch_prim::et_copy_index.tensor(tensor, tensor) -> tensor Kernel( "executorch_prim::et_copy_index.tensor", [](KernelRuntimeContext& context, Span stack) { et_copy_index(context, stack); }), +#endif + +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_ET_VIEW_DEFAULT) // executorch_prim::et_view.default(Tensor, int[]) -> Tensor Kernel( "executorch_prim::et_view.default", [](KernelRuntimeContext& context, Span stack) { et_view(context, stack); }), +#endif }; diff --git a/kernels/prim_ops/selective_build_prim_ops.h b/kernels/prim_ops/selective_build_prim_ops.h new file mode 100644 index 00000000000..78181405b11 --- /dev/null +++ b/kernels/prim_ops/selective_build_prim_ops.h @@ -0,0 +1,12 @@ +#pragma once +/** + * Generated by executorch/kernels/prim_ops/selective_build_prim_ops.h + * This header conditionally includes selected_prim_ops.h when selective build + * for prim ops is enabled. + */ + +// If no prim ops are selected, then the header is empty. +// that would mean all prim ops are enabled. +#ifdef ET_PRIM_OPS_SELECTIVE_BUILD +#include "selected_prim_ops.h" +#endif diff --git a/kernels/prim_ops/targets.bzl b/kernels/prim_ops/targets.bzl index 8bdc44fe553..eea66c1afa7 100644 --- a/kernels/prim_ops/targets.bzl +++ b/kernels/prim_ops/targets.bzl @@ -7,13 +7,31 @@ def define_common_targets(): TARGETS and BUCK files that call this function. """ + # Define the filegroup once outside the loop since it doesn't vary by aten mode + runtime.filegroup( + name = "prim_ops_sources", + srcs = ["register_prim_ops.cpp"], + visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"], + ) + + runtime.filegroup( + name = "selective_build_prim_ops.h", + srcs = ["selective_build_prim_ops.h"], + visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"], + ) + for aten_mode in get_aten_mode_options(): aten_suffix = ("_aten" if aten_mode else "") runtime.cxx_library( name = "et_copy_index" + aten_suffix, srcs = ["et_copy_index.cpp"], - visibility = [], # Private + # To allow for selective prim ops to depend on this library. + # Used by selective_build.bzl + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], exported_headers = ["et_copy_index.h"], deps = [ "//executorch/runtime/kernel:kernel_includes" + aten_suffix, @@ -28,7 +46,12 @@ def define_common_targets(): runtime.cxx_library( name = "et_view" + aten_suffix, srcs = ["et_view.cpp"], - visibility = [], # Private + # To allow for selective prim ops to depend on this library. + # Used by selective_build.bzl + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], exported_headers = ["et_view.h"], deps = [ "//executorch/runtime/kernel:kernel_includes" + aten_suffix, diff --git a/shim_et/xplat/executorch/codegen/codegen.bzl b/shim_et/xplat/executorch/codegen/codegen.bzl index ae6b42e2d8f..3546b64cdb6 100644 --- a/shim_et/xplat/executorch/codegen/codegen.bzl +++ b/shim_et/xplat/executorch/codegen/codegen.bzl @@ -7,6 +7,7 @@ load( "get_vec_deps", "get_vec_preprocessor_flags", ) +load("@fbsource//xplat/executorch/kernels/prim_ops:selective_build.bzl", "prim_ops_registry_selective") # Headers that declare the function signatures of the C++ functions that # map to entries in functions.yaml and custom_ops.yaml. @@ -81,6 +82,83 @@ ScalarType = enum( "Uint64", ) +def _get_prim_ops_registry_target(name, deps, aten_suffix, platforms): + """ + Helper function to determine which prim ops registry target to use. + + Args: + name: Base name for creating selective registry target + deps: List of dependencies for the selective registry target, it will filter out + the deps with label et_operator_library + aten_suffix: Suffix for aten mode (e.g. "_aten") + platforms: Platforms configuration + + Returns: + String: Target name for the appropriate prim ops registry + """ + # If selective build targets are specified, create a selective prim ops registry + # Create a selective prim ops registry using the existing function + selective_prim_ops_registry_name = name + "_selected_prim_ops_registry" + combined_prim_ops_header_target_name = name + "_combined_prim_ops_header" + selected_prim_operators_genrule(combined_prim_ops_header_target_name, deps, platforms) + # Use the existing prim_ops_registry_selective function + prim_ops_registry_selective( + name = selective_prim_ops_registry_name, + selected_prim_ops_header_target = ":"+combined_prim_ops_header_target_name, + aten_suffix = aten_suffix, + platforms = platforms, + ) + + # Return the selective registry target + return ":" + selective_prim_ops_registry_name + +def _extract_prim_ops_from_lists(ops, ops_dict): + """ + Utility function to extract prim ops from ops list and ops_dict. + + Args: + ops: List of operator names + ops_dict: Dictionary mapping ops to metadata + + Returns: + Tuple of (prim_ops, remaining_ops, remaining_ops_dict) + """ + def _is_aten_prim_op(op_name): + if not op_name.startswith("aten::"): + return False + for prim_suffix in [ + "sym_size", "sym_numel", "sym_max", "sym_min", "sym_float" + ]: + if prim_suffix in op_name: + return True + return False + + def _is_prim_op(op_name): + """Check if an operator is a primitive operation.""" + return op_name.startswith("executorch_prim::") or ( + _is_aten_prim_op(op_name) + ) + + prim_ops = [] + remaining_ops = [] + remaining_ops_dict = {} + + # Extract from ops list + for op in ops: + if _is_prim_op(op): + prim_ops.append(op) + else: + remaining_ops.append(op) + + # Extract from ops_dict + for op, metadata in ops_dict.items(): + if _is_prim_op(op): + prim_ops.append(op) + else: + remaining_ops_dict[op] = metadata + + return prim_ops, remaining_ops, remaining_ops_dict + # Hide the dependency to caffe2 internally. def et_operator_library( name, @@ -91,6 +169,27 @@ def et_operator_library( ops_schema_yaml_target = None, server_generated_yaml_target = None, **kwargs): + + # Check if we should extract prim ops from the operator lists + # Note that selective build for prim ops doesnt support model or ops_schema_yaml_target or server_generated_yaml_target + # TODO: Add support for selective build for prim ops with model or ops_schema_yaml_target or server_generated_yaml_target + should_extract_prim_ops = (ops or ops_dict) and not (model or ops_schema_yaml_target or server_generated_yaml_target or include_all_operators) + + if should_extract_prim_ops: + # Extract prim ops from ops and ops_dict + prim_ops, remaining_ops, remaining_ops_dict = _extract_prim_ops_from_lists(ops, ops_dict) + # Use the remaining ops (with prim ops removed) for the main et_operator_library + final_ops = remaining_ops + final_ops_dict = remaining_ops_dict + else: + # No prim ops extraction needed - use original ops and ops_dict + prim_ops = [] + final_ops = ops + final_ops_dict = ops_dict + + selected_operator_yaml_filename = "selected_operators.yaml" + selected_prim_ops_filename = "selected_prim_ops.h" + # Generate the main operator library with the final ops # do a dummy copy if server_generated_yaml_target is set if server_generated_yaml_target: if include_all_operators or ops_schema_yaml_target or model or ops or ops_dict: @@ -98,7 +197,7 @@ def et_operator_library( genrule_cmd = [ "cp", "$(location {})".format(server_generated_yaml_target), - "$OUT", + "$OUT/{}".format(selected_operator_yaml_filename), ] else: genrule_cmd = [ @@ -109,12 +208,12 @@ def et_operator_library( genrule_cmd.append( "--ops_schema_yaml_path=$(location {})".format(ops_schema_yaml_target), ) - if ops: + if final_ops: genrule_cmd.append( - "--root_ops=" + ",".join(ops), + "--root_ops=" + ",".join(final_ops), ) - if ops_dict: - ops_dict_json = struct_to_json(ops_dict) + if final_ops_dict: + ops_dict_json = struct_to_json(final_ops_dict) genrule_cmd.append( "--ops_dict='{}'".format(ops_dict_json), ) @@ -127,6 +226,15 @@ def et_operator_library( "--include_all_operators", ) + prim_ops_genrule_cmd = [ + "$(exe //executorch/codegen/tools:gen_selected_prim_ops)", + "--prim_op_names=" + ",".join(prim_ops), + "--output_dir=${OUT}", + ] + # Here we generate the selected_prim_ops.h and the selected_operators.yaml file + # both with single genrule + genrule_cmd = genrule_cmd + [" && "] + prim_ops_genrule_cmd + # TODO(larryliu0820): Remove usages of this flag. if "define_static_targets" in kwargs: kwargs.pop("define_static_targets") @@ -134,7 +242,8 @@ def et_operator_library( name = name, macros_only = False, cmd = " ".join(genrule_cmd), - out = "selected_operators.yaml", + outs = {selected_operator_yaml_filename: [selected_operator_yaml_filename], selected_prim_ops_filename: [selected_prim_ops_filename]}, + default_outs = ["."], labels = ["et_operator_library"], **kwargs ) @@ -615,6 +724,31 @@ def selected_operators_genrule( platforms = platforms, ) +def selected_prim_operators_genrule( + name, + deps, + platforms = get_default_executorch_platforms(), +): + """Generates selected_prim_ops.h from the list of deps. We look into the transitive closure of all the deps, + and look for targets with label `et_operator_library`. + + `combine_prim_ops_headers` is the python binary we use to aggregate all the `selected_prim_ops.h` headers + from `et_prim_ops_library` targets into a single combined `selected_prim_ops.h` file. + + This file can be used to enable selective build for prim ops across multiple dependencies. + """ + cmd = ("$(exe //executorch/codegen/tools:combine_prim_ops_headers) " + + "--header_files $(@query_outputs \'attrfilter(labels, et_operator_library, deps(set({deps})))\') " + + "--output_dir $OUT ").format(deps = " ".join(["\"{}\"".format(d) for d in deps])) + runtime.genrule( + name = name, + macros_only = False, + cmd = cmd, + outs = {"selected_prim_ops.h": ["selected_prim_ops.h"]}, + default_outs = ["."], + platforms = platforms, + ) + def dtype_header_genrule( name, visibility, @@ -677,7 +811,8 @@ def executorch_generated_lib( dtype_selective_build = False, feature = None, expose_operator_symbols = False, - support_exceptions = True): + support_exceptions = True, + include_all_prim_ops = True): """Emits 0-3 C++ library targets (in fbcode or xplat) containing code to dispatch the operators specified in the provided yaml files. @@ -738,6 +873,9 @@ def executorch_generated_lib( support_exceptions: enable try/catch wrapper around operator implementations to make sure exceptions thrown will not bring down the process. Disable if your use case disables exceptions in the build. + include_all_prim_ops: If true, include all prim ops in the generated library. This option + allows for selecting only some prim ops to reduce code size for extremely constrained + environments. For selecting only some prim ops, see examples in //executorch/examples/selective_build """ if functions_yaml_target and aten_mode: fail("{} is providing functions_yaml_target in ATen mode, it will be ignored. `native_functions.yaml` will be the source of truth.".format(name)) @@ -903,6 +1041,12 @@ def executorch_generated_lib( if name in libs: lib_name = name + + if include_all_prim_ops: + prim_ops_registry_target = "//executorch/kernels/prim_ops:prim_ops_registry" + aten_suffix + else: + prim_ops_registry_target = _get_prim_ops_registry_target(name, deps, aten_suffix, platforms) + runtime.cxx_library( name = lib_name, srcs = [ @@ -927,7 +1071,7 @@ def executorch_generated_lib( }) + compiler_flags, deps = [ "//executorch/runtime/kernel:operator_registry" + aten_suffix, - "//executorch/kernels/prim_ops:prim_ops_registry" + aten_suffix, + prim_ops_registry_target, # Use the appropriate prim ops registry "//executorch/runtime/core:evalue" + aten_suffix, "//executorch/codegen:macros", ] + deps + kernel_deps,