From f5c7bf26f655c72aa7fabcecae293e908ba0a358 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Thu, 22 May 2025 14:44:29 -0700 Subject: [PATCH] Copy unit tests from torchgen to ET codegen (#11074) Summary: As a follow up to #10939, we need to add the same coverage to this code. Reviewed By: kirklandsign, manuelcandales Differential Revision: D75236020 --- codegen/test/TARGETS | 8 + codegen/test/targets.bzl | 21 + codegen/test/test_executorch_custom_ops.py | 153 +++++ codegen/test/test_executorch_gen.py | 695 +++++++++++++++++++++ codegen/test/test_executorch_signatures.py | 65 ++ codegen/test/test_executorch_types.py | 121 ++++ codegen/test/test_executorch_unboxing.py | 165 +++++ codegen/test/test_selective_build.py | 51 ++ pytest.ini | 1 + 9 files changed, 1280 insertions(+) create mode 100644 codegen/test/TARGETS create mode 100644 codegen/test/targets.bzl create mode 100644 codegen/test/test_executorch_custom_ops.py create mode 100644 codegen/test/test_executorch_gen.py create mode 100644 codegen/test/test_executorch_signatures.py create mode 100644 codegen/test/test_executorch_types.py create mode 100644 codegen/test/test_executorch_unboxing.py create mode 100644 codegen/test/test_selective_build.py diff --git a/codegen/test/TARGETS b/codegen/test/TARGETS new file mode 100644 index 00000000000..1e8cc179228 --- /dev/null +++ b/codegen/test/TARGETS @@ -0,0 +1,8 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain xplat-only targets. + +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/codegen/test/targets.bzl b/codegen/test/targets.bzl new file mode 100644 index 00000000000..bf21a594554 --- /dev/null +++ b/codegen/test/targets.bzl @@ -0,0 +1,21 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_oss_build_kwargs", "runtime") + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + + runtime.python_test( + name = "test_gen", + srcs = glob(["test_*.py"]), + package_style = "inplace", + deps = [ + "//executorch/codegen:gen_lib", + "fbsource//third-party/pypi/expecttest:expecttest", + ], + external_deps = [ + "torchgen", + ], + ) diff --git a/codegen/test/test_executorch_custom_ops.py b/codegen/test/test_executorch_custom_ops.py new file mode 100644 index 00000000000..847f87ab352 --- /dev/null +++ b/codegen/test/test_executorch_custom_ops.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import tempfile +import unittest +from typing import Any +from unittest.mock import ANY, Mock, patch + +import expecttest + +import torchgen +from executorch.codegen.api.custom_ops import ComputeNativeFunctionStub +from executorch.codegen.model import ETKernelIndex +from torchgen.gen_executorch import gen_headers +from torchgen.model import Location, NativeFunction +from torchgen.selective_build.selector import SelectiveBuilder +from torchgen.utils import FileManager + + +SPACES = " " + + +def _get_native_function_from_yaml(yaml_obj: dict[str, object]) -> NativeFunction: + native_function, _ = NativeFunction.from_yaml( + yaml_obj, + loc=Location(__file__, 1), + valid_tags=set(), + ) + return native_function + + +class TestComputeNativeFunctionStub(expecttest.TestCase): + """ + Could use torch.testing._internal.common_utils to reduce boilerplate. + GH CI job doesn't build torch before running tools unit tests, hence + manually adding these parametrized tests. + """ + + def _test_function_schema_generates_correct_kernel( + self, obj: dict[str, Any], expected: str + ) -> None: + func = _get_native_function_from_yaml(obj) + + gen = ComputeNativeFunctionStub() + res = gen(func) + self.assertIsNotNone(res) + self.assertExpectedInline( + str(res), + expected, + ) + + def test_function_schema_generates_correct_kernel_tensor_out(self) -> None: + obj = {"func": "custom::foo.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"} + expected = """ +at::Tensor & wrapper_CPU_out_foo_out(const at::Tensor & self, at::Tensor & out) { + return out; +} + """ + self._test_function_schema_generates_correct_kernel(obj, expected) + + def test_function_schema_generates_correct_kernel_no_out(self) -> None: + obj = {"func": "custom::foo.Tensor(Tensor self) -> Tensor"} + expected = """ +at::Tensor wrapper_CPU_Tensor_foo(const at::Tensor & self) { + return self; +} + """ + self._test_function_schema_generates_correct_kernel(obj, expected) + + def test_function_schema_generates_correct_kernel_no_return(self) -> None: + obj = {"func": "custom::foo.out(Tensor self, *, Tensor(a!)[] out) -> ()"} + expected = f""" +void wrapper_CPU_out_foo_out(const at::Tensor & self, at::TensorList out) {{ +{SPACES} +}} + """ + self._test_function_schema_generates_correct_kernel(obj, expected) + + def test_function_schema_generates_correct_kernel_3_returns(self) -> None: + obj = { + "func": "custom::foo(Tensor self, Tensor[] other) -> (Tensor, Tensor, Tensor)" + } + expected = """ +::std::tuple wrapper_CPU__foo(const at::Tensor & self, at::TensorList other) { + return ::std::tuple( + at::Tensor(), at::Tensor(), at::Tensor() + ); +} + """ + self._test_function_schema_generates_correct_kernel(obj, expected) + + def test_function_schema_generates_correct_kernel_1_return_no_out(self) -> None: + obj = {"func": "custom::foo(Tensor[] a) -> Tensor"} + expected = """ +at::Tensor wrapper_CPU__foo(at::TensorList a) { + return at::Tensor(); +} + """ + self._test_function_schema_generates_correct_kernel(obj, expected) + + def test_schema_has_no_return_type_argument_throws(self) -> None: + func = _get_native_function_from_yaml( + {"func": "custom::foo.bool(Tensor self) -> bool"} + ) + + gen = ComputeNativeFunctionStub() + with self.assertRaisesRegex(Exception, "Can't handle this return type"): + gen(func) + + +class TestGenCustomOpsHeader(unittest.TestCase): + @patch.object(torchgen.utils.FileManager, "write_with_template") + @patch.object(torchgen.utils.FileManager, "write") + def test_fm_writes_custom_ops_header_when_boolean_is_true( + self, unused: Mock, mock_method: Mock + ) -> None: + with tempfile.TemporaryDirectory() as tempdir: + fm = FileManager(tempdir, tempdir, False) + gen_headers( + native_functions=[], + gen_custom_ops_header=True, + custom_ops_native_functions=[], + selector=SelectiveBuilder.get_nop_selector(), + kernel_index=ETKernelIndex(index={}), # type: ignore[arg-type] + cpu_fm=fm, + use_aten_lib=False, + ) + mock_method.assert_called_once_with( + "CustomOpsNativeFunctions.h", "NativeFunctions.h", ANY + ) + + @patch.object(torchgen.utils.FileManager, "write_with_template") + @patch.object(torchgen.utils.FileManager, "write") + def test_fm_doesnot_writes_custom_ops_header_when_boolean_is_false( + self, unused: Mock, mock_method: Mock + ) -> None: + with tempfile.TemporaryDirectory() as tempdir: + fm = FileManager(tempdir, tempdir, False) + gen_headers( + native_functions=[], + gen_custom_ops_header=False, + custom_ops_native_functions=[], + selector=SelectiveBuilder.get_nop_selector(), + kernel_index=ETKernelIndex(index={}), # type: ignore[arg-type] + cpu_fm=fm, + use_aten_lib=False, + ) + mock_method.assert_not_called() diff --git a/codegen/test/test_executorch_gen.py b/codegen/test/test_executorch_gen.py new file mode 100644 index 00000000000..23dcbecf64a --- /dev/null +++ b/codegen/test/test_executorch_gen.py @@ -0,0 +1,695 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import os +import tempfile +import unittest + +import yaml + +from executorch.codegen.model import ETKernelIndex, ETKernelKey +from torchgen.gen import LineLoader +from torchgen.gen_executorch import ( + ComputeCodegenUnboxedKernels, + gen_functions_declarations, + parse_yaml_files, + translate_native_yaml, +) +from torchgen.model import ( + BackendIndex, + BackendMetadata, + DispatchKey, + Location, + NativeFunction, + OperatorName, +) +from torchgen.selective_build.selector import SelectiveBuilder + + +TEST_YAML = """ +- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + ufunc_inner_loop: + Generic: add (AllAndComplex, BFloat16, Half, ComplexHalf) + ScalarOnly: add (Bool) + dispatch: + SparseCPU: add_out_sparse_cpu + SparseCUDA: add_out_sparse_cuda + SparseCsrCPU: add_out_sparse_csr_cpu + SparseCsrCUDA: add_out_sparse_csr_cuda + MkldnnCPU: mkldnn_add_out + MPS: add_out_mps + +- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: add.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: add_sparse + SparseCsrCPU, SparseCsrCUDA: add_sparse_csr + MkldnnCPU: mkldnn_add + ZeroTensor: add_zerotensor + NestedTensorCPU, NestedTensorCUDA: NestedTensor_add_Tensor + tags: core + +- func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + dispatch: + CPU, CUDA: mul_out + MPS: mul_out_mps + SparseCPU: mul_out_sparse_cpu + SparseCUDA: mul_out_sparse_cuda + SparseCsrCPU, SparseCsrCUDA: mul_out_sparse_csr + MkldnnCPU: mkldnn_mul_out + +- func: mul.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: mul.out + variants: function, method + dispatch: + SparseCPU, SparseCUDA: mul_sparse + SparseCsrCPU, SparseCsrCUDA: mul_sparse_csr + MkldnnCPU: mkldnn_mul + ZeroTensor: mul_zerotensor + NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul_Tensor + tags: core + +""" + + +TEST_KERNEL_YAML = """ +- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + ufunc_inner_loop: + Generic: add (AllAndComplex, BFloat16, Half, ComplexHalf) + ScalarOnly: add (Bool) + type_alias: + T0: [Float, Double] + T1: [Double, Int] + dim_order_alias: + D0: [0, 1, 2, 3] + D1: [0, 3, 2, 1] + kernels: + - arg_meta: null + kernel_name: default_impl + - arg_meta: + self: [T0, D0] + other: [T1, D0] + out: [T0, D0] + kernel_name: test_impl + - arg_meta: + self: [T1, D0] + other: [T1, D1] + out: [T0, D1] + kernel_name: test_impl_2 + +- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: add.out + variants: function, method + tags: core + +- func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + type_alias: + T0: [Float] + T1: [Double] + dim_order_alias: + D0: [0, 1, 2, 3] + kernels: + - arg_meta: null + kernel_name: default_impl + - arg_meta: + self: [T0, D0] + other: [T1, D0] + out: [T0, D0] + kernel_name: test_impl + +- func: mul.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: mul.out + variants: function, method + tags: core + +""" + + +class TestParseNativeYaml(unittest.TestCase): + def setUp(self) -> None: + self.temp_dir = tempfile.mkdtemp() + + self.aten_yaml_path = os.path.join(self.temp_dir, "test_native_functions.yaml") + with open(self.aten_yaml_path, "w") as f: + f.write(TEST_YAML) + self.ops_yaml_path = os.path.join(self.temp_dir, "test.yaml") + self.tags_yaml_path = os.path.join(self.temp_dir, "tags.yaml") + with open(self.tags_yaml_path, "w") as f: + f.write( + """ +- tag: core + desc: test + """ + ) + with open(self.ops_yaml_path, "w") as f: + f.write( + """ +- op: add.out + device_check: NoCheck # TensorIterator + dispatch: + CPU: torch::executor::add_out_kernel + +- op: mul.out + device_check: NoCheck # TensorIterator + dispatch: + CPU: torch::executor::mul_out_kernel + """ + ) + + def test_translate_native_yaml_writes_correct_data(self) -> None: + out_yaml_path = os.path.join(self.temp_dir, "out.yaml") + with open(out_yaml_path, "w") as out_file: + translate_native_yaml( + tags_yaml_path=self.tags_yaml_path, + aten_yaml_path=self.aten_yaml_path, + native_yaml_path=self.ops_yaml_path, + use_aten_lib=False, + out_file=out_file, + ) + with open(out_yaml_path) as out_file: + es = yaml.load(out_file, Loader=LineLoader) + self.assertTrue(all("func" in e for e in es)) + self.assertTrue(all(e.get("variants") == "function" for e in es)) + + # Check that kernel fields aren't introduced in yaml + for e in es: + self.assertFalse({"kernels", "type_alias", "dim_order_alias"} < e.keys()) + + def test_parse_yaml_files(self) -> None: + custom_ops_yaml_path = None + selector = SelectiveBuilder.get_nop_selector() + use_aten_lib = False + + parsed_yaml, custom_ops_parsed_yaml = parse_yaml_files( + aten_yaml_path=self.aten_yaml_path, + tags_yaml_path=self.tags_yaml_path, + native_yaml_path=self.ops_yaml_path, + custom_ops_yaml_path=custom_ops_yaml_path, + selector=selector, + use_aten_lib=use_aten_lib, + ) + + # Just the default kernel entry + expected_kernel_entry = {"add.out": 1, "mul.out": 1} + self.assertTrue(len(parsed_yaml.native_functions) == len(expected_kernel_entry)) + + op_entries = parsed_yaml.kernel_index.index + for op_name, kernel_mapping in op_entries.items(): + self.assertTrue( + len(kernel_mapping) == expected_kernel_entry.pop(str(op_name)) + ) + + self.assertTrue(len(expected_kernel_entry) == 0) + + def tearDown(self) -> None: + import shutil + + try: + shutil.rmtree(self.temp_dir) + except OSError: + pass + + +class TestParseKernelYamlFiles(unittest.TestCase): + def setUp(self) -> None: + self.temp_dir = tempfile.mkdtemp() + + self.aten_kernel_yaml_path = os.path.join( + self.temp_dir, "test_kernel_native_functions.yaml" + ) + with open(self.aten_kernel_yaml_path, "w") as f: + f.write(TEST_KERNEL_YAML) + self.ops_yaml_path = os.path.join(self.temp_dir, "test.yaml") + self.tags_yaml_path = os.path.join(self.temp_dir, "tags.yaml") + with open(self.tags_yaml_path, "w") as f: + f.write( + """ +- tag: core + desc: test + """ + ) + with open(self.ops_yaml_path, "w") as f: + f.write( + """ +- op: add.out + device_check: NoCheck # TensorIterator + dispatch: + CPU: torch::executor::add_out_kernel + +- op: mul.out + device_check: NoCheck # TensorIterator + dispatch: + CPU: torch::executor::mul_out_kernel + """ + ) + + def test_translate_kernel_native_yaml_writes_correct_data(self) -> None: + out_yaml_path = os.path.join(self.temp_dir, "out2.yaml") + with open(out_yaml_path, "w") as out_file: + translate_native_yaml( + tags_yaml_path=self.tags_yaml_path, + aten_yaml_path=self.aten_kernel_yaml_path, + native_yaml_path=self.ops_yaml_path, + use_aten_lib=False, + out_file=out_file, + ) + with open(out_yaml_path) as out_file: + es = yaml.load(out_file, Loader=LineLoader) + self.assertTrue(all("func" in e for e in es)) + self.assertTrue(all(e.get("variants") == "function" for e in es)) + + # Check persistence of kernel fields in yaml + for e in es: + self.assertTrue({"kernels", "type_alias", "dim_order_alias"} < e.keys()) + + def test_parse_yaml_files(self) -> None: + custom_ops_yaml_path = None + selector = SelectiveBuilder.get_nop_selector() + use_aten_lib = False + + parsed_yaml, custom_ops_parsed_yaml = parse_yaml_files( + aten_yaml_path=self.aten_kernel_yaml_path, + tags_yaml_path=self.tags_yaml_path, + native_yaml_path=self.ops_yaml_path, + custom_ops_yaml_path=custom_ops_yaml_path, + selector=selector, + use_aten_lib=use_aten_lib, + ) + + expected_kernel_entry = {"add.out": 9, "mul.out": 2} + self.assertTrue(len(parsed_yaml.native_functions) == len(expected_kernel_entry)) + + op_entries = parsed_yaml.kernel_index.index + for op_name, kernel_mapping in op_entries.items(): + self.assertTrue( + len(kernel_mapping) == expected_kernel_entry.pop(str(op_name)) + ) + + self.assertTrue(len(expected_kernel_entry) == 0) + + def tearDown(self) -> None: + import shutil + + try: + shutil.rmtree(self.temp_dir) + except OSError: + pass + + +class TestGenFunctionsDeclarations(unittest.TestCase): + def setUp(self) -> None: + ( + self.custom_1_native_function, + custom_1_backend_index, + ) = NativeFunction.from_yaml( + {"func": "custom_1::op_1() -> bool", "dispatch": {"CPU": "kernel_1"}}, + loc=Location(__file__, 1), + valid_tags=set(), + ) + ( + self.custom_2_native_function, + custom_2_backend_index, + ) = NativeFunction.from_yaml( + { + "func": "custom_2::op_2() -> bool", + "dispatch": {"CPU": "kernel_2"}, + }, + loc=Location(__file__, 1), + valid_tags=set(), + ) + ( + self.custom_3_native_function, + custom_3_backend_index, + ) = NativeFunction.from_yaml( + { + "func": "custom_3::op_3(Tensor(a!) self, Tensor x) -> Tensor(a!)", + "dispatch": {"CPU": "kernel_3"}, + "variants": "method", + }, + loc=Location(__file__, 1), + valid_tags=set(), + ) + + backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = { + DispatchKey.CPU: {}, + DispatchKey.QuantizedCPU: {}, + } + BackendIndex.grow_index(backend_indices, custom_1_backend_index) + BackendIndex.grow_index(backend_indices, custom_2_backend_index) + self.static_dispatch_idx = [ + BackendIndex( + dispatch_key=k, + use_out_as_primary=True, + external=False, + device_guard=False, + index=backend_indices[k], + ) + for k in backend_indices + ] + self.kernel_index = ETKernelIndex.from_backend_indices(backend_indices) + + def test_operators_with_different_namespaces_are_grouped_correctly(self) -> None: + declarations = gen_functions_declarations( + native_functions=[ + self.custom_1_native_function, + self.custom_2_native_function, + ], + kernel_index=self.kernel_index, # type: ignore[arg-type] + selector=SelectiveBuilder.get_nop_selector(), + use_aten_lib=False, + ) + self.assertTrue( + """ +namespace custom_1 { + +// custom_1::op_1() -> bool +TORCH_API inline bool op_1(torch::executor::KernelRuntimeContext & context) { + return ::at::native::kernel_1(context); +} + +} // namespace custom_1 +""" + in declarations + ) + + self.assertTrue( + """ +namespace custom_2 { + +// custom_2::op_2() -> bool +TORCH_API inline bool op_2(torch::executor::KernelRuntimeContext & context) { + return ::at::native::kernel_2(context); +} + +} // namespace custom_2 + """ + in declarations + ) + + def test_aten_lib_has_context_arg(self) -> None: + declarations = gen_functions_declarations( + native_functions=[ + self.custom_1_native_function, + ], + kernel_index=self.kernel_index, # type: ignore[arg-type] + selector=SelectiveBuilder.get_nop_selector(), + use_aten_lib=True, + ) + self.assertTrue( + """ +namespace custom_1 { + +// custom_1::op_1() -> bool +TORCH_API inline bool op_1(torch::executor::KernelRuntimeContext & context) { + return at::op_1(); +} + +} // namespace custom_1 + """ + in declarations + ) + + def test_aten_lib_method_variant(self) -> None: + declarations = gen_functions_declarations( + native_functions=[ + self.custom_3_native_function, + ], + kernel_index=self.kernel_index, # type: ignore[arg-type] + selector=SelectiveBuilder.get_nop_selector(), + use_aten_lib=True, + ) + self.assertTrue( + """ +namespace custom_3 { + +// custom_3::op_3(Tensor(a!) self, Tensor x) -> Tensor(a!) +TORCH_API inline at::Tensor & op_3(torch::executor::KernelRuntimeContext & context, at::Tensor & self, const at::Tensor & x) { + return self.op_3(x); +} + +} // namespace custom_3 + """ + in declarations + ) + + +class TestComputeCodegenUnboxedKernels(unittest.TestCase): + def setUp(self) -> None: + ( + self.native_function_no_kern, + _, + ) = NativeFunction.from_yaml( + { + "func": "custom_1::op_1() -> bool", + "dispatch": {"CPU": "unused_kernel_1"}, + }, + loc=Location(__file__, 1), + valid_tags=set(), + ) + + self.default_kernel_key = ETKernelKey(default=True) + self.default_backend_metadata = BackendMetadata( + "default_kernel", False, "at::native" + ) + self.default_kernel_entry = ( + [self.default_kernel_key], + self.default_backend_metadata, + ) + + def test_codegen_unboxed_specialized(self) -> None: + specialized_kernel_key = ETKernelKey.gen_from_yaml( + {"self": ("T0", "D0"), "other": ("T0", "D0"), "out": ("T0", "D0")}, + {"T0": ["Double"]}, + {"D0": [0, 1, 2, 3]}, + ) + selector = SelectiveBuilder.from_yaml_dict( + { + "include_all_operators": True, + "et_kernel_metadata": { + "custom_1::op_1": ["v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3"] + }, + } + ) + use_aten_lib = False + entry = ( + self.native_function_no_kern, + (specialized_kernel_key, self.default_backend_metadata), + ) + + result = ComputeCodegenUnboxedKernels( + selector, use_aten_lib, add_exception_boundary=False + )(entry) + # Concat used to prevent whitespace stripping + expected_str = ( + """ +Kernel( + "custom_1::op_1", + "v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3", + [](torch::executor::KernelRuntimeContext & context, EValue** stack) { + """ + + """ + + + internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_op_1"); + EXECUTORCH_SCOPE_PROF("native_call_op_1"); + bool result_ = at::native::default_kernel(context, ); + internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]); + + *stack[0] = EValue(result_); + + } +), +""" + ) + + self.assertEqual(expected_str, result) + + def test_codegen_unboxed_specialized_not_matching(self) -> None: + specialized_kernel_key = ETKernelKey.gen_from_yaml( + {"self": ("T0", "D0"), "other": ("T0", "D0"), "out": ("T0", "D0")}, + {"T0": ["Double"]}, + {"D0": [0, 1, 2, 3]}, + ) + selector = SelectiveBuilder.from_yaml_dict( + { + "include_all_operators": True, + "et_kernel_metadata": { + "custom_1::op_1": ["v1/8;0,1,2,3|7;0,1,2,3|7;0,1,2,3"] + }, + } + ) + use_aten_lib = False + entry = ( + self.native_function_no_kern, + (specialized_kernel_key, self.default_backend_metadata), + ) + + self.assertRaises( + Exception, + ComputeCodegenUnboxedKernels( + selector, use_aten_lib, add_exception_boundary=False + ), + entry, + ) + + def test_codegen_unboxed_specialized_missing_root_op(self) -> None: + specialized_kernel_key = ETKernelKey.gen_from_yaml( + {"self": ("T0", "D0"), "other": ("T0", "D0"), "out": ("T0", "D0")}, + {"T0": ["Double"]}, + {"D0": [0, 1, 2, 3]}, + ) + selector = SelectiveBuilder.from_yaml_dict( + { + "et_kernel_metadata": { + "custom_1::op_1": ["v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3"] + } + } + ) + use_aten_lib = False + entry = ( + self.native_function_no_kern, + (specialized_kernel_key, self.default_backend_metadata), + ) + + for add_exception_boundary in (True, False): + result = ComputeCodegenUnboxedKernels( + selector, use_aten_lib, add_exception_boundary + )(entry) + # Concat used to prevent whitespace stripping + expected_str = """""" + + self.assertEqual(expected_str, result) + + def test_codegen_unboxed_default(self) -> None: + """ + This test checks that if there is no specialized kernel, the default kernel is used. + """ + selector = SelectiveBuilder.from_yaml_dict( + { + "include_all_operators": True, + "et_kernel_metadata": { + "custom_1::op_1": ["v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3"] + }, + } + ) + use_aten_lib = False + entry = (self.native_function_no_kern, self.default_kernel_entry) + + result = ComputeCodegenUnboxedKernels( + selector, use_aten_lib, add_exception_boundary=False + )(entry) + # Concat used to prevent whitespace stripping + expected_str = ( + """ +Kernel( + "custom_1::op_1", + [](torch::executor::KernelRuntimeContext & context, EValue** stack) { + """ + + """ + + + internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_op_1"); + EXECUTORCH_SCOPE_PROF("native_call_op_1"); + bool result_ = at::native::default_kernel(context, ); + internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]); + + *stack[0] = EValue(result_); + + } +), +""" + ) + + self.assertEqual(expected_str, result) + + result = ComputeCodegenUnboxedKernels( + selector, use_aten_lib, add_exception_boundary=True + )(entry) + # Concat used to prevent whitespace stripping + expected_str = ( + """ +Kernel( + "custom_1::op_1", + [](torch::executor::KernelRuntimeContext & context, EValue** stack) { + """ + + """ + + try { + internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_op_1"); + EXECUTORCH_SCOPE_PROF("native_call_op_1"); + bool result_ = at::native::default_kernel(context, ); + internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]); + + *stack[0] = EValue(result_); + } catch (const std::exception& ex) { + ET_LOG(Error, "Kernel threw an exception: %s", ex.what()); + context.fail(torch::executor::Error::Internal); + } + } +), +""" + ) + self.maxDiff = None + self.assertEqual(expected_str, result) + + def test_codegen_unboxed_default_kernel_key_selected(self) -> None: + """ + This test checks that if there is no specialized kernel, the default kernel is used, when the selector only has default key. + """ + selector = SelectiveBuilder.from_yaml_dict( + { + "include_all_operators": True, + "et_kernel_metadata": {"custom_1::op_1": ["default"]}, + } + ) + use_aten_lib = False + entry = (self.native_function_no_kern, self.default_kernel_entry) + + result = ComputeCodegenUnboxedKernels( + selector, use_aten_lib, add_exception_boundary=False + )(entry) + # Concat used to prevent whitespace stripping + expected_str = ( + """ +Kernel( + "custom_1::op_1", + [](torch::executor::KernelRuntimeContext & context, EValue** stack) { + """ + + """ + + + internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_op_1"); + EXECUTORCH_SCOPE_PROF("native_call_op_1"); + bool result_ = at::native::default_kernel(context, ); + internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]); + + *stack[0] = EValue(result_); + + } +), +""" + ) + + self.assertEqual(expected_str, result) diff --git a/codegen/test/test_executorch_signatures.py b/codegen/test/test_executorch_signatures.py new file mode 100644 index 00000000000..9a019bd6908 --- /dev/null +++ b/codegen/test/test_executorch_signatures.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from executorch.codegen.api.types import ExecutorchCppSignature +from torchgen.local import parametrize +from torchgen.model import Location, NativeFunction + + +DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml( + {"func": "foo.out(Tensor input, *, Tensor(a!) out) -> Tensor(a!)"}, + loc=Location(__file__, 1), + valid_tags=set(), +) + + +class ExecutorchCppSignatureTest(unittest.TestCase): + def setUp(self) -> None: + self.sig = ExecutorchCppSignature.from_native_function(DEFAULT_NATIVE_FUNCTION) + + def test_runtime_signature_contains_runtime_context(self) -> None: + # test if `KernelRuntimeContext` argument exists in `RuntimeSignature` + with parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False + ): + args = self.sig.arguments(include_context=True) + self.assertEqual(len(args), 3) + self.assertTrue(any(a.name == "context" for a in args)) + + def test_runtime_signature_does_not_contain_runtime_context(self) -> None: + # test if `KernelRuntimeContext` argument is missing in `RuntimeSignature` + with parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False + ): + args = self.sig.arguments(include_context=False) + self.assertEqual(len(args), 2) + self.assertFalse(any(a.name == "context" for a in args)) + + def test_runtime_signature_declaration_correct(self) -> None: + with parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False + ): + decl = self.sig.decl(include_context=True) + self.assertEqual( + decl, + ( + "torch::executor::Tensor & foo_outf(" + "torch::executor::KernelRuntimeContext & context, " + "const torch::executor::Tensor & input, " + "torch::executor::Tensor & out)" + ), + ) + no_context_decl = self.sig.decl(include_context=False) + self.assertEqual( + no_context_decl, + ( + "torch::executor::Tensor & foo_outf(" + "const torch::executor::Tensor & input, " + "torch::executor::Tensor & out)" + ), + ) diff --git a/codegen/test/test_executorch_types.py b/codegen/test/test_executorch_types.py new file mode 100644 index 00000000000..e219c86ca57 --- /dev/null +++ b/codegen/test/test_executorch_types.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from executorch.codegen.api.et_cpp import argument_type, return_type, returns_type +from executorch.codegen.api.types import ArrayRefCType, scalarT, tensorListT, tensorT + +from torchgen import local +from torchgen.api.types import ( + BaseCType, + boolT, + ConstRefCType, + CType, + longT, + MutRefCType, + NamedCType, + OptionalCType, + TupleCType, + VectorCType, + voidT, +) +from torchgen.model import Argument, FunctionSchema, Return + + +class ExecutorchCppTest(unittest.TestCase): + """ + Test executorch.codegen.api.cpp + """ + + def _test_argumenttype_type(self, arg_str: str, expected: NamedCType) -> None: + arg = Argument.parse(arg_str) + self.assertEqual(str(argument_type(arg, binds=arg.name)), str(expected)) + + @local.parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False + ) + def test_argumenttype_type(self) -> None: + data = [ + ("Tensor self", NamedCType("self", ConstRefCType(BaseCType(tensorT)))), + ("Tensor(a!) out", NamedCType("out", MutRefCType(BaseCType(tensorT)))), + ( + "Tensor? opt", + NamedCType("opt", ConstRefCType(OptionalCType(BaseCType(tensorT)))), + ), + ("Scalar scalar", NamedCType("scalar", ConstRefCType(BaseCType(scalarT)))), + ( + "Scalar? scalar", + NamedCType("scalar", ConstRefCType(OptionalCType(BaseCType(scalarT)))), + ), + ("int[] size", NamedCType("size", ArrayRefCType(BaseCType(longT)))), + ("int? dim", NamedCType("dim", OptionalCType(BaseCType(longT)))), + ("Tensor[] weight", NamedCType("weight", BaseCType(tensorListT))), + ( + "Scalar[] spacing", + NamedCType("spacing", ArrayRefCType(ConstRefCType(BaseCType(scalarT)))), + ), + ( + "Tensor?[] weight", + NamedCType("weight", ArrayRefCType(OptionalCType(BaseCType(tensorT)))), + ), + ( + "SymInt[]? output_size", + NamedCType( + "output_size", OptionalCType(ArrayRefCType(BaseCType(longT))) + ), + ), + ( + "int[]? dims", + NamedCType("dims", OptionalCType(ArrayRefCType(BaseCType(longT)))), + ), + ( + "bool[3] output_mask", + NamedCType("output_mask", ArrayRefCType(BaseCType(boolT))), + ), + ] + for d in data: + self._test_argumenttype_type(*d) + + def _test_returntype_type(self, ret_str: str, expected: CType) -> None: + ret = Return.parse(ret_str) + self.assertEqual(str(return_type(ret)), str(expected)) + + @local.parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False + ) + def test_returntype_type(self) -> None: + data = [ + ("Tensor", BaseCType(tensorT)), + ("Tensor(a!)", MutRefCType(BaseCType(tensorT))), + ("Tensor[]", VectorCType(BaseCType(tensorT))), + ] + for d in data: + self._test_returntype_type(*d) + + @local.parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False + ) + def test_returns_type(self) -> None: + func = FunctionSchema.parse( + "min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)" + ) + expected = TupleCType([BaseCType(tensorT), BaseCType(tensorT)]) + self.assertEqual(str(returns_type(func.returns)), str(expected)) + + @local.parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False + ) + def test_void_return_type(self) -> None: + func = FunctionSchema.parse( + "_foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()" + ) + expected = BaseCType(voidT) + self.assertEqual(str(returns_type(func.returns)), str(expected)) + + +if __name__ == "__main__": + unittest.main() diff --git a/codegen/test/test_executorch_unboxing.py b/codegen/test/test_executorch_unboxing.py new file mode 100644 index 00000000000..4244291a674 --- /dev/null +++ b/codegen/test/test_executorch_unboxing.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from types import ModuleType + +from executorch.codegen.api import et_cpp as et_cpp, types as et_types +from executorch.codegen.api.unboxing import Unboxing + +from torchgen import local +from torchgen.api import cpp as aten_cpp, types as aten_types +from torchgen.api.types import ( + ArgName, + BaseCType, + ConstRefCType, + MutRefCType, + NamedCType, +) +from torchgen.model import BaseTy, BaseType, ListType, OptionalType, Type + + +def aten_argumenttype_type_wrapper( + t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False +) -> NamedCType: + return aten_cpp.argumenttype_type( + t, + mutable=mutable, + binds=binds, + remove_non_owning_ref_types=remove_non_owning_ref_types, + ) + + +ATEN_UNBOXING = Unboxing(argument_type_gen=aten_argumenttype_type_wrapper) +ET_UNBOXING = Unboxing(argument_type_gen=et_cpp.argumenttype_type) + + +class TestUnboxing(unittest.TestCase): + """ + Could use torch.testing._internal.common_utils to reduce boilerplate. + GH CI job doesn't build torch before running tools unit tests, hence + manually adding these parametrized tests. + """ + + @local.parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False + ) + def test_symint_argument_translate_ctype_aten(self) -> None: + # test if `SymInt[]` JIT argument can be translated into C++ argument correctly. + # should be `IntArrayRef` due to the fact that Executorch doesn't use symint sig. + + symint_list_type = ListType(elem=BaseType(BaseTy.SymInt), size=None) + + out_name, ctype, _, _ = ATEN_UNBOXING.argumenttype_evalue_convert( + t=symint_list_type, arg_name="size", mutable=False + ) + + self.assertEqual(out_name, "size_list_out") + self.assertIsInstance(ctype, BaseCType) + self.assertEqual(ctype, aten_types.BaseCType(aten_types.intArrayRefT)) + + @local.parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False + ) + def test_symint_argument_translate_ctype_executorch(self) -> None: + # test if `SymInt[]` JIT argument can be translated into C++ argument correctly. + # should be `IntArrayRef` due to the fact that Executorch doesn't use symint sig. + + symint_list_type = ListType(elem=BaseType(BaseTy.SymInt), size=None) + + out_name, ctype, _, _ = ET_UNBOXING.argumenttype_evalue_convert( + t=symint_list_type, arg_name="size", mutable=False + ) + + self.assertEqual(out_name, "size_list_out") + self.assertIsInstance(ctype, et_types.ArrayRefCType) + self.assertEqual( + ctype, et_types.ArrayRefCType(elem=BaseCType(aten_types.longT)) + ) + + @local.parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False + ) + def _test_const_tensor_argument_translate_ctype( + self, unboxing: Unboxing, types: ModuleType + ) -> None: + tensor_type = BaseType(BaseTy.Tensor) + + out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert( + t=tensor_type, arg_name="self", mutable=False + ) + + self.assertEqual(out_name, "self_base") + self.assertEqual(ctype, ConstRefCType(BaseCType(types.tensorT))) + + def test_const_tensor_argument_translate_ctype_aten(self) -> None: + self._test_const_tensor_argument_translate_ctype(ATEN_UNBOXING, aten_types) + + def test_const_tensor_argument_translate_ctype_executorch(self) -> None: + self._test_const_tensor_argument_translate_ctype(ET_UNBOXING, et_types) + + @local.parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False + ) + def _test_mutable_tensor_argument_translate_ctype( + self, unboxing: Unboxing, types: ModuleType + ) -> None: + tensor_type = BaseType(BaseTy.Tensor) + + out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert( + t=tensor_type, arg_name="out", mutable=True + ) + + self.assertEqual(out_name, "out_base") + self.assertEqual(ctype, MutRefCType(BaseCType(types.tensorT))) + + def test_mutable_tensor_argument_translate_ctype_aten(self) -> None: + self._test_mutable_tensor_argument_translate_ctype(ATEN_UNBOXING, aten_types) + + def test_mutable_tensor_argument_translate_ctype_executorch(self) -> None: + self._test_mutable_tensor_argument_translate_ctype(ET_UNBOXING, et_types) + + @local.parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False + ) + def _test_tensor_list_argument_translate_ctype( + self, unboxing: Unboxing, types: ModuleType + ) -> None: + tensor_list_type = ListType(elem=BaseType(BaseTy.Tensor), size=None) + + out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert( + t=tensor_list_type, arg_name="out", mutable=True + ) + + self.assertEqual(out_name, "out_list_out") + self.assertEqual(ctype, BaseCType(types.tensorListT)) + + def test_tensor_list_argument_translate_ctype_aten(self) -> None: + self._test_tensor_list_argument_translate_ctype(ATEN_UNBOXING, aten_types) + + def test_tensor_list_argument_translate_ctype_executorch(self) -> None: + self._test_tensor_list_argument_translate_ctype(ET_UNBOXING, et_types) + + @local.parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False + ) + def _test_optional_int_argument_translate_ctype( + self, unboxing: Unboxing, types: ModuleType + ) -> None: + optional_int_type = OptionalType(elem=BaseType(BaseTy.int)) + + out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert( + t=optional_int_type, arg_name="something", mutable=True + ) + + self.assertEqual(out_name, "something_opt_out") + self.assertEqual(ctype, types.OptionalCType(BaseCType(types.longT))) + + def test_optional_int_argument_translate_ctype_aten(self) -> None: + self._test_optional_int_argument_translate_ctype(ATEN_UNBOXING, aten_types) + + def test_optional_int_argument_translate_ctype_executorch(self) -> None: + self._test_optional_int_argument_translate_ctype(ET_UNBOXING, et_types) diff --git a/codegen/test/test_selective_build.py b/codegen/test/test_selective_build.py new file mode 100644 index 00000000000..3754ca0345d --- /dev/null +++ b/codegen/test/test_selective_build.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from torchgen.selective_build.selector import SelectiveBuilder + + +class TestExecuTorchSelectiveBuild(unittest.TestCase): + def test_et_kernel_selected(self) -> None: + yaml_config = """ +et_kernel_metadata: + aten::add.out: + - "v1/6;0,1|6;0,1|6;0,1|6;0,1" + aten::sub.out: + - "v1/6;0,1|6;0,1|6;0,1|6;0,1" +""" + selector = SelectiveBuilder.from_yaml_str(yaml_config) + self.assertListEqual( + ["v1/6;0,1|6;0,1|6;0,1|6;0,1"], + selector.et_get_selected_kernels( + "aten::add.out", + [ + "v1/6;0,1|6;0,1|6;0,1|6;0,1", + "v1/3;0,1|3;0,1|3;0,1|3;0,1", + "v1/6;1,0|6;0,1|6;0,1|6;0,1", + ], + ), + ) + self.assertListEqual( + ["v1/6;0,1|6;0,1|6;0,1|6;0,1"], + selector.et_get_selected_kernels( + "aten::sub.out", ["v1/6;0,1|6;0,1|6;0,1|6;0,1"] + ), + ) + self.assertListEqual( + [], + selector.et_get_selected_kernels( + "aten::mul.out", ["v1/6;0,1|6;0,1|6;0,1|6;0,1"] + ), + ) + # We don't use version for now. + self.assertListEqual( + ["v2/6;0,1|6;0,1|6;0,1|6;0,1"], + selector.et_get_selected_kernels( + "aten::add.out", ["v2/6;0,1|6;0,1|6;0,1|6;0,1"] + ), + ) diff --git a/pytest.ini b/pytest.ini index 87d9ebc9c23..4dd7f4353d2 100644 --- a/pytest.ini +++ b/pytest.ini @@ -55,6 +55,7 @@ addopts = # Runtime runtime # Tools + codegen/test tools/cmake # test TODO: fix these tests # test/end2end/test_end2end.py