diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 5b32c2fce5b..d0d540a3742 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -132,6 +132,18 @@ python_library( ], ) +python_library( + name = "graph_builder", + srcs = [ + "graph_builder.py", + ], + typing = True, + deps = [ + "fbcode//caffe2:torch", + "fbcode//executorch/exir:pass_base", + ], +) + python_library( name = "fuse_ops", srcs = [ @@ -150,3 +162,20 @@ python_library( "//executorch/exir/passes:spec_prop_pass", ], ) + +python_unittest( + name = "test_graph_builder", + srcs = [ + "tests/test_graph_builder.py", + ], + typing = True, + deps = [ + "//caffe2:torch", + "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/cadence/aot:pass_utils", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + "//later:lib", + ":ops_registrations" + ], +) diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 7e5fd3fec27..6b799d99f9e 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -196,7 +196,26 @@ def export_to_edge( # Export the model and lower it to an EdgeProgramManager (in edge IR), and # apply passes specific to Cadence DSP execution. Return both to print the # differences. -def export_to_cadence_edge_executorch( +def export_to_cadence( + model: torch.nn.Module, + inputs: tuple[object, ...], + dump_graphs: bool = False, + output_dir: Optional[str] = None, + opt_level: int = 1, +) -> EdgeProgramManager: + edge_prog_manager = export_to_edge(model, inputs) + cadence_passes = get_cadence_passes(opt_level) + + # Run a couple required passes for quant/dequant ops + cadence_prog_manager = edge_prog_manager.transform( + cast( + list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes + ) + ) + return cadence_prog_manager + + +def export_to_executorch_gen_etrecord( model: torch.nn.Module, inputs: tuple[object, ...], dump_graphs: bool = False, diff --git a/backends/cadence/aot/export_example.py b/backends/cadence/aot/export_example.py index 092bbc4b192..146d4f806cd 100644 --- a/backends/cadence/aot/export_example.py +++ b/backends/cadence/aot/export_example.py @@ -16,7 +16,7 @@ from executorch.backends.cadence.aot.compiler import ( convert_pt2, - export_to_cadence_edge_executorch, + export_to_executorch_gen_etrecord, fuse_pt2, ) @@ -86,8 +86,8 @@ def export_model( quantized_model = fuse_pt2(converted_model, quantizer) # Get edge program after Cadence specific passes - exec_prog: ExecutorchProgramManager = export_to_cadence_edge_executorch( - quantized_model, example_inputs, working_dir + exec_prog: ExecutorchProgramManager = export_to_executorch_gen_etrecord( + quantized_model, example_inputs, output_dir=working_dir ) logging.info("Final exported graph:\n") diff --git a/backends/cadence/aot/graph_builder.py b/backends/cadence/aot/graph_builder.py new file mode 100644 index 00000000000..27604eac321 --- /dev/null +++ b/backends/cadence/aot/graph_builder.py @@ -0,0 +1,107 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +import logging +from typing import Optional, Sequence, Union + +import torch +from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from torch._subclasses import FakeTensor, FakeTensorMode +from torch.fx.node import Argument, Target +from torch.utils import _pytree as pytree + + +class GraphBuilder(ExportPass): + """Utility class for creating a graph module with user-specified ops. + + This class allows us to create test graph modules with any ops we want + directly, rather than relying on decomposition or passes. + + Usage: + builder = GraphBuilder() + # To insert placeholders, use builder.placeholder. + x = builder.placeholder("x", torch.randn(1, 3, 224, 224)) + # To insert an op, use builder.call_operator. + op = builder.call_operator( + some_op + (x, other_args, ...), + ) + # Insert outputs as a list of ProxyValues using builder.output. + builder.output([op]) + # Get GraphModule from builder. + gm = builder.get_graph_module() + """ + + def __init__(self) -> None: + self.exporter = ExportPass() + self.tracer: ExportPass.ExportTracer = self.ExportTracer( + self, torch.fx.graph.CodeGen() + ) + self.fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False) + self.tracer.fake_tensor_mode = self.fake_tensor_mode + + # This will be called to create nodes in tracer. + self.interpreter = torch.fx.Interpreter( + torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + ) + + # pyre-ignore[14]: Inconsistent override. + def placeholder( + self, target: str, fake_tensor: Union[FakeTensor, torch.Tensor] + ) -> ProxyValue: + if not isinstance(fake_tensor, FakeTensor): + fake_tensor = self.fake_tensor_mode.from_tensor(fake_tensor) + logging.info(f"Creating placeholder {target} => {fake_tensor.shape}") + placeholder = super().placeholder(target, fake_tensor, NodeMetadata({})) + return placeholder + + # pyre-ignore[14]: Inconsistent override. + def output(self, results: list[ProxyValue]) -> ProxyValue: + logging.info(f"Creating outputs {results}") + return super().output(results, NodeMetadata({})) + + def get_graph_module(self) -> torch.fx.GraphModule: + return torch.fx.GraphModule(self.tracer.root, self.tracer.graph) + + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: Optional[dict[str, Argument]] = None, + meta: Optional[NodeMetadata] = None, + ) -> ProxyValue: + if meta is None: + meta = NodeMetadata({}) + if kwargs is None: + kwargs = {} + return super().call_operator(op, args, kwargs, meta) + + +def single_op_builder( + placeholders: Sequence[Union[torch.Tensor, FakeTensor]], + op: Target, + args: Sequence[Argument], + kwargs: Optional[dict[str, Argument]] = None, +) -> torch.fx.GraphModule: + """Create a graph module with a single op. + + Args: + placeholders: Placeholders to be used as inputs to the GraphModule. + op: The op to be inserted. + args: The args to be passed to the op. + kwargs: The kwargs to be passed to the op. + + Returns: + A graph module with a single op + """ + builder = GraphBuilder() + op_to_placeholder_dict = { + p: builder.placeholder(f"p_{i}", p) for i, p in enumerate(placeholders) + } + proxy_args, proxy_kwargs = pytree.tree_map_only( + (torch.Tensor, FakeTensor), lambda x: op_to_placeholder_dict[x], (args, kwargs) + ) + node = builder.call_operator(op, proxy_args, proxy_kwargs) + builder.output([node]) + return builder.get_graph_module() diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py index 12a2f622389..ed56a1b85fb 100644 --- a/backends/cadence/aot/pass_utils.py +++ b/backends/cadence/aot/pass_utils.py @@ -89,3 +89,12 @@ def get_node_names_list_from_gm( continue graph_nodes.append(node.name) return graph_nodes + + +def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target) -> int: + """Count the number of nodes with target `target` in the graph.""" + total = 0 + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target == target: + total += 1 + return total diff --git a/backends/cadence/aot/tests/test_graph_builder.py b/backends/cadence/aot/tests/test_graph_builder.py new file mode 100644 index 00000000000..04097c17255 --- /dev/null +++ b/backends/cadence/aot/tests/test_graph_builder.py @@ -0,0 +1,70 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + + +import executorch.backends.cadence.aot.ops_registrations # noqa +import torch +from executorch.backends.cadence.aot.graph_builder import ( + GraphBuilder, + single_op_builder, +) +from executorch.backends.cadence.aot.pass_utils import count_node +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass +from later.unittest import TestCase + + +class TestGraphBuilder(TestCase): + def test_graph_with_single_im2row(self) -> None: + # Create a graph with a single im2row node. + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 3, 224, 224)) + pad_value = builder.placeholder("pad", torch.randn(1)) + channels_last = False + im2row = builder.call_operator( + exir_ops.edge.cadence.im2row.default, + # pyre-ignore + ( + x, + (2, 2), + (1, 1), + (0, 0), + (1, 1), + pad_value, + channels_last, + ), + ) + builder.output([im2row]) + gm = builder.get_graph_module() + # Check if graph module is valid by running exportpass on it. + gm = ExportPass().call(gm).graph_module + + # Check graph has a single im2row node. + self.assertEqual(len([gm.graph.nodes]), 1) + self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) + + +class TestSingleOpBuilderUtility(TestCase): + def test_graph_with_single_im2row(self) -> None: + # Create a graph with a single im2row node. + x = torch.randn(1, 3, 224, 224) + pad_value = torch.randn(1) + channels_last = False + gm = single_op_builder( + (x, pad_value), + exir_ops.edge.cadence.im2row.default, + ( + x, + (2, 2), + (1, 1), + (0, 0), + (1, 1), + pad_value, + channels_last, + ), + ) + # Check if graph module is valid by running exportpass on it. + gm = ExportPass().call(gm).graph_module + + # Check graph has a single im2row node. + self.assertEqual(len([gm.graph.nodes]), 1) + self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)