Skip to content

Commit 00d3dc7

Browse files
lucylqfacebook-github-bot
authored andcommitted
dtype selective build, link codegen and scalar_type_util (#1090)
Summary: Pull Request resolved: #1090 1. Let codegen generate `selected_mobile_ops.h` based on `selected_operators.yaml` 2. Include header in `scalar_type_utils`. Add to visibility for codegen. 3. Create test for dtype selective build in scalar_utils. 4. Add buck config for dtype selective build. Dependencies: 1. Make scalar_type_utils depend on codegen headers (header lib only) to access `selected_mobile_ops.h`. 2. Split deps into `deps` and `kernel_deps`. Add `kernel_deps` to executorch_generated_library. There is a circular dependency if we leave it as is, where: - scalar_type_utils -> executorch_generated_lib (added dep) - executorch_generated_lib -> portable_ops (portable ops are a dependency to the genrule for `selected_operators.yaml`, but they do not need to be) - portable_ops -> kernel impls (eg. `op_add.cpp` etc.) - kernel impls -> scalar_type_utils (for macros) User flow: 1. User passes in the required executorch_generated_lib through buck config 2. See example in selective build examples. To do: include another dtype selection example when we can use model.pte for selective build. Reviewed By: larryliu0820 Differential Revision: D50561135 fbshipit-source-id: c52b6318bc1909c9c73a0ba231ecc2a6f01d1d22
1 parent 423f3bc commit 00d3dc7

File tree

12 files changed

+266
-24
lines changed

12 files changed

+266
-24
lines changed

codegen/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def define_common_targets():
2929
visibility = [
3030
"//executorch/runtime/kernel/...",
3131
"//executorch/kernels/...",
32+
"//executorch/runtime/core/exec_aten/...",
3233
"@EXECUTORCH_CLIENTS",
3334
],
3435
)

codegen/tools/gen_selected_mobile_ops.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
import argparse
89
import os
10+
import sys
11+
from typing import Any, List
912

1013
import yaml
1114

@@ -57,16 +60,14 @@
5760
}
5861

5962

60-
def write_selected_mobile_ops(output_file_path: str) -> None:
61-
with open(
62-
os.path.join(output_file_path, "selected_operators.yaml"), "r"
63-
) as selected_operators_file:
63+
def write_selected_mobile_ops(yaml_file_path: str, output_dir: str) -> None:
64+
with open(yaml_file_path, "r") as selected_operators_file:
6465
# Collect et_kernel_metadata from selected_operators.yaml and extract dtypes
6566
# Example format: v1/6;0,1|6;0,1|6;0,1|6;0,1 # Float, 0, 1
6667
selected_operators_dict = yaml.safe_load(selected_operators_file)
6768
et_kernel_metadata = selected_operators_dict.get("et_kernel_metadata", {})
6869
assert isinstance(et_kernel_metadata, dict)
69-
body = "return true;"
70+
body = "true"
7071
body_parts = []
7172
for operator_name, kernel_metadata_str in et_kernel_metadata.items():
7273
tensor_meta = []
@@ -85,14 +86,38 @@ def write_selected_mobile_ops(output_file_path: str) -> None:
8586
]
8687
body_parts.append(
8788
ops_and_dtypes_template.substitute(
88-
operator_name=operator_name.removeprefix("aten::"),
89+
operator_name=operator_name.replace("aten::", ""),
8990
dtype_checks=" || ".join(conditions),
9091
),
9192
)
9293
body = "\n || ".join(body_parts)
93-
header_contents = selected_kernel_dtypes_h_template.substitute(body=body)
94-
selected_mobile_ops_path = os.path.join(
95-
output_file_path, "selected_mobile_ops.h"
96-
)
97-
with open(selected_mobile_ops_path, "wb") as out_file:
98-
out_file.write(header_contents.encode("utf-8"))
94+
header_contents = selected_kernel_dtypes_h_template.substitute(body=body)
95+
selected_mobile_ops_path = os.path.join(output_dir, "selected_mobile_ops.h")
96+
with open(selected_mobile_ops_path, "wb") as out_file:
97+
out_file.write(header_contents.encode("utf-8"))
98+
99+
100+
def main(argv: List[Any]) -> None:
101+
parser = argparse.ArgumentParser(description="Generate operator lists")
102+
parser.add_argument(
103+
"--yaml-file-path",
104+
"--yaml_file_path",
105+
help=("The directory where selected_operators.yaml was generated)"),
106+
required=True,
107+
)
108+
parser.add_argument(
109+
"--output-dir",
110+
"--output_dir",
111+
help=(
112+
"The directory to store the output yaml files (selected_mobile_ops.h, "
113+
+ "selected_kernel_dtypes.h, selected_operators.yaml)"
114+
),
115+
required=True,
116+
)
117+
118+
options = parser.parse_args(argv)
119+
write_selected_mobile_ops(options.yaml_file_path, options.output_dir)
120+
121+
122+
if __name__ == "__main__":
123+
main(sys.argv[1:])

