Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions codegen/tools/combine_prim_ops_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Script to combine multiple selected_prim_ops.h header files into a single header.
This is used by selected_prim_operators_genrule to merge prim ops headers from dependencies.
"""

import argparse
import os
import sys
from pathlib import Path
from typing import List, Set


def read_header_file(file_path: Path) -> Set[str]:
"""
Read a selected_prim_ops.h file and extract the macros and comments.

Args:
file_path: Path to the header file

Returns:
macros_set where macros_set contains unique macro defines
"""
macros = set()

try:
with open(file_path, "r") as f:
for line in f:
line = line.strip()

# Extract #define statements for prim ops
if line.startswith("#define INCLUDE_") and not line.startswith(
"#define EXECUTORCH_ENABLE"
):
macros.add(line)
except FileNotFoundError:
print(f"Warning: Header file not found: {file_path}", file=sys.stderr)
except Exception as e:
print(f"Error reading {file_path}: {e}", file=sys.stderr)

return macros


def combine_prim_ops_headers(header_file_paths: List[str], output_path: str) -> None:
"""
Combine multiple selected_prim_ops.h files into a single header.

Args:
header_files: List of paths to header files to combine
output_path: Path to output the combined header
"""
all_macros = set()
has_selective_build = False

# Read all header files and collect unique macros
for header_file_path in header_file_paths:
header_file = Path(header_file_path) / "selected_prim_ops.h"
if os.path.exists(header_file):
macros = read_header_file(header_file)
all_macros.update(macros)
if len(all_macros) > 0:
has_selective_build = True
else:
print(
f"Warning: Header file does not exist: {header_file}", file=sys.stderr
)

# Generate combined header
header_content = [
"// Combined header for selective prim ops build",
"// This file is auto-generated by combining multiple selected_prim_ops.h files",
"// Do not edit manually.",
"",
"#pragma once",
"",
]

if all_macros and has_selective_build:
header_content.extend(
[
"// Enable selective build for prim ops",
"#define EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD",
"",
"// Combined prim ops macros from all dependencies",
]
)

# Sort macros for deterministic output
sorted_macros = sorted(all_macros)
header_content.extend(sorted_macros)
else:
header_content.extend(
[
"// No prim ops found in dependencies - all prim ops will be included",
"// Selective build is disabled",
]
)

header_content.append("")

# Write the combined header
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w") as f:
f.write("\n".join(header_content))


def _get_header_file_paths_from_query_output(query_output_file: str) -> List[str]:
"""
Parse the output of a Buck query command to extract header file paths.

Args:
query_output_file: Path to the file containing the query output

Returns:
List of header file paths
"""
header_file_paths = []
assert (
query_output_file[0] == "@"
), "query_output_file is not a valid file path, or it doesn't start with '@'."
query_output_file = query_output_file[1:]

with open(query_output_file, "r") as f:
for line in f:
# Extract the header file path from the query output
header_file_paths += line.split()
return header_file_paths


def main():
parser = argparse.ArgumentParser(
description="Combine multiple selected_prim_ops.h header files"
)
parser.add_argument(
"--header_files",
required=True,
help="Comma-separated list of header file paths",
)
parser.add_argument(
"--output_dir", required=True, help="Output directory for combined header"
)

args = parser.parse_args()
import os

header_file_paths = _get_header_file_paths_from_query_output(args.header_files)

if not header_file_paths:
print("Error: No header files provided", file=sys.stderr)
sys.exit(1)

# Generate output path
output_path = os.path.join(args.output_dir, "selected_prim_ops.h")

combine_prim_ops_headers(header_file_paths, output_path)


if __name__ == "__main__":
main()
20 changes: 19 additions & 1 deletion codegen/tools/gen_all_oplist.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import sys
from functools import reduce
from pathlib import Path
from typing import Any, List
from typing import Any, Dict, List

import yaml
from torchgen.selective_build.selector import (
Expand Down Expand Up @@ -72,6 +72,19 @@ def _raise_if_check_prim_ops_fail(options):
raise Exception(error)


def _selected_ops_model_dict_is_empty(model_dict: Dict[str, Any]) -> bool:
return (
not model_dict.get("build_features", [])
and not model_dict.get("custom_classes", [])
and not model_dict.get("et_kernel_metadata", None)
and not model_dict.get("include_all_non_op_selectives", False)
and not model_dict.get("include_all_operators", False)
and not model_dict.get("kernel_metadata", {})
and not model_dict.get("operators", {})
)


# flake8: noqa: C901
def main(argv: List[Any]) -> None:
"""This binary generates 3 files:

