diff --git a/backends/apple/metal/README.md b/backends/apple/metal/README.md new file mode 100644 index 00000000000..0f010ae8920 --- /dev/null +++ b/backends/apple/metal/README.md @@ -0,0 +1,5 @@ +# Metal Backend + +⚠️ **EXPERIMENTAL BACKEND** + +This backend is currently in experimental development and may not be fully functional or stable. Use with caution. diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py new file mode 100644 index 00000000000..db57bed8fc7 --- /dev/null +++ b/backends/apple/metal/metal_backend.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import os +import typing +from enum import Enum + +from typing import Any, Dict, final, List, Optional, Set + +import torch +from executorch.backends.apple.metal.replace_slice_copy_with_slice import ( + ReplaceSliceCopyWithSlicePass, +) +from executorch.exir._serialize._named_data_store import NamedDataStore +from executorch.exir._warnings import experimental +from executorch.exir.backend.backend_details import ( + BackendDetails, + ExportedProgram, + PreprocessResult, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch.export.passes import move_to_device_pass + + +# exist fallback operators in et namespace; +supported_fallback_kernels: Dict[str, Any] = { + "aoti_torch_mps_addmm_out": None, + "aoti_torch_mps_convolution": None, + "aoti_torch_mps_mm_out": None, + "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, +} + +# required fallback kernels but not supported +missing_fallback_kernels: Set[str] = set() + + +class COMPILE_SPEC_KEYS(Enum): + METHOD_NAME = "method_name" + + +# context manager for non-fallback guarantee +# it will raise exception when generating fallback kernels during aoti compile +@contextlib.contextmanager +def collect_unsupported_fallback_kernels(): + original_generate_c_shim_extern_kernel_call = ( + CppWrapperCpu.generate_c_shim_extern_kernel_call + ) + + def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( + self, + kernel: str, + args: list[str], + device: str, + *, + debug_args: Optional[list[str]] = None, + debug_handle: Optional[int] = None, + ): + if kernel not in supported_fallback_kernels: + missing_fallback_kernels.add(kernel) + + original_generate_c_shim_extern_kernel_call( + self, kernel, args, device, debug_args=debug_args, debug_handle=debug_handle + ) + + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels + ) + try: + yield + finally: + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + original_generate_c_shim_extern_kernel_call + ) + + +@final +@experimental( + "This API and all of Metal backend related functionality are experimental." +) +class MetalBackend(BackendDetails): + @staticmethod + def preprocess( + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + print("entering the lowerable parts in MetalBackend.preprocess....") + # Move the edge_program from CPU to MPS for aoti compile + mps_edge_program = move_to_device_pass(edge_program, "mps") + + # replace slice_copy with slice + ReplaceSliceCopyWithSlicePass()(mps_edge_program.graph_module) + + edge_program_module = mps_edge_program.module() + + # Grab all input placeholders from the graph + user_input_names = mps_edge_program.graph_signature.user_inputs + user_input_placeholders = [] + for node in mps_edge_program.graph.nodes: + if node.op == "placeholder" and node.name in user_input_names: + user_input_placeholders.append(node.meta["val"]) + + # Base options for all devices + options: dict[str, typing.Any] = { + # Do not link against the full PyTorch/libtorch library + "aot_inductor.link_libtorch": False, + # Package model constants and other generated files directly in the shared object (.so) file + "aot_inductor.package_constants_in_so": True, + # Enable maximum automatic tuning for optimal performance + "max_autotune": True, + # "aot_inductor.debug_compile": True, + # "aot_inductor.force_mmap_weights": False, + } + + with collect_unsupported_fallback_kernels(): + so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] + if len(missing_fallback_kernels) > 0: + formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) + raise RuntimeError( + f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" + "Please add them to the AOTI backend." + ) + + # pyre-ignorep[6]: Incompatible parameter type + with open(so_path, "rb") as f: + so_data = f.read() + + named_data_store = NamedDataStore() + method_name = MetalBackend.method_name_from_compile_specs(compile_specs) + named_data_store.add_named_data( + method_name + "_so_blob", so_data, 1, "aoti_metal_blob" + ) + + # Clean up the generated so file; it has been packaged into the NamdeDataStore + # pyre-ignorep[6]: Incompatible parameter type + os.remove(so_path) + + return PreprocessResult( + processed_bytes=b"", + debug_handle_map={}, + data_store_output=named_data_store.get_named_data_store_output(), + ) + + @staticmethod + def generate_method_name_compile_spec( + method_name: str, + ) -> CompileSpec: + """ + Generates a CompileSpec for the given method name. + """ + return CompileSpec( + COMPILE_SPEC_KEYS.METHOD_NAME.value, + method_name.encode("utf-8"), + ) + + @staticmethod + def method_name_from_compile_specs( + compile_specs: List[CompileSpec], + ) -> str: + """ + Returns the method name from the compile specs. + """ + for spec in compile_specs: + if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: + return spec.value.decode("utf-8") + raise RuntimeError( + f"Could not find method name in compile specs: {compile_specs}" + ) diff --git a/backends/apple/metal/metal_partitioner.py b/backends/apple/metal/metal_partitioner.py new file mode 100644 index 00000000000..b103ac0f455 --- /dev/null +++ b/backends/apple/metal/metal_partitioner.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Dict, final, List, Optional, Tuple + +import torch +from executorch.backends.apple.metal.metal_backend import MetalBackend # usort: skip +from executorch.exir._warnings import experimental +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer +from torch.export.exported_program import ExportedProgram + + +@final +@experimental( + "This API and all of Metal backend related functionality are experimental." +) +class MetalPartitioner(Partitioner): + """ + Metal partitioner for AOTInductor backend integration. + + This partitioner creates a single partition containing all operators from the input graph. + It skips core ATen decomposition, allowing the Metal backend to handle decomposition using + AOTInductor's MPS-specific decomposition table. + + Only operators that cannot be handled by the aoti-mps library will be excluded from + the partition and fall back to ExecuTorch's default or custom handling. + """ + + def __init__(self, compile_spec: List[CompileSpec]) -> None: + self.delegation_spec = DelegationSpec(MetalBackend.__name__, compile_spec) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + """ + Fully delegate the graph to AOTInductor by tagging all nodes as a single partition. + """ + + partition_tags: Dict[str, DelegationSpec] = {} + tag = "tag0" + + for node in exported_program.graph.nodes: + if node.op != "call_function": + continue + node.meta["delegation_tag"] = tag + + partition_tags[tag] = self.delegation_spec + + tag_constant_data(exported_program) + tag_mutated_buffer(exported_program) + + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) + + def ops_to_not_decompose( + self, ep: ExportedProgram + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + """ + Return a list of operations that should not be decomposed and let the AOT compiler handle them. + Currently we skip ATen decompositon for all ops, and let the Metal backend handle them. + """ + do_not_decompose = set() + + for node in ep.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload + ): + do_not_decompose.add(node.target) + return list(do_not_decompose), None diff --git a/backends/apple/metal/replace_slice_copy_with_slice.py b/backends/apple/metal/replace_slice_copy_with_slice.py new file mode 100644 index 00000000000..4f16759af35 --- /dev/null +++ b/backends/apple/metal/replace_slice_copy_with_slice.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Dict, Iterable, Tuple + +import torch +from executorch.exir.dialects._ops import ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from torch import fx + + +_SLICE_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = ( + torch.ops.aten.slice_copy.Tensor, + ops.edge.aten.slice_copy.Tensor, +) + +_SLICE_TARGETS: Dict[ + torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload +] = { + torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor, + ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor, +} + + +class ReplaceSliceCopyWithSlicePass(ExportPass): + """Replace non-mutated ``slice_copy`` results with ``slice`` views.""" + + def call(self, graph_module: fx.GraphModule) -> PassResult: + graph_changed = False + + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in _SLICE_COPY_TARGETS: + continue + + if self._has_blocking_user(node, node.users.keys()): + continue + + node.target = _SLICE_TARGETS[node.target] + graph_changed = True + + if graph_changed: + graph_module.graph.lint() + graph_module.recompile() + + return PassResult(graph_module, graph_changed) + + def _has_blocking_user(self, node: fx.Node, users: Iterable[fx.Node]) -> bool: + for user in users: + if self._is_mutating_user(node, user) or self._is_view_user(node, user): + return True + return False + + def _is_mutating_user(self, node: fx.Node, user: fx.Node) -> bool: + if user.op == "call_method": + # Treat in-place tensor methods conservatively as mutations only when the + # method name ends with ``_`` which is the PyTorch convention for mutation. + return isinstance(user.target, str) and user.target.endswith("_") + + if user.op != "call_function": + return False + + target = user.target + if not hasattr(target, "_schema"): + return False + + schema = target._schema # pyre-ignore[16] + # Positional arguments + for index, arg in enumerate(user.args): + if arg is node and self._argument_mutates(schema, index): + return True + + # Keyword arguments + for name, arg in user.kwargs.items(): + if arg is node and self._argument_mutates(schema, name): + return True + + return False + + def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool: + if user.op == "call_method": + # Treat tensor methods conservatively and assume they may be view-producing. + return True + + if user.op != "call_function": + return False + + target = user.target + if getattr(target, "is_view", False): + for arg in user.args: + if arg is node: + return True + for arg in user.kwargs.values(): + if arg is node: + return True + + return False + + def _argument_mutates( + self, schema: torch._C.FunctionSchema, key: int | str + ) -> bool: + arguments = schema.arguments + if isinstance(key, int): + if key >= len(arguments): + return False + argument = arguments[key] + else: + argument = next((arg for arg in arguments if arg.name == key), None) + if argument is None: + return False + + alias_info = argument.alias_info + return bool(alias_info and alias_info.is_write) diff --git a/backends/apple/metal/tests/__init__.py b/backends/apple/metal/tests/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/apple/metal/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/apple/metal/tests/test_metal_backend.py b/backends/apple/metal/tests/test_metal_backend.py new file mode 100644 index 00000000000..5caf7a3adc6 --- /dev/null +++ b/backends/apple/metal/tests/test_metal_backend.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from executorch.backends.apple.metal.metal_backend import ( + COMPILE_SPEC_KEYS, + MetalBackend, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec + + +class TestMetalBackend(unittest.TestCase): + """Test Metal backend utility functions.""" + + def test_generate_method_name_compile_spec(self): + """Test that compile spec is generated correctly with method name.""" + method_name = "forward" + compile_spec = MetalBackend.generate_method_name_compile_spec(method_name) + + # Verify compile spec structure + self.assertIsInstance(compile_spec, CompileSpec) + self.assertEqual(compile_spec.key, COMPILE_SPEC_KEYS.METHOD_NAME.value) + self.assertEqual(compile_spec.value, method_name.encode("utf-8")) + + def test_method_name_from_compile_specs(self): + """Test extracting method name from compile specs.""" + method_name = "forward" + compile_specs = [MetalBackend.generate_method_name_compile_spec(method_name)] + + # Extract method name + extracted_name = MetalBackend.method_name_from_compile_specs(compile_specs) + + self.assertEqual(extracted_name, method_name) + + def test_method_name_from_compile_specs_with_multiple_specs(self): + """Test extracting method name when there are multiple compile specs.""" + method_name = "forward" + compile_specs = [ + CompileSpec("other_key", b"other_value"), + MetalBackend.generate_method_name_compile_spec(method_name), + CompileSpec("another_key", b"another_value"), + ] + + # Extract method name + extracted_name = MetalBackend.method_name_from_compile_specs(compile_specs) + + self.assertEqual(extracted_name, method_name) + + def test_method_name_from_compile_specs_missing(self): + """Test that RuntimeError is raised when method name is missing.""" + compile_specs = [ + CompileSpec("other_key", b"other_value"), + ] + + # Should raise RuntimeError when method name is not found + with self.assertRaises(RuntimeError) as context: + MetalBackend.method_name_from_compile_specs(compile_specs) + + self.assertIn("Could not find method name", str(context.exception)) + + def test_compile_spec_roundtrip(self): + """Test that method name survives encode/decode roundtrip.""" + original_name = "my_custom_method" + + # Generate compile spec + compile_spec = MetalBackend.generate_method_name_compile_spec(original_name) + + # Extract from compile specs list + extracted_name = MetalBackend.method_name_from_compile_specs([compile_spec]) + + self.assertEqual(original_name, extracted_name) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/apple/metal/tests/test_metal_partitioner.py b/backends/apple/metal/tests/test_metal_partitioner.py new file mode 100644 index 00000000000..1b29410ab6c --- /dev/null +++ b/backends/apple/metal/tests/test_metal_partitioner.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Tuple + +import torch +from executorch.backends.apple.metal.metal_backend import MetalBackend +from executorch.backends.apple.metal.metal_partitioner import MetalPartitioner +from executorch.exir.backend.partitioner import PartitionResult +from torch.export import export + + +class TestMetalPartitioner(unittest.TestCase): + """ + Test Metal partitioner functionality. + + After Metal partitioning, there should be exactly one partitioned graph that contains + all operators from the input graph. This means all operators should be tagged with + the same delegation tag, indicating they will all be executed by the Metal backend. + """ + + def _get_partition_result( + self, module: torch.nn.Module, inputs: Tuple[torch.Tensor, ...] + ) -> PartitionResult: + """Helper method to get partition result for a given module.""" + # Export the model + exported_program = export(module, inputs, strict=True) + + # Create partitioner with compile specs + compile_specs = [MetalBackend.generate_method_name_compile_spec("forward")] + partitioner = MetalPartitioner(compile_specs) + + # Get partition result + partition_result = partitioner.partition(exported_program) + + # Verify partition result structure + self.assertIsNotNone(partition_result) + self.assertTrue(hasattr(partition_result, "tagged_exported_program")) + self.assertTrue(hasattr(partition_result, "partition_tags")) + + return partition_result + + def _check_fully_partitioned(self, partition_result: PartitionResult) -> bool: + """Check if the graph is fully partitioned (all operators have the same tag).""" + tagged_nodes = [] + untagged_ops = [] + + for node in partition_result.tagged_exported_program.graph.nodes: + if node.op == "call_function": + if hasattr(node, "meta") and "delegation_tag" in node.meta: + tagged_nodes.append(node) + else: + untagged_ops.append(node) + + # Check if we have any tagged nodes + if not tagged_nodes: + return False + + # Check if all tagged nodes have the same tag + first_tag = tagged_nodes[0].meta["delegation_tag"] + all_same_tag = all( + node.meta.get("delegation_tag") == first_tag for node in tagged_nodes + ) + + # Should have no untagged operations for full partitioning + fully_partitioned = len(untagged_ops) == 0 and all_same_tag + + return fully_partitioned + + def test_simple_add_partition(self): + """ + Test that Metal partitioner creates exactly one partition containing all operators. + Simple element-wise addition should result in a single graph with all ops tagged identically. + """ + + class AddModule(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + # Create test inputs + x = torch.randn(2, 3) + y = torch.randn(2, 3) + + # Get partition result + partition_result = self._get_partition_result(AddModule(), (x, y)) + + # Verify it's fully partitioned + self.assertTrue( + self._check_fully_partitioned(partition_result), + "Expected all operations to be in a single partition", + ) + + # Verify exactly one partition tag exists + self.assertEqual( + len(partition_result.partition_tags), + 1, + "Expected exactly one partition tag for fully delegated graph", + ) + + def test_linear_partition(self): + """ + Test Metal partitioner with a linear layer. + All matrix operations should be in a single partition. + """ + + class LinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + # Create test input + x = torch.randn(2, 10) + + # Get partition result + partition_result = self._get_partition_result(LinearModule(), (x,)) + + # Verify it's fully partitioned + self.assertTrue( + self._check_fully_partitioned(partition_result), + "Expected all operations to be in a single partition", + ) + + def test_ops_to_not_decompose(self): + """ + Test that ops_to_not_decompose returns all call_function ops. + Metal backend should handle decomposition via AOTInductor. + """ + + class SimpleModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.relu(x + 1.0) + + # Create test input + x = torch.randn(2, 3) + + # Export the model + exported_program = export(SimpleModule(), (x,), strict=True) + + # Create partitioner + compile_specs = [MetalBackend.generate_method_name_compile_spec("forward")] + partitioner = MetalPartitioner(compile_specs) + + # Get ops to not decompose + ops_to_not_decompose, _ = partitioner.ops_to_not_decompose(exported_program) + + # Verify it returns a list + self.assertIsInstance(ops_to_not_decompose, list) + + # All call_function ops should be in the list + call_function_ops = [ + node.target + for node in exported_program.graph.nodes + if node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) + ] + + self.assertEqual( + set(ops_to_not_decompose), + set(call_function_ops), + "ops_to_not_decompose should contain all call_function ops", + ) + + +if __name__ == "__main__": + unittest.main()