codegen/tools/targets.bzl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,22 @@ def define_common_targets(is_fbcode = False):
123123
srcs = ["gen_selected_mobile_ops.py"],
124124
base_module = "executorch.codegen.tools",
125125
visibility = ["//executorch/..."],
126+
external_deps = [
127+
"gen-oplist-lib",
128+
],
129+
)
130+
131+
runtime.python_binary(
132+
name = "gen_selected_mobile_ops",
133+
main_module = "executorch.codegen.tools.gen_selected_mobile_ops",
134+
package_style = "inplace",
135+
visibility = [
136+
"PUBLIC",
137+
],
138+
deps = [
139+
":gen_selected_mobile_ops_lib",
140+
],
141+
_is_external_target = True,
126142
)
127143

128144
runtime.python_test(
@@ -136,7 +152,6 @@ def define_common_targets(is_fbcode = False):
136152
],
137153
deps = [
138154
":gen_selected_mobile_ops_lib",
139-
":gen_oplist_lib",
140155
"fbsource//third-party/pypi/expecttest:expecttest",
141156
],
142157
_is_external_target = True,

codegen/tools/test/test_gen_selected_mobile_ops.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
class TestGenSelectedMobileOpsHeader(expecttest.TestCase):
1616
def setUp(self):
1717
self.temp_dir = tempfile.TemporaryDirectory()
18+
self.addCleanup(self.temp_dir.cleanup)
1819
self.selected_ops_yaml = os.path.join(
1920
self.temp_dir.name, "selected_operators.yaml"
2021
)
@@ -54,7 +55,10 @@ def tearDown(self):
5455
self.temp_dir.cleanup()
5556