Expand Down Expand Up @@ -171,6 +184,11 @@ def main(argv: List[Any]) -> None:
), f"{model_file_name} is not a valid file path. This is likely a BUCK issue."
with open(model_file_name, "rb") as model_file:
model_dict = yaml.safe_load(model_file)
# It is possible that we created an empty yaml file.
# This is because et_operator_library may only contain prim ops.
# In that case selected_operators.yaml will be empty.
if _selected_ops_model_dict_is_empty(model_dict):
continue
resolved = resolve_model_file_path_to_buck_target(model_file_name)
for op in model_dict["operators"]:
model_dict["operators"][op]["debug_info"] = [resolved]
Expand Down
20 changes: 15 additions & 5 deletions codegen/tools/gen_oplist.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import sys
from enum import IntEnum
from pathlib import Path
from typing import Any, Dict, List, Optional, Set

import yaml
Expand Down Expand Up @@ -158,7 +159,7 @@ def _get_et_kernel_metadata_from_ops_yaml(ops_yaml_path: str) -> Dict[str, List[

def _dump_yaml(
op_list: List[str],
output_path: str,
output_path: Path,
model_name: Optional[str] = None,
et_kernel_metadata: Optional[Dict[str, List[str]]] = None,
include_all_operators: bool = False,
Expand Down Expand Up @@ -212,20 +213,23 @@ def create_kernel_key(maybe_kernel_key: str) -> str:


def gen_oplist(
output_path: str,
output_path: Path,
model_file_path: Optional[str] = None,
ops_schema_yaml_path: Optional[str] = None,
root_ops: Optional[str] = None,
ops_dict: Optional[str] = None,
include_all_operators: bool = False,
):
assert (
if not (
model_file_path
or ops_schema_yaml_path
or root_ops
or ops_dict
or include_all_operators
), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or ops_dict or include_all_operators."
):
# dump empty yaml file
_dump_yaml([], output_path)
return

assert output_path, "Need to provide output_path for dumped yaml file."
op_set = set()
Expand Down Expand Up @@ -326,9 +330,15 @@ def main(args: List[Any]) -> None:
)
options = parser.parse_args(args)

# check if the output_path is a directory, then generate operators
# under selected_operators.yaml
if Path(options.output_path).is_dir():
output_path = Path(options.output_path) / "selected_operators.yaml"
else:
output_path = Path(options.output_path)
try:
gen_oplist(
output_path=options.output_path,
output_path=output_path,
model_file_path=options.model_file_path,
ops_schema_yaml_path=options.ops_schema_yaml_path,
root_ops=options.root_ops,
Expand Down
96 changes: 96 additions & 0 deletions codegen/tools/gen_selected_prim_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import argparse
import os
import sys
from typing import Any, List

from torchgen.code_template import CodeTemplate # type: ignore[import-not-found]


selected_prim_ops_h_template_str = """#pragma once
/**
* Generated by executorch/codegen/tools/gen_selected_prim_ops.py
*/

$defines
"""
selected_prim_ops_h_template = CodeTemplate(selected_prim_ops_h_template_str)


def normalize_op_name(op_name: str) -> str:
"""
Normalize an operator name to a macro-safe format.
Convert op names like "executorch_prim::et_view.default" to "EXECUTORCH_PRIM_ET_VIEW_DEFAULT"
or "aten::sym_size.int" to "ATEN_SYM_SIZE_INT"
"""
# Remove namespace separator and replace with underscore
normalized = op_name.replace("::", "_")
# Replace dots with underscores
normalized = normalized.replace(".", "_")
# Convert to uppercase
normalized = normalized.upper()
# Add INCLUDE_ prefix
normalized = f"INCLUDE_{normalized}"
return normalized


def write_selected_prim_ops(prim_op_names: List[str], output_dir: str) -> None:
"""
Generate selected_prim_ops.h from a list of prim op names.

Args:
prim_op_names: List of prim op names like ["executorch_prim::et_view.default", "aten::sym_size.int"]
output_dir: Directory where to write selected_prim_ops.h
"""
# Generate #define statements for each op
defines = []
for op_name in prim_op_names:
macro_name = normalize_op_name(op_name)
defines.append(f"#define {macro_name}")

# Join all defines with newlines
defines_str = "\n".join(defines)

# Generate header content
header_contents = selected_prim_ops_h_template.substitute(defines=defines_str)

# Write to file
selected_prim_ops_path = os.path.join(output_dir, "selected_prim_ops.h")
with open(selected_prim_ops_path, "wb") as out_file:
out_file.write(header_contents.encode("utf-8"))


def main(argv: List[Any]) -> None:
parser = argparse.ArgumentParser(description="Generate selected prim ops header")
parser.add_argument(
"--prim-op-names",
"--prim_op_names",
help="Comma-separated list of prim op names to include",
required=True,
)
parser.add_argument(
"--output-dir",
"--output_dir",
help="The directory to store the output header file (selected_prim_ops.h)",
required=True,
)

options = parser.parse_args(argv)

# Parse comma-separated prim op names
prim_op_names = [
name.strip() for name in options.prim_op_names.split(",") if name.strip()
]

write_selected_prim_ops(prim_op_names, options.output_dir)


if __name__ == "__main__":
main(sys.argv[1:])
Loading
Loading