diff --git a/codegen/tools/combine_prim_ops_headers.py b/codegen/tools/combine_prim_ops_headers.py new file mode 100644 index 00000000000..b579de2047d --- /dev/null +++ b/codegen/tools/combine_prim_ops_headers.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Script to combine multiple selected_prim_ops.h header files into a single header. +This is used by selected_prim_operators_genrule to merge prim ops headers from dependencies. +""" + +import argparse +import os +import sys +from pathlib import Path +from typing import List, Set + + +def read_header_file(file_path: Path) -> Set[str]: + """ + Read a selected_prim_ops.h file and extract the macros and comments. + + Args: + file_path: Path to the header file + + Returns: + macros_set where macros_set contains unique macro defines + """ + macros = set() + + try: + with open(file_path, "r") as f: + for line in f: + line = line.strip() + + # Extract #define statements for prim ops + if line.startswith("#define INCLUDE_") and not line.startswith( + "#define EXECUTORCH_ENABLE" + ): + macros.add(line) + except FileNotFoundError: + print(f"Warning: Header file not found: {file_path}", file=sys.stderr) + except Exception as e: + print(f"Error reading {file_path}: {e}", file=sys.stderr) + + return macros + + +def combine_prim_ops_headers(header_file_paths: List[str], output_path: str) -> None: + """ + Combine multiple selected_prim_ops.h files into a single header. + + Args: + header_files: List of paths to header files to combine + output_path: Path to output the combined header + """ + all_macros = set() + has_selective_build = False + + # Read all header files and collect unique macros + for header_file_path in header_file_paths: + header_file = Path(header_file_path) / "selected_prim_ops.h" + if os.path.exists(header_file): + macros = read_header_file(header_file) + all_macros.update(macros) + if len(all_macros) > 0: + has_selective_build = True + else: + print( + f"Warning: Header file does not exist: {header_file}", file=sys.stderr + ) + + # Generate combined header + header_content = [ + "// Combined header for selective prim ops build", + "// This file is auto-generated by combining multiple selected_prim_ops.h files", + "// Do not edit manually.", + "", + "#pragma once", + "", + ] + + if all_macros and has_selective_build: + header_content.extend( + [ + "// Enable selective build for prim ops", + "#define EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD", + "", + "// Combined prim ops macros from all dependencies", + ] + ) + + # Sort macros for deterministic output + sorted_macros = sorted(all_macros) + header_content.extend(sorted_macros) + else: + header_content.extend( + [ + "// No prim ops found in dependencies - all prim ops will be included", + "// Selective build is disabled", + ] + ) + + header_content.append("") + + # Write the combined header + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "w") as f: + f.write("\n".join(header_content)) + + +def _get_header_file_paths_from_query_output(query_output_file: str) -> List[str]: + """ + Parse the output of a Buck query command to extract header file paths. + + Args: + query_output_file: Path to the file containing the query output + + Returns: + List of header file paths + """ + header_file_paths = [] + assert ( + query_output_file[0] == "@" + ), "query_output_file is not a valid file path, or it doesn't start with '@'." + query_output_file = query_output_file[1:] + + with open(query_output_file, "r") as f: + for line in f: + # Extract the header file path from the query output + header_file_paths += line.split() + return header_file_paths + + +def main(): + parser = argparse.ArgumentParser( + description="Combine multiple selected_prim_ops.h header files" + ) + parser.add_argument( + "--header_files", + required=True, + help="Comma-separated list of header file paths", + ) + parser.add_argument( + "--output_dir", required=True, help="Output directory for combined header" + ) + + args = parser.parse_args() + import os + + header_file_paths = _get_header_file_paths_from_query_output(args.header_files) + + if not header_file_paths: + print("Error: No header files provided", file=sys.stderr) + sys.exit(1) + + # Generate output path + output_path = os.path.join(args.output_dir, "selected_prim_ops.h") + + combine_prim_ops_headers(header_file_paths, output_path) + + +if __name__ == "__main__": + main() diff --git a/codegen/tools/gen_all_oplist.py b/codegen/tools/gen_all_oplist.py index 5cb93bb9153..f33c3dc935d 100644 --- a/codegen/tools/gen_all_oplist.py +++ b/codegen/tools/gen_all_oplist.py @@ -10,7 +10,7 @@ import sys from functools import reduce from pathlib import Path -from typing import Any, List +from typing import Any, Dict, List import yaml from torchgen.selective_build.selector import ( @@ -72,6 +72,19 @@ def _raise_if_check_prim_ops_fail(options): raise Exception(error) +def _selected_ops_model_dict_is_empty(model_dict: Dict[str, Any]) -> bool: + return ( + not model_dict.get("build_features", []) + and not model_dict.get("custom_classes", []) + and not model_dict.get("et_kernel_metadata", None) + and not model_dict.get("include_all_non_op_selectives", False) + and not model_dict.get("include_all_operators", False) + and not model_dict.get("kernel_metadata", {}) + and not model_dict.get("operators", {}) + ) + + +# flake8: noqa: C901 def main(argv: List[Any]) -> None: """This binary generates 3 files: @@ -171,6 +184,11 @@ def main(argv: List[Any]) -> None: ), f"{model_file_name} is not a valid file path. This is likely a BUCK issue." with open(model_file_name, "rb") as model_file: model_dict = yaml.safe_load(model_file) + # It is possible that we created an empty yaml file. + # This is because et_operator_library may only contain prim ops. + # In that case selected_operators.yaml will be empty. + if _selected_ops_model_dict_is_empty(model_dict): + continue resolved = resolve_model_file_path_to_buck_target(model_file_name) for op in model_dict["operators"]: model_dict["operators"][op]["debug_info"] = [resolved] diff --git a/codegen/tools/gen_oplist.py b/codegen/tools/gen_oplist.py index cca5bf1b1d2..28506050a8e 100644 --- a/codegen/tools/gen_oplist.py +++ b/codegen/tools/gen_oplist.py @@ -9,6 +9,7 @@ import os import sys from enum import IntEnum +from pathlib import Path from typing import Any, Dict, List, Optional, Set import yaml @@ -158,7 +159,7 @@ def _get_et_kernel_metadata_from_ops_yaml(ops_yaml_path: str) -> Dict[str, List[ def _dump_yaml( op_list: List[str], - output_path: str, + output_path: Path, model_name: Optional[str] = None, et_kernel_metadata: Optional[Dict[str, List[str]]] = None, include_all_operators: bool = False, @@ -212,20 +213,23 @@ def create_kernel_key(maybe_kernel_key: str) -> str: def gen_oplist( - output_path: str, + output_path: Path, model_file_path: Optional[str] = None, ops_schema_yaml_path: Optional[str] = None, root_ops: Optional[str] = None, ops_dict: Optional[str] = None, include_all_operators: bool = False, ): - assert ( + if not ( model_file_path or ops_schema_yaml_path or root_ops or ops_dict or include_all_operators - ), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or ops_dict or include_all_operators." + ): + # dump empty yaml file + _dump_yaml([], output_path) + return assert output_path, "Need to provide output_path for dumped yaml file." op_set = set() @@ -326,9 +330,15 @@ def main(args: List[Any]) -> None: ) options = parser.parse_args(args) + # check if the output_path is a directory, then generate operators + # under selected_operators.yaml + if Path(options.output_path).is_dir(): + output_path = Path(options.output_path) / "selected_operators.yaml" + else: + output_path = Path(options.output_path) try: gen_oplist( - output_path=options.output_path, + output_path=output_path, model_file_path=options.model_file_path, ops_schema_yaml_path=options.ops_schema_yaml_path, root_ops=options.root_ops, diff --git a/codegen/tools/gen_selected_prim_ops.py b/codegen/tools/gen_selected_prim_ops.py new file mode 100644 index 00000000000..4535ffaa57a --- /dev/null +++ b/codegen/tools/gen_selected_prim_ops.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import argparse +import os +import sys +from typing import Any, List + +from torchgen.code_template import CodeTemplate # type: ignore[import-not-found] + + +selected_prim_ops_h_template_str = """#pragma once +/** + * Generated by executorch/codegen/tools/gen_selected_prim_ops.py + */ + +$defines +""" +selected_prim_ops_h_template = CodeTemplate(selected_prim_ops_h_template_str) + + +def normalize_op_name(op_name: str) -> str: + """ + Normalize an operator name to a macro-safe format. + Convert op names like "executorch_prim::et_view.default" to "EXECUTORCH_PRIM_ET_VIEW_DEFAULT" + or "aten::sym_size.int" to "ATEN_SYM_SIZE_INT" + """ + # Remove namespace separator and replace with underscore + normalized = op_name.replace("::", "_") + # Replace dots with underscores + normalized = normalized.replace(".", "_") + # Convert to uppercase + normalized = normalized.upper() + # Add INCLUDE_ prefix + normalized = f"INCLUDE_{normalized}" + return normalized + + +def write_selected_prim_ops(prim_op_names: List[str], output_dir: str) -> None: + """ + Generate selected_prim_ops.h from a list of prim op names. + + Args: + prim_op_names: List of prim op names like ["executorch_prim::et_view.default", "aten::sym_size.int"] + output_dir: Directory where to write selected_prim_ops.h + """ + # Generate #define statements for each op + defines = [] + for op_name in prim_op_names: + macro_name = normalize_op_name(op_name) + defines.append(f"#define {macro_name}") + + # Join all defines with newlines + defines_str = "\n".join(defines) + + # Generate header content + header_contents = selected_prim_ops_h_template.substitute(defines=defines_str) + + # Write to file + selected_prim_ops_path = os.path.join(output_dir, "selected_prim_ops.h") + with open(selected_prim_ops_path, "wb") as out_file: + out_file.write(header_contents.encode("utf-8")) + + +def main(argv: List[Any]) -> None: + parser = argparse.ArgumentParser(description="Generate selected prim ops header") + parser.add_argument( + "--prim-op-names", + "--prim_op_names", + help="Comma-separated list of prim op names to include", + required=True, + ) + parser.add_argument( + "--output-dir", + "--output_dir", + help="The directory to store the output header file (selected_prim_ops.h)", + required=True, + ) + + options = parser.parse_args(argv) + + # Parse comma-separated prim op names + prim_op_names = [ + name.strip() for name in options.prim_op_names.split(",") if name.strip() + ] + + write_selected_prim_ops(prim_op_names, options.output_dir) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/codegen/tools/targets.bzl b/codegen/tools/targets.bzl index acea3370e7d..d594b7178b8 100644 --- a/codegen/tools/targets.bzl +++ b/codegen/tools/targets.bzl @@ -103,6 +103,26 @@ def define_common_targets(is_fbcode = False): _is_external_target = True, ) + runtime.python_library( + name = "combine_prim_ops_headers_lib", + srcs = ["combine_prim_ops_headers.py"], + base_module = "executorch.codegen.tools", + visibility = ["//executorch/..."], + ) + + runtime.python_binary( + name = "combine_prim_ops_headers", + main_module = "executorch.codegen.tools.combine_prim_ops_headers", + package_style = "inplace", + visibility = [ + "PUBLIC", + ], + deps = [ + ":combine_prim_ops_headers_lib", + ], + _is_external_target = True, + ) + runtime.python_test( name = "test_gen_all_oplist", srcs = [ @@ -155,6 +175,27 @@ def define_common_targets(is_fbcode = False): _is_external_target = True, ) + runtime.python_library( + name = "gen_selected_prim_ops_lib", + srcs = ["gen_selected_prim_ops.py"], + base_module = "executorch.codegen.tools", + visibility = ["//executorch/..."], + external_deps = ["torchgen"], + ) + + runtime.python_binary( + name = "gen_selected_prim_ops", + main_module = "executorch.codegen.tools.gen_selected_prim_ops", + package_style = "inplace", + visibility = [ + "PUBLIC", + ], + deps = [ + ":gen_selected_prim_ops_lib", + ], + _is_external_target = True, + ) + if not runtime.is_oss: runtime.cxx_python_extension( name = "selective_build", diff --git a/codegen/tools/test/test_gen_oplist.py b/codegen/tools/test/test_gen_oplist.py index f5c6829d6a0..18689cd2505 100644 --- a/codegen/tools/test/test_gen_oplist.py +++ b/codegen/tools/test/test_gen_oplist.py @@ -8,6 +8,7 @@ import os import tempfile import unittest +from pathlib import Path from typing import Dict, List from unittest.mock import NonCallableMock, patch @@ -77,7 +78,7 @@ def test_gen_op_list_with_valid_root_ops( gen_oplist.main(args) mock_dump_yaml.assert_called_once_with( ["aten::add", "aten::mul"], - output_path, + Path(output_path), None, {"aten::add": ["default"], "aten::mul": ["default"]}, False, @@ -100,7 +101,7 @@ def test_gen_op_list_with_root_ops_and_dtypes( gen_oplist.main(args) mock_dump_yaml.assert_called_once_with( ["aten::add", "aten::mul"], - output_path, + Path(output_path), None, { "aten::add": [ @@ -129,7 +130,7 @@ def test_gen_op_list_with_both_op_list_and_ops_schema_yaml_merges( gen_oplist.main(args) mock_dump_yaml.assert_called_once_with( ["aten::add.out", "aten::mul.out", "aten::relu.out"], - output_path, + Path(output_path), test_path, { "aten::relu.out": ["default"], @@ -153,7 +154,7 @@ def test_gen_op_list_with_include_all_operators( gen_oplist.main(args) mock_dump_yaml.assert_called_once_with( ["aten::add", "aten::mul"], - output_path, + Path(output_path), None, {"aten::add": ["default"], "aten::mul": ["default"]}, True, @@ -164,7 +165,7 @@ def test_get_custom_build_selector_with_both_allowlist_and_yaml( ) -> None: op_list = ["aten::add", "aten::mul"] filename = os.path.join(self.temp_dir.name, "selected_operators.yaml") - gen_oplist._dump_yaml(op_list, filename, "model.pte") + gen_oplist._dump_yaml(op_list, Path(filename), "model.pte") self.assertTrue(os.path.isfile(filename)) with open(filename) as f: es = yaml.safe_load(f) diff --git a/examples/selective_build/targets.bzl b/examples/selective_build/targets.bzl index 72639fef842..bd11a53e3e0 100644 --- a/examples/selective_build/targets.bzl +++ b/examples/selective_build/targets.bzl @@ -1,6 +1,118 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_oss_build_kwargs", "is_xplat", "runtime") load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib", "ScalarType") +def define_selective_build_prim_ops_example(): + """ + Example showing how selected_prim_operators_genrule works to combine + prim ops headers from multiple dependencies. + """ + + # Define several operator libraries with automatic prim ops extraction + et_operator_library( + name = "model_a_ops", + ops = [ + "aten::add.out", + "aten::mul.out", + "executorch_prim::et_view.default", # Auto-extracted to prim ops + "aten::sym_size.int", # Auto-extracted to prim ops + ], + visibility = ["//executorch/..."], + ) + # This creates: "model_a_ops" + "model_a_ops_selected_prim_ops" + + et_operator_library( + name = "model_b_ops", + ops = [ + "aten::sub.out", + "aten::div.out", + "executorch_prim::add.Scalar", # Auto-extracted to prim ops + "aten::sym_numel.int", # Auto-extracted to prim ops + ], + visibility = ["//executorch/..."], + ) + # This creates: "model_b_ops" + "model_b_ops_selected_prim_ops" + + # Define a manual prim ops target as well + et_operator_library( + name = "extra_prim_ops", + ops = [ + "executorch_prim::mul.Scalar", + "executorch_prim::sym_max.Scalar", + ], + visibility = ["//executorch/..."], + ) + # Use the combined header in an executorch_generated_lib + executorch_generated_lib( + name = "library_with_combined_prim_ops", + deps = [ + ":model_a_ops", + ":model_b_ops", + ":extra_prim_ops", + ], + kernel_deps = [ + "//executorch/kernels/portable:operators", + ], + functions_yaml_target = "//executorch/kernels/portable:functions.yaml", + aten_mode = False, + visibility = ["PUBLIC"], + include_all_prim_ops = False, + ) + + # Prim ops selected separately + et_operator_library( + name = "model_b_ops_no_prim_ops", + ops = [ + "aten::sub.out", + "aten::div.out", + ], + visibility = ["//executorch/..."], + ) + + # Use the combined header in an executorch_generated_lib + executorch_generated_lib( + name = "library_with_combined_prim_ops_1", + deps = [ + ":model_b_ops_no_prim_ops", + ":extra_prim_ops", + ], + kernel_deps = [ + "//executorch/kernels/portable:operators", + ], + functions_yaml_target = "//executorch/kernels/portable:functions.yaml", + aten_mode = False, + visibility = ["PUBLIC"], + include_all_prim_ops = False, + ) + + # No prim ops selected. So include all prim ops. + executorch_generated_lib( + name = "library_with_combined_prim_ops_2", + deps = [ + ":model_b_ops_no_prim_ops", + ], + kernel_deps = [ + "//executorch/kernels/portable:operators", + ], + functions_yaml_target = "//executorch/kernels/portable:functions.yaml", + aten_mode = False, + visibility = ["PUBLIC"], + include_all_prim_ops = False, + ) + + # default to selecting all prim ops + executorch_generated_lib( + name = "library_with_all_prim_ops", + deps = [ + ":model_b_ops", + ], + kernel_deps = [ + "//executorch/kernels/portable:operators", + ], + functions_yaml_target = "//executorch/kernels/portable:functions.yaml", + aten_mode = False, + visibility = ["PUBLIC"], + ) + def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. @@ -165,3 +277,5 @@ def define_common_targets(): define_static_target = True, **get_oss_build_kwargs() ) + + define_selective_build_prim_ops_example() diff --git a/kernels/prim_ops/register_prim_ops.cpp b/kernels/prim_ops/register_prim_ops.cpp index 8607c36204d..dc6ed9ac26f 100644 --- a/kernels/prim_ops/register_prim_ops.cpp +++ b/kernels/prim_ops/register_prim_ops.cpp @@ -12,6 +12,18 @@ #include #include +/* +For internal builds using buck rules, the target that depends on +selective prim ops, will manage its own artifacts. It is in the +artifacts directory where the geneated selected_prim_ops.h resides +and thus compilation sources must be copied there including +selective_build_prim_ops.h. Hence it does not have fully qualified +name unlike the header files above. +*/ +#ifdef ET_PRIM_OPS_SELECTIVE_BUILD +#include "selective_build_prim_ops.h" +#endif + #include #include @@ -87,6 +99,8 @@ void floor_div_double(double a, double b, EValue& out) { } static Kernel prim_ops[] = { +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_ATEN_SYM_SIZE_INT) // aten::sym_size.int(Tensor self, int dim) -> SymInt Kernel( "aten::sym_size.int", @@ -108,6 +122,9 @@ static Kernel prim_ops[] = { int64_t size = self_tensor.size(dim_val); out = EValue(size); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_ATEN_LOCAL_SCALAR_DENSE) // aten::_local_scalar_dense(Tensor self) -> Scalar Kernel( "aten::_local_scalar_dense", @@ -134,6 +151,9 @@ static Kernel prim_ops[] = { out = EValue(Scalar(self_tensor.const_data_ptr()[0])); }); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_ATEN_SYM_NUMEL) // aten::sym_numel(Tensor self) -> SymInt Kernel( "aten::sym_numel", @@ -153,6 +173,9 @@ static Kernel prim_ops[] = { int64_t numel = self_tensor.numel(); out = EValue(numel); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_SYM_MAX_SCALAR) // executorch_prim::sym_max.Scalar(SymInt a, SymInt b) -> SymInt Kernel( "executorch_prim::sym_max.Scalar", @@ -182,6 +205,9 @@ static Kernel prim_ops[] = { (size_t)b.tag); } }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_SYM_MIN_SCALAR) // executorch_prim::sym_min.Scalar(SymInt a, SymInt b) -> SymInt Kernel( "executorch_prim::sym_min.Scalar", @@ -210,27 +236,39 @@ static Kernel prim_ops[] = { (size_t)b.tag); } }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_ADD_SCALAR) // executorch_prim::add.Scalar(Scalar, Scalar) -> Scalar Kernel( "executorch_prim::add.Scalar", [](KernelRuntimeContext& context, Span stack) { ALGEBRA_ET_PRIM_OP(+, stack, context); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_SUB_SCALAR) // executorch_prim::sub.Scalar(Scalar, Scalar) -> Scalar Kernel( "executorch_prim::sub.Scalar", [](KernelRuntimeContext& context, Span stack) { ALGEBRA_ET_PRIM_OP(-, stack, context); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_MUL_SCALAR) // executorch_prim::mul.Scalar(Scalar, Scalar) -> Scalar Kernel( "executorch_prim::mul.Scalar", [](KernelRuntimeContext& context, Span stack) { ALGEBRA_ET_PRIM_OP(*, stack, context); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_FLOORDIV_SCALAR) /** * Python's __floordiv__ operator is more complicated than just floor(a / * b). It aims to maintain the property: a == (a // b) * b + remainder(a, b) @@ -280,8 +318,11 @@ static Kernel prim_ops[] = { (size_t)b.tag); } }), +#endif - // executorch_prim::floordiv.Scalar(Scalar, Scalar) -> Scalar +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_TRUEDIV_SCALAR) + // executorch_prim::truediv.Scalar(Scalar, Scalar) -> Scalar Kernel( "executorch_prim::truediv.Scalar", [](KernelRuntimeContext& context, Span stack) { @@ -318,7 +359,10 @@ static Kernel prim_ops[] = { (size_t)b.tag); } }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_SYM_FLOAT_SCALAR) // executorch_prim::sym_float.Scalar(Scalar) -> Scalar Kernel( "executorch_prim::sym_float.Scalar", @@ -346,41 +390,60 @@ static Kernel prim_ops[] = { context, false, InvalidType, /* void */, "%zu", (size_t)a.tag); } }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_EQ_SCALAR) // executorch_prim::eq.Scalar(Scalar, Scalar) -> bool Kernel( "executorch_prim::eq.Scalar", [](KernelRuntimeContext& context, Span stack) { BOOLEAN_ET_PRIM_OP(==, stack, context); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_GT_SCALAR) // executorch_prim::gt.Scalar(Scalar, Scalar) -> bool Kernel( "executorch_prim::gt.Scalar", [](KernelRuntimeContext& context, Span stack) { BOOLEAN_ET_PRIM_OP(>, stack, context); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_LT_SCALAR) // executorch_prim::lt.Scalar(Scalar, Scalar) -> bool Kernel( "executorch_prim::lt.Scalar", [](KernelRuntimeContext& context, Span stack) { BOOLEAN_ET_PRIM_OP(<, stack, context); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_GE_SCALAR) // executorch_prim::ge.Scalar(Scalar, Scalar) -> bool Kernel( "executorch_prim::ge.Scalar", [](KernelRuntimeContext& context, Span stack) { BOOLEAN_ET_PRIM_OP(>=, stack, context); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_LE_SCALAR) // executorch_prim::le.Scalar(Scalar, Scalar) -> bool Kernel( "executorch_prim::le.Scalar", [](KernelRuntimeContext& context, Span stack) { BOOLEAN_ET_PRIM_OP(<=, stack, context); }), +#endif + +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_NEG_SCALAR) // executorch_prim::neg.Scalar(Scalar) -> Scalar Kernel( "executorch_prim::neg.Scalar", @@ -404,7 +467,10 @@ static Kernel prim_ops[] = { context, false, InvalidType, /* void */, "%zu", (size_t)a.tag); } }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_FLOORDIV_INT) // executorch_prim::floordiv.int(int, int) -> int Kernel( "executorch_prim::floordiv.int", @@ -422,7 +488,10 @@ static Kernel prim_ops[] = { EValue& out = *stack[2]; out = EValue(a.toInt() / b.toInt()); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_MOD_INT) // executorch_prim::mod.int(int, int) -> int Kernel( "executorch_prim::mod.int", @@ -440,7 +509,10 @@ static Kernel prim_ops[] = { EValue& out = *stack[2]; out = EValue(a.toInt() % b.toInt()); }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_MOD_SCALAR) // executorch_prim::mod.Scalar(Scalar, Scalar) -> Scalar Kernel( "executorch_prim::mod.Scalar", @@ -469,7 +541,10 @@ static Kernel prim_ops[] = { (size_t)b.tag); } }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_CEIL_SCALAR) // ceil.Scalar(Scalar a) -> Scalar Kernel( "executorch_prim::ceil.Scalar", @@ -496,7 +571,10 @@ static Kernel prim_ops[] = { (size_t)a.tag); } }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_ROUND_SCALAR) // round.Scalar(Scalar a) -> Scalar Kernel( "executorch_prim::round.Scalar", @@ -540,7 +618,10 @@ static Kernel prim_ops[] = { (size_t)a.tag); } }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_TRUNC_SCALAR) // trunc.Scalar(Scalar a) -> Scalar Kernel( "executorch_prim::trunc.Scalar", @@ -562,19 +643,27 @@ static Kernel prim_ops[] = { context, false, InvalidType, /* void */, "%zu", (size_t)a.tag); } }), +#endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_ET_COPY_INDEX_TENSOR) // executorch_prim::et_copy_index.tensor(tensor, tensor) -> tensor Kernel( "executorch_prim::et_copy_index.tensor", [](KernelRuntimeContext& context, Span stack) { et_copy_index(context, stack); }), +#endif + +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_ET_VIEW_DEFAULT) // executorch_prim::et_view.default(Tensor, int[]) -> Tensor Kernel( "executorch_prim::et_view.default", [](KernelRuntimeContext& context, Span stack) { et_view(context, stack); }), +#endif }; diff --git a/kernels/prim_ops/selective_build_prim_ops.h b/kernels/prim_ops/selective_build_prim_ops.h new file mode 100644 index 00000000000..78181405b11 --- /dev/null +++ b/kernels/prim_ops/selective_build_prim_ops.h @@ -0,0 +1,12 @@ +#pragma once +/** + * Generated by executorch/kernels/prim_ops/selective_build_prim_ops.h + * This header conditionally includes selected_prim_ops.h when selective build + * for prim ops is enabled. + */ + +// If no prim ops are selected, then the header is empty. +// that would mean all prim ops are enabled. +#ifdef ET_PRIM_OPS_SELECTIVE_BUILD +#include "selected_prim_ops.h" +#endif diff --git a/kernels/prim_ops/targets.bzl b/kernels/prim_ops/targets.bzl index 8bdc44fe553..eea66c1afa7 100644 --- a/kernels/prim_ops/targets.bzl +++ b/kernels/prim_ops/targets.bzl @@ -7,13 +7,31 @@ def define_common_targets(): TARGETS and BUCK files that call this function. """ + # Define the filegroup once outside the loop since it doesn't vary by aten mode + runtime.filegroup( + name = "prim_ops_sources", + srcs = ["register_prim_ops.cpp"], + visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"], + ) + + runtime.filegroup( + name = "selective_build_prim_ops.h", + srcs = ["selective_build_prim_ops.h"], + visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"], + ) + for aten_mode in get_aten_mode_options(): aten_suffix = ("_aten" if aten_mode else "") runtime.cxx_library( name = "et_copy_index" + aten_suffix, srcs = ["et_copy_index.cpp"], - visibility = [], # Private + # To allow for selective prim ops to depend on this library. + # Used by selective_build.bzl + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], exported_headers = ["et_copy_index.h"], deps = [ "//executorch/runtime/kernel:kernel_includes" + aten_suffix, @@ -28,7 +46,12 @@ def define_common_targets(): runtime.cxx_library( name = "et_view" + aten_suffix, srcs = ["et_view.cpp"], - visibility = [], # Private + # To allow for selective prim ops to depend on this library. + # Used by selective_build.bzl + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], exported_headers = ["et_view.h"], deps = [ "//executorch/runtime/kernel:kernel_includes" + aten_suffix, diff --git a/shim_et/xplat/executorch/codegen/codegen.bzl b/shim_et/xplat/executorch/codegen/codegen.bzl index ae6b42e2d8f..3546b64cdb6 100644 --- a/shim_et/xplat/executorch/codegen/codegen.bzl +++ b/shim_et/xplat/executorch/codegen/codegen.bzl @@ -7,6 +7,7 @@ load( "get_vec_deps", "get_vec_preprocessor_flags", ) +load("@fbsource//xplat/executorch/kernels/prim_ops:selective_build.bzl", "prim_ops_registry_selective") # Headers that declare the function signatures of the C++ functions that # map to entries in functions.yaml and custom_ops.yaml. @@ -81,6 +82,83 @@ ScalarType = enum( "Uint64", ) +def _get_prim_ops_registry_target(name, deps, aten_suffix, platforms): + """ + Helper function to determine which prim ops registry target to use. + + Args: + name: Base name for creating selective registry target + deps: List of dependencies for the selective registry target, it will filter out + the deps with label et_operator_library + aten_suffix: Suffix for aten mode (e.g. "_aten") + platforms: Platforms configuration + + Returns: + String: Target name for the appropriate prim ops registry + """ + # If selective build targets are specified, create a selective prim ops registry + # Create a selective prim ops registry using the existing function + selective_prim_ops_registry_name = name + "_selected_prim_ops_registry" + combined_prim_ops_header_target_name = name + "_combined_prim_ops_header" + selected_prim_operators_genrule(combined_prim_ops_header_target_name, deps, platforms) + # Use the existing prim_ops_registry_selective function + prim_ops_registry_selective( + name = selective_prim_ops_registry_name, + selected_prim_ops_header_target = ":"+combined_prim_ops_header_target_name, + aten_suffix = aten_suffix, + platforms = platforms, + ) + + # Return the selective registry target + return ":" + selective_prim_ops_registry_name + +def _extract_prim_ops_from_lists(ops, ops_dict): + """ + Utility function to extract prim ops from ops list and ops_dict. + + Args: + ops: List of operator names + ops_dict: Dictionary mapping ops to metadata + + Returns: + Tuple of (prim_ops, remaining_ops, remaining_ops_dict) + """ + def _is_aten_prim_op(op_name): + if not op_name.startswith("aten::"): + return False + for prim_suffix in [ + "sym_size", "sym_numel", "sym_max", "sym_min", "sym_float" + ]: + if prim_suffix in op_name: + return True + return False + + def _is_prim_op(op_name): + """Check if an operator is a primitive operation.""" + return op_name.startswith("executorch_prim::") or ( + _is_aten_prim_op(op_name) + ) + + prim_ops = [] + remaining_ops = [] + remaining_ops_dict = {} + + # Extract from ops list + for op in ops: + if _is_prim_op(op): + prim_ops.append(op) + else: + remaining_ops.append(op) + + # Extract from ops_dict + for op, metadata in ops_dict.items(): + if _is_prim_op(op): + prim_ops.append(op) + else: + remaining_ops_dict[op] = metadata + + return prim_ops, remaining_ops, remaining_ops_dict + # Hide the dependency to caffe2 internally. def et_operator_library( name, @@ -91,6 +169,27 @@ def et_operator_library( ops_schema_yaml_target = None, server_generated_yaml_target = None, **kwargs): + + # Check if we should extract prim ops from the operator lists + # Note that selective build for prim ops doesnt support model or ops_schema_yaml_target or server_generated_yaml_target + # TODO: Add support for selective build for prim ops with model or ops_schema_yaml_target or server_generated_yaml_target + should_extract_prim_ops = (ops or ops_dict) and not (model or ops_schema_yaml_target or server_generated_yaml_target or include_all_operators) + + if should_extract_prim_ops: + # Extract prim ops from ops and ops_dict + prim_ops, remaining_ops, remaining_ops_dict = _extract_prim_ops_from_lists(ops, ops_dict) + # Use the remaining ops (with prim ops removed) for the main et_operator_library + final_ops = remaining_ops + final_ops_dict = remaining_ops_dict + else: + # No prim ops extraction needed - use original ops and ops_dict + prim_ops = [] + final_ops = ops + final_ops_dict = ops_dict + + selected_operator_yaml_filename = "selected_operators.yaml" + selected_prim_ops_filename = "selected_prim_ops.h" + # Generate the main operator library with the final ops # do a dummy copy if server_generated_yaml_target is set if server_generated_yaml_target: if include_all_operators or ops_schema_yaml_target or model or ops or ops_dict: @@ -98,7 +197,7 @@ def et_operator_library( genrule_cmd = [ "cp", "$(location {})".format(server_generated_yaml_target), - "$OUT", + "$OUT/{}".format(selected_operator_yaml_filename), ] else: genrule_cmd = [ @@ -109,12 +208,12 @@ def et_operator_library( genrule_cmd.append( "--ops_schema_yaml_path=$(location {})".format(ops_schema_yaml_target), ) - if ops: + if final_ops: genrule_cmd.append( - "--root_ops=" + ",".join(ops), + "--root_ops=" + ",".join(final_ops), ) - if ops_dict: - ops_dict_json = struct_to_json(ops_dict) + if final_ops_dict: + ops_dict_json = struct_to_json(final_ops_dict) genrule_cmd.append( "--ops_dict='{}'".format(ops_dict_json), ) @@ -127,6 +226,15 @@ def et_operator_library( "--include_all_operators", ) + prim_ops_genrule_cmd = [ + "$(exe //executorch/codegen/tools:gen_selected_prim_ops)", + "--prim_op_names=" + ",".join(prim_ops), + "--output_dir=${OUT}", + ] + # Here we generate the selected_prim_ops.h and the selected_operators.yaml file + # both with single genrule + genrule_cmd = genrule_cmd + [" && "] + prim_ops_genrule_cmd + # TODO(larryliu0820): Remove usages of this flag. if "define_static_targets" in kwargs: kwargs.pop("define_static_targets") @@ -134,7 +242,8 @@ def et_operator_library( name = name, macros_only = False, cmd = " ".join(genrule_cmd), - out = "selected_operators.yaml", + outs = {selected_operator_yaml_filename: [selected_operator_yaml_filename], selected_prim_ops_filename: [selected_prim_ops_filename]}, + default_outs = ["."], labels = ["et_operator_library"], **kwargs ) @@ -615,6 +724,31 @@ def selected_operators_genrule( platforms = platforms, ) +def selected_prim_operators_genrule( + name, + deps, + platforms = get_default_executorch_platforms(), +): + """Generates selected_prim_ops.h from the list of deps. We look into the transitive closure of all the deps, + and look for targets with label `et_operator_library`. + + `combine_prim_ops_headers` is the python binary we use to aggregate all the `selected_prim_ops.h` headers + from `et_prim_ops_library` targets into a single combined `selected_prim_ops.h` file. + + This file can be used to enable selective build for prim ops across multiple dependencies. + """ + cmd = ("$(exe //executorch/codegen/tools:combine_prim_ops_headers) " + + "--header_files $(@query_outputs \'attrfilter(labels, et_operator_library, deps(set({deps})))\') " + + "--output_dir $OUT ").format(deps = " ".join(["\"{}\"".format(d) for d in deps])) + runtime.genrule( + name = name, + macros_only = False, + cmd = cmd, + outs = {"selected_prim_ops.h": ["selected_prim_ops.h"]}, + default_outs = ["."], + platforms = platforms, + ) + def dtype_header_genrule( name, visibility, @@ -677,7 +811,8 @@ def executorch_generated_lib( dtype_selective_build = False, feature = None, expose_operator_symbols = False, - support_exceptions = True): + support_exceptions = True, + include_all_prim_ops = True): """Emits 0-3 C++ library targets (in fbcode or xplat) containing code to dispatch the operators specified in the provided yaml files. @@ -738,6 +873,9 @@ def executorch_generated_lib( support_exceptions: enable try/catch wrapper around operator implementations to make sure exceptions thrown will not bring down the process. Disable if your use case disables exceptions in the build. + include_all_prim_ops: If true, include all prim ops in the generated library. This option + allows for selecting only some prim ops to reduce code size for extremely constrained + environments. For selecting only some prim ops, see examples in //executorch/examples/selective_build """ if functions_yaml_target and aten_mode: fail("{} is providing functions_yaml_target in ATen mode, it will be ignored. `native_functions.yaml` will be the source of truth.".format(name)) @@ -903,6 +1041,12 @@ def executorch_generated_lib( if name in libs: lib_name = name + + if include_all_prim_ops: + prim_ops_registry_target = "//executorch/kernels/prim_ops:prim_ops_registry" + aten_suffix + else: + prim_ops_registry_target = _get_prim_ops_registry_target(name, deps, aten_suffix, platforms) + runtime.cxx_library( name = lib_name, srcs = [ @@ -927,7 +1071,7 @@ def executorch_generated_lib( }) + compiler_flags, deps = [ "//executorch/runtime/kernel:operator_registry" + aten_suffix, - "//executorch/kernels/prim_ops:prim_ops_registry" + aten_suffix, + prim_ops_registry_target, # Use the appropriate prim ops registry "//executorch/runtime/core:evalue" + aten_suffix, "//executorch/codegen:macros", ] + deps + kernel_deps,