5657
def test_generates_correct_header(self) -> None:
57-
gen_selected_mobile_ops.write_selected_mobile_ops(self.temp_dir.name)
58+
gen_selected_mobile_ops.write_selected_mobile_ops(
59+
os.path.join(self.temp_dir.name, "selected_operators.yaml"),
60+
self.temp_dir.name,
61+
)
5862
with open(
5963
os.path.join(self.temp_dir.name, "selected_mobile_ops.h"), "r"
6064
) as result:
@@ -78,3 +82,51 @@ def test_generates_correct_header(self) -> None:
7882
}
7983
""",
8084
)
85+
86+
87+
class TestGenSelectedMobileOpsHeader_Empty(expecttest.TestCase):
88+
def setUp(self):
89+
self.temp_dir = tempfile.TemporaryDirectory()
90+
self.addCleanup(self.temp_dir.cleanup)
91+
self.selected_ops_yaml = os.path.join(
92+
self.temp_dir.name, "selected_operators.yaml"
93+
)
94+
with open(self.selected_ops_yaml, "w") as f:
95+
f.write(
96+
"""
97+
build_features: []
98+
custom_classes: []
99+
et_kernel_metadata: {}
100+
include_all_non_op_selectives: false
101+
include_all_operators: true
102+
kernel_metadata: {}
103+
operators: {}
104+
"""
105+
)
106+
107+
def tearDown(self):
108+
self.temp_dir.cleanup()
109+
110+
def test_generates_correct_header(self) -> None:
111+
gen_selected_mobile_ops.write_selected_mobile_ops(
112+
os.path.join(self.temp_dir.name, "selected_operators.yaml"),
113+
self.temp_dir.name,
114+
)
115+
with open(
116+
os.path.join(self.temp_dir.name, "selected_mobile_ops.h"), "r"
117+
) as result:
118+
self.assertExpectedInline(
119+
result.read(),
120+
"""#pragma once
121+
/**
122+
* Generated by executorch/codegen/tools/gen_selected_mobile_ops_header.py
123+
*/
124+
125+
inline constexpr bool should_include_kernel_dtype(
126+
const char *operator_name,
127+
exec_aten::ScalarType scalar_type
128+
) {
129+
return true;
130+
}
131+
""",
132+
)

examples/selective_build/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Check out `targets.bzl` for demo of 3 selective build APIs:
2121

2222
Other configs:
2323
- `--config executorch.max_kernel_num=N`: Only allocate memory for the required number of operators. Take this result from `selected_operators.yaml`.
24+
- `--config executorch.dtype_selective_build_lib=<executorch_generated_lib_name>`: Use dtype selective build. For each op, we register the dtypes that are used. Eg. if the model only uses the float implementation of add, then only the float add will be registered. Pass in the executorch_generated_lib name (see buck2 list example).
2425

2526
## CMake examples
2627

examples/selective_build/targets.bzl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ def define_common_targets():
1717
executorch_generated_lib(
1818
name = "select_all_lib",
1919
functions_yaml_target = "//executorch/kernels/portable:functions.yaml",
20-
deps = [
20+
kernel_deps = [
2121
"//executorch/kernels/portable:operators",
22+
],
23+
deps = [
2224
":select_all_ops",
2325
],
2426
)
@@ -35,10 +37,13 @@ def define_common_targets():
3537
executorch_generated_lib(
3638
name = "select_ops_in_list_lib",
3739
functions_yaml_target = "//executorch/kernels/portable:functions.yaml",
38-
deps = [
40+
kernel_deps = [
3941
"//executorch/kernels/portable:operators",
42+
],
43+
deps = [
4044
":select_ops_in_list",
4145
],
46+
visibility = ["//executorch/runtime/core/..."],
4247
)
4348

4449
# Select all ops from a yaml file
@@ -50,9 +55,11 @@ def define_common_targets():
5055
executorch_generated_lib(
5156
name = "select_ops_from_yaml_lib",
5257
custom_ops_yaml_target = "//executorch/examples/portable/custom_ops:custom_ops.yaml",
53-
deps = [
58+
kernel_deps = [
5459
"//executorch/examples/portable/custom_ops:custom_ops_1",
5560
"//executorch/examples/portable/custom_ops:custom_ops_2",
61+
],
62+
deps = [
5663
":select_ops_from_yaml",
5764
],
5865
)

examples/selective_build/test_selective_build.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ test_buck2_select_ops_in_list() {
3535
# set max_kernel_num=17: 14 primops, add, mul
3636
$BUCK run //examples/selective_build:selective_build_test \
3737
--config=executorch.max_kernel_num=17 \
38-
--config=executorch.select_ops=list -- --model_path=./add_mul.pte
38+
--config=executorch.select_ops=list \
39+
--config=executorch.dtype_selective_build_lib="//examples/selective_build:select_ops_in_list_lib" \
40+
-- --model_path=./add_mul.pte
3941

4042
echo "Removing add_mul.pte"
4143
rm "./add_mul.pte"

runtime/core/exec_aten/util/targets.bzl

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib")
23

34
def define_common_targets():
45
"""Defines targets that should be shared between fbcode and xplat.
@@ -7,9 +8,68 @@ def define_common_targets():
78
TARGETS and BUCK files that call this function.
89
"""
910

11+
et_operator_library(
12+
name = "two_ops",
13+
ops = [
14+
"aten::add.out",
15+
"aten::mm.out",
16+
],
17+
)
18+
19+
executorch_generated_lib(
20+
name = "two_ops_lib",
21+
functions_yaml_target = "//executorch/kernels/portable:functions.yaml",
22+
kernel_deps = ["//executorch/kernels/portable:operators"],
23+
deps = [":two_ops"],
24+
)
25+
26+
runtime.cxx_library(
27+
name = "scalar_type_util_TEST_ONLY",
28+
srcs = [],
29+
exported_headers = [
30+
"scalar_type_util.h",
31+
],
32+
visibility = [
33+
"//executorch/...",
34+
"@EXECUTORCH_CLIENTS",
35+
],
36+
exported_preprocessor_flags = ["-DEXECUTORCH_SELECTIVE_BUILD_DTYPE"],
37+
exported_deps = [
38+
"//executorch/runtime/core/portable_type:scalar_type",
39+
":two_ops_lib_headers",
40+
],
41+
)
42+
43+
dtype_selective_build_lib = native.read_config("executorch", "dtype_selective_build_lib", None)
44+
if dtype_selective_build_lib != None:
45+
# retrieve selected_mobile_ops.h from codegen
46+
genrule_name = dtype_selective_build_lib + "_et_op_dtype_gen[selected_mobile_ops]"
47+
runtime.cxx_library(
48+
name = "dtype_headers",
49+
srcs = [],
50+
exported_headers = {
51+
"selected_mobile_ops.h": genrule_name,
52+
},
53+
visibility = [
54+
"//executorch/...",
55+
"@EXECUTORCH_CLIENTS",
56+
],
57+
)
58+
1059
for aten_mode in (True, False):
1160
aten_suffix = "_aten" if aten_mode else ""
1261

62+
exported_preprocessor_flags_ = []
63+
exported_deps_ = ["//executorch/runtime/core:core"]
64+
if aten_mode:
65+
exported_preprocessor_flags_ += ["-DUSE_ATEN_LIB"]
66+
else:
67+
exported_deps_ += ["//executorch/runtime/core/portable_type:scalar_type"]
68+
69+
if dtype_selective_build_lib != None:
70+
exported_preprocessor_flags_ += ["-DEXECUTORCH_SELECTIVE_BUILD_DTYPE"]
71+
exported_deps_ += [":dtype_headers"]
72+
1373
runtime.cxx_library(
1474
name = "scalar_type_util" + aten_suffix,
1575
srcs = [],
@@ -20,10 +80,8 @@ def define_common_targets():
2080
"//executorch/...",
2181
"@EXECUTORCH_CLIENTS",
2282
],
23-
exported_preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [],
24-
exported_deps = [
25-
"//executorch/runtime/core:core",
26-
] + [] if aten_mode else ["//executorch/runtime/core/portable_type:scalar_type"],
83+
exported_preprocessor_flags = exported_preprocessor_flags_,
84+
exported_deps = exported_deps_,
2785
exported_external_deps = ["libtorch"] if aten_mode else [],
2886
)
2987

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
10+
#include <executorch/test/utils/DeathTest.h>
11+
#include <gtest/gtest.h>
12+
13+
using namespace ::testing;
14+
using exec_aten::ScalarType;
15+
using torch::executor::ScalarTypeToCppType;
16+
17+
// Add test for specific dtypes when we can run on model file.
18+
19+
TEST(SelectedMobileOpsHeaderTest, UnknownOp) {
20+
ET_EXPECT_DEATH(
21+
ET_SWITCH_TWO_TYPES(
22+
Float,
23+
Int,
24+
exec_aten::ScalarType::Float,
25+
ctx,
26+
"addmm.out",
27+
CTYPE_OUT,
28+
[&] { return true; }),
29+
"");
30+
}
31+
32+
TEST(SelectedMobileOpsHeaderTest, OpWithDtype) {
33+
ASSERT_EQ(
34+
ET_SWITCH_TWO_TYPES(
35+
Float,
36+
Int,
37+
exec_aten::ScalarType::Float,
38+
ctx,
39+
"add.out",
40+
CTYPE_OUT,
41+
[&] { return true; }),
42+
true);
43+
ASSERT_EQ(
44+
ET_SWITCH_TWO_TYPES(
45+
Float,
46+
Int,
47+
exec_aten::ScalarType::Int,
48+
ctx,
49+
"mm.out",
50+
CTYPE_OUT,
51+
[&] { return true; }),
52+
true);
53+
}

0 commit comments

Comments
 (0)