From 7a5e7dd792c23e6e1d18e344ec182da6e7a0ea23 Mon Sep 17 00:00:00 2001 From: Zonglin Peng Date: Tue, 19 Nov 2024 09:35:53 -0800 Subject: [PATCH] add graph builder in oss for fuse ops (#6877) Summary: titled Reviewed By: skrtskrtfb Differential Revision: D65911233 --- backends/cadence/aot/TARGETS | 12 +++ backends/cadence/aot/graph_builder.py | 107 ++++++++++++++++++++++++++ 2 files changed, 119 insertions(+) create mode 100644 backends/cadence/aot/graph_builder.py diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 5b32c2fce5b..549d9729902 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -132,6 +132,18 @@ python_library( ], ) +python_library( + name = "graph_builder", + srcs = [ + "graph_builder.py", + ], + typing = True, + deps = [ + "fbcode//caffe2:torch", + "fbcode//executorch/exir:pass_base", + ], +) + python_library( name = "fuse_ops", srcs = [ diff --git a/backends/cadence/aot/graph_builder.py b/backends/cadence/aot/graph_builder.py new file mode 100644 index 00000000000..27604eac321 --- /dev/null +++ b/backends/cadence/aot/graph_builder.py @@ -0,0 +1,107 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +import logging +from typing import Optional, Sequence, Union + +import torch +from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from torch._subclasses import FakeTensor, FakeTensorMode +from torch.fx.node import Argument, Target +from torch.utils import _pytree as pytree + + +class GraphBuilder(ExportPass): + """Utility class for creating a graph module with user-specified ops. + + This class allows us to create test graph modules with any ops we want + directly, rather than relying on decomposition or passes. + + Usage: + builder = GraphBuilder() + # To insert placeholders, use builder.placeholder. + x = builder.placeholder("x", torch.randn(1, 3, 224, 224)) + # To insert an op, use builder.call_operator. + op = builder.call_operator( + some_op + (x, other_args, ...), + ) + # Insert outputs as a list of ProxyValues using builder.output. + builder.output([op]) + # Get GraphModule from builder. + gm = builder.get_graph_module() + """ + + def __init__(self) -> None: + self.exporter = ExportPass() + self.tracer: ExportPass.ExportTracer = self.ExportTracer( + self, torch.fx.graph.CodeGen() + ) + self.fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False) + self.tracer.fake_tensor_mode = self.fake_tensor_mode + + # This will be called to create nodes in tracer. + self.interpreter = torch.fx.Interpreter( + torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + ) + + # pyre-ignore[14]: Inconsistent override. + def placeholder( + self, target: str, fake_tensor: Union[FakeTensor, torch.Tensor] + ) -> ProxyValue: + if not isinstance(fake_tensor, FakeTensor): + fake_tensor = self.fake_tensor_mode.from_tensor(fake_tensor) + logging.info(f"Creating placeholder {target} => {fake_tensor.shape}") + placeholder = super().placeholder(target, fake_tensor, NodeMetadata({})) + return placeholder + + # pyre-ignore[14]: Inconsistent override. + def output(self, results: list[ProxyValue]) -> ProxyValue: + logging.info(f"Creating outputs {results}") + return super().output(results, NodeMetadata({})) + + def get_graph_module(self) -> torch.fx.GraphModule: + return torch.fx.GraphModule(self.tracer.root, self.tracer.graph) + + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: Optional[dict[str, Argument]] = None, + meta: Optional[NodeMetadata] = None, + ) -> ProxyValue: + if meta is None: + meta = NodeMetadata({}) + if kwargs is None: + kwargs = {} + return super().call_operator(op, args, kwargs, meta) + + +def single_op_builder( + placeholders: Sequence[Union[torch.Tensor, FakeTensor]], + op: Target, + args: Sequence[Argument], + kwargs: Optional[dict[str, Argument]] = None, +) -> torch.fx.GraphModule: + """Create a graph module with a single op. + + Args: + placeholders: Placeholders to be used as inputs to the GraphModule. + op: The op to be inserted. + args: The args to be passed to the op. + kwargs: The kwargs to be passed to the op. + + Returns: + A graph module with a single op + """ + builder = GraphBuilder() + op_to_placeholder_dict = { + p: builder.placeholder(f"p_{i}", p) for i, p in enumerate(placeholders) + } + proxy_args, proxy_kwargs = pytree.tree_map_only( + (torch.Tensor, FakeTensor), lambda x: op_to_placeholder_dict[x], (args, kwargs) + ) + node = builder.call_operator(op, proxy_args, proxy_kwargs) + builder.output([node]) + return builder.get_graph_module()