From ba8c6c58640ba1fd8bdcc73cdd429fb26b53b86e Mon Sep 17 00:00:00 2001 From: Hardik Sharma Date: Wed, 4 Sep 2024 16:20:22 -0700 Subject: [PATCH] Buckify backends/arm for meta internal use. (#5023) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5023 1. Add buck targets for `executorch/backends/arm/...` 2. Pyre typing cleanup 3. Invoke vela compiler as `vela.main(args)` instead of `subprocess.run([vela_command]...)` Differential Revision: D62062674 --- backends/arm/TARGETS | 83 +++++++++++++++++++ backends/arm/arm_backend.py | 2 +- backends/arm/arm_vela.py | 22 ++--- backends/arm/operators/TARGETS | 34 ++++++++ backends/arm/operators/op_bmm.py | 1 + backends/arm/operators/op_conv2d.py | 7 +- backends/arm/operators/op_mm.py | 1 + backends/arm/operators/op_mul.py | 10 ++- backends/arm/operators/op_output.py | 4 +- backends/arm/operators/op_view.py | 2 +- backends/arm/passes/TARGETS | 12 +++ .../annotate_channels_last_dim_order_pass.py | 4 +- backends/arm/passes/arm_pass_manager.py | 4 +- .../passes/convert_expand_copy_to_repeat.py | 4 +- .../arm/passes/size_adjust_conv2d_pass.py | 6 +- backends/arm/quantizer/TARGETS | 31 +++++++ backends/arm/quantizer/arm_quantizer_utils.py | 10 ++- .../quantizer/quantization_annotation/TARGETS | 12 +++ .../quantization_annotation/cat_annotator.py | 4 +- backends/arm/tosa_quant_utils.py | 10 ++- backends/arm/tosa_utils.py | 11 ++- 21 files changed, 232 insertions(+), 42 deletions(-) create mode 100644 backends/arm/TARGETS create mode 100644 backends/arm/operators/TARGETS create mode 100644 backends/arm/passes/TARGETS create mode 100644 backends/arm/quantizer/TARGETS create mode 100644 backends/arm/quantizer/quantization_annotation/TARGETS diff --git a/backends/arm/TARGETS b/backends/arm/TARGETS new file mode 100644 index 00000000000..220db373710 --- /dev/null +++ b/backends/arm/TARGETS @@ -0,0 +1,83 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +python_library( + name = "arm_partitioner", + srcs = [ + "arm_partitioner.py", + ], + typing = True, + deps = [ + ":arm_backend", + "//executorch/backends/arm/passes:passes", + "//executorch/exir:lib", + ], +) + +python_library( + name = "arm_backend", + srcs = [ + "arm_backend.py", + ], + typing = True, + deps = [ + "fbsource//third-party/pypi/flatbuffers:flatbuffers", + "fbsource//third-party/pypi/ml-dtypes:ml-dtypes", + "fbsource//third-party/serialization_lib/python/serializer:serializer", + "fbsource//third-party/serialization_lib/python/tosa:tosa", + ":arm_vela", + "//executorch/backends/arm/operators:lib", + "//executorch/backends/arm/operators:node_visitor", + "//executorch/backends/arm/passes:passes", + ], +) + +python_library( + name = "arm_vela", + srcs = [ + "arm_vela.py", + ], + typing = True, + deps = [ + "fbsource//third-party/pypi/ethos-u-vela:ethos-u-vela", + ], +) + +python_library( + name = "tosa_mapping", + srcs = [ + "tosa_mapping.py", + ], + typing = True, + deps = [ + "fbsource//third-party/serialization_lib/python/serializer:serializer", + "//caffe2:torch", + ], +) + +python_library( + name = "tosa_quant_utils", + srcs = [ + "tosa_quant_utils.py", + ], + typing = True, + deps = [ + "fbsource//third-party/pypi/numpy:numpy", + "fbsource//third-party/serialization_lib/python/serializer:serializer", + "fbsource//third-party/serialization_lib/python/tosa:tosa", + ":tosa_mapping", + "//executorch/exir/dialects:lib", + ], +) + +python_library( + name = "tosa_utils", + srcs = [ + "tosa_utils.py", + ], + typing = True, + deps = [ + "fbsource//third-party/serialization_lib/python/serializer:serializer", + ":tosa_quant_utils", + "//executorch/backends/arm/operators:node_visitor", + ], +) diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index f187191fee0..27fd36ca0e1 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -159,7 +159,7 @@ def is_tosa(compile_spec: List[CompileSpec]) -> bool: return False -def get_intermediate_path(compile_spec: List[CompileSpec]) -> str: +def get_intermediate_path(compile_spec: List[CompileSpec]) -> Optional[str]: for spec in compile_spec: if spec.key == "debug_artifact_path": return spec.value.decode() diff --git a/backends/arm/arm_vela.py b/backends/arm/arm_vela.py index f387672b7b4..53533947c49 100644 --- a/backends/arm/arm_vela.py +++ b/backends/arm/arm_vela.py @@ -5,12 +5,12 @@ import os import struct -import subprocess import tempfile from typing import List import numpy as np +from ethosu.vela import vela # Pack either input or output tensor block, compose the related arrays into @@ -38,21 +38,17 @@ def vela_compile(tosa_graph, args: List[str]): with tempfile.TemporaryDirectory() as tmpdir: tosaname = "out.tosa" flatbuffer = tosa_graph.serialize() - with open(os.path.join(tmpdir, tosaname), "wb") as f: + tosa_path = os.path.join(tmpdir, tosaname) + with open(tosa_path, "wb") as f: f.write(flatbuffer) # invoke vela - vela_command = f"cd {tmpdir}; vela {' '.join(args)} {tosaname}" - try: - subprocess.run([vela_command], shell=True, check=True, capture_output=True) - except subprocess.CalledProcessError as process_error: - raise RuntimeError( - f"Vela compiler ('{vela_command}') failed with error:\n \ - {process_error.stderr.decode()}\n \ - Stdout:\n{process_error.stdout.decode()}" - ) - - np_path = os.path.join(tmpdir, "output", "out_sg0_vela.npz") + output_dir = os.path.join(tmpdir, "output") + args.append(f"--output-dir={output_dir}") + args.append(tosa_path) + vela.main(" ".join(args).split(" ")) + + np_path = os.path.join(output_dir, "out_sg0_vela.npz") blocks = b"" with np.load(np_path, allow_pickle=False) as data: diff --git a/backends/arm/operators/TARGETS b/backends/arm/operators/TARGETS new file mode 100644 index 00000000000..fd04d5fb847 --- /dev/null +++ b/backends/arm/operators/TARGETS @@ -0,0 +1,34 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +python_library( + name = "node_visitor", + srcs = ["node_visitor.py"], + typing = True, + deps = [ + "//executorch/backends/arm:tosa_mapping", + ], +) + +python_library( + name = "ops", + srcs = glob(["op_*.py"]), + typing = True, + deps = [ + "fbsource//third-party/serialization_lib/python/tosa:tosa", + ":node_visitor", + "//executorch/backends/arm:tosa_mapping", + "//executorch/backends/arm:tosa_quant_utils", + "//executorch/backends/arm:tosa_utils", + "//executorch/exir:lib", + ], +) + +python_library( + name = "lib", + srcs = ["__init__.py"], + typing = True, + deps = [ + ":node_visitor", + ":ops", + ], +) diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index 8d0235ebe73..59f28d3bad8 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -72,6 +72,7 @@ def define_node( build_rescale( tosa_fb=tosa_graph, scale=final_output_scale, + # pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined. input_node=bmm_result, output_name=output.name, output_type=ts.DType.INT8, diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index 9437e96f5e9..935c923ba42 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -2,7 +2,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List +from typing import cast, List import serializer.tosa_serializer as ts import torch @@ -156,11 +156,12 @@ def define_node( # integer value domain of the next op. Otherwise return float32 output. if is_quant_node: # Get scale_factor from input, weight, and output. - _, input_scale, _, _, _, _ = getNodeArgs(node.args[0]) - _, weight_scale, _, _, _, _ = getNodeArgs(node.args[1]) + _, input_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[0])) + _, weight_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[1])) _, output_scale, output_zp, _, _, _ = getNodeArgs(list(node.users)[0]) build_rescale_conv_output( tosa_graph, + # pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined. conv2d_res, output.name, actual_out_type, diff --git a/backends/arm/operators/op_mm.py b/backends/arm/operators/op_mm.py index f7097022f12..98152215035 100644 --- a/backends/arm/operators/op_mm.py +++ b/backends/arm/operators/op_mm.py @@ -96,6 +96,7 @@ def define_node( build_rescale( tosa_fb=tosa_graph, scale=final_output_scale, + # pyre-ignore[61]: Uninitialized local [61]: Local variable `reshape_intermediate` is undefined, or not always defined. input_node=reshape_intermediate, output_name=output.name, output_type=ts.DType.INT8, diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index e9cbfcbd7cc..f7c593e9fe3 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List +from typing import cast, List import executorch.backends.arm.tosa_quant_utils as tqutils import executorch.backends.arm.tosa_utils as tutils @@ -35,8 +35,12 @@ def define_node( if is_quant_node: input_A = inputs[0] input_B = inputs[1] - input_A_qargs = tqutils.get_quant_node_args(node.args[0]) - input_B_qargs = tqutils.get_quant_node_args(node.args[1]) + input_A_qargs = tqutils.get_quant_node_args( + cast(torch.fx.Node, node.args[0]) + ) + input_B_qargs = tqutils.get_quant_node_args( + cast(torch.fx.Node, node.args[1]) + ) input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order) input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order) diff --git a/backends/arm/operators/op_output.py b/backends/arm/operators/op_output.py index 7d163114aa8..89654ed2d48 100644 --- a/backends/arm/operators/op_output.py +++ b/backends/arm/operators/op_output.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import cast + import serializer.tosa_serializer as ts import torch @@ -11,7 +13,7 @@ def process_output( node: torch.fx.Node, tosa_graph: ts.TosaSerializer, ): - for output in node.args[0]: + for output in cast(tuple[torch.fx.Node, ...], node.args[0]): tosa_graph.addOutputTensor( tosa_graph.currRegion.currBasicBlock.tensors[output.name] ) diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index 682eacd5e38..5baedfc9627 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -6,6 +6,7 @@ import serializer.tosa_serializer as ts import torch +import tosa.Op as TosaOp from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -13,7 +14,6 @@ ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_utils import tosa_shape -from serializer.tosa_serializer import TosaOp @register_node_visitor diff --git a/backends/arm/passes/TARGETS b/backends/arm/passes/TARGETS new file mode 100644 index 00000000000..ca20b03fccd --- /dev/null +++ b/backends/arm/passes/TARGETS @@ -0,0 +1,12 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +python_library( + name = "passes", + srcs = glob(["*.py"]), + typing = True, + deps = [ + "//executorch/backends/arm:tosa_quant_utils", + "//executorch/backends/arm:tosa_utils", + "//executorch/exir:lib", + ], +) diff --git a/backends/arm/passes/annotate_channels_last_dim_order_pass.py b/backends/arm/passes/annotate_channels_last_dim_order_pass.py index ea3c171c580..8ba02c2f7e3 100644 --- a/backends/arm/passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/passes/annotate_channels_last_dim_order_pass.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import cast + import torch from executorch.backends.arm.tosa_quant_utils import dq_op from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d @@ -28,7 +30,7 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node): if node.target != dq_op: return False prev_node = node.args[0] - if prev_node.op != "placeholder": + if cast(torch.fx.Node, prev_node).op != "placeholder": return False return is_consumer_node_depthwise_conv2d(node) elif node.op == "placeholder": diff --git a/backends/arm/passes/arm_pass_manager.py b/backends/arm/passes/arm_pass_manager.py index 8cac53b1347..914bf57aabc 100644 --- a/backends/arm/passes/arm_pass_manager.py +++ b/backends/arm/passes/arm_pass_manager.py @@ -23,11 +23,11 @@ class ArmPassManager(PassManager): - def _transform(self, graph_module: torch.fx.Graph): + def _transform(self, graph_module: torch.fx.GraphModule): return self(graph_module).graph_module def transform_to_backend_pipeline( - self, graph_module: torch.fx.Graph, compile_spec: CompileSpec + self, graph_module: torch.fx.GraphModule, compile_spec: list[CompileSpec] ): """Apply passes before transforming program to backend""" self.add_pass(SizeAdjustConv2DPass()) diff --git a/backends/arm/passes/convert_expand_copy_to_repeat.py b/backends/arm/passes/convert_expand_copy_to_repeat.py index 53138682d56..5f409e1ae5f 100644 --- a/backends/arm/passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/passes/convert_expand_copy_to_repeat.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import cast + import torch.fx from executorch.backends.arm.tosa_mapping import extract_tensor_meta from executorch.exir.dialects._ops import ops as exir_ops @@ -31,7 +33,7 @@ def call(self, graph_module: torch.fx.GraphModule): expand_node = src_partition.nodes[0] _, shape, _ = extract_tensor_meta(expand_node.all_input_nodes[0].meta) - multiples = expand_node.args[1] + multiples = cast(tuple[int], expand_node.args[1]) expanded_rank = len(multiples) # Expanded shape is 'shape' front-padded with ones. diff --git a/backends/arm/passes/size_adjust_conv2d_pass.py b/backends/arm/passes/size_adjust_conv2d_pass.py index 25d27e7f40f..ea161b74928 100644 --- a/backends/arm/passes/size_adjust_conv2d_pass.py +++ b/backends/arm/passes/size_adjust_conv2d_pass.py @@ -85,8 +85,8 @@ def call(self, graph_module: torch.fx.GraphModule): input_node, weight, _, stride_hw, pad_hw, dilation_hw, _, _, _ = ( conv_node.args ) - weight_shape = weight.meta["val"].shape - input_shape = input_node.meta["val"].shape + weight_shape = cast(torch.fx.Node, weight).meta["val"].shape + input_shape = cast(torch.fx.Node, input_node).meta["val"].shape slice_args = [] for stride, pad, dilation, dim in zip( @@ -119,7 +119,7 @@ def call(self, graph_module: torch.fx.GraphModule): last_node = dq_node else: last_node = slice_node - conv_node.replace_input_with(input_node, last_node) + conv_node.replace_input_with(cast(torch.fx.Node, input_node), last_node) modified_graph = True if modified_graph: diff --git a/backends/arm/quantizer/TARGETS b/backends/arm/quantizer/TARGETS new file mode 100644 index 00000000000..840586488bf --- /dev/null +++ b/backends/arm/quantizer/TARGETS @@ -0,0 +1,31 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +python_library( + name = "arm_quantizer", + srcs = ["arm_quantizer.py"], + typing = True, + deps = [ + ":arm_quantizer_utils", + "//caffe2:torch", + "//executorch/backends/arm/quantizer/quantization_annotation:quantization_annotation", + "//executorch/exir:lib", + ], +) + +python_library( + name = "quantization_config", + srcs = ["quantization_config.py"], + typing = True, + deps = [ + "//caffe2:torch", + ], +) + +python_library( + name = "arm_quantizer_utils", + srcs = ["arm_quantizer_utils.py"], + typing = True, + deps = [ + ":quantization_config", + ], +) diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index 417aa454a8e..1cac297bc92 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -10,7 +10,7 @@ # import operator -from typing import Callable, cast, List +from typing import Callable, cast, List, Union import torch from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig @@ -72,7 +72,7 @@ def get_shared_qspec( Both outputs are None if one of the inputs is a node that can't be quantized. """ - input_act0 = node.args[0] + input_act0 = cast(Node, node.args[0]) input_act1 = node.args[1] input_act_qspec = quantization_config.get_input_act_qspec() @@ -169,7 +169,9 @@ def propagate_annotation(model: GraphModule) -> None: n = cast(Node, n) if is_annotated(n): continue - if n.op != "call_function" or not is_share_obs_or_fq_op(n.target): + if n.op != "call_function" or not is_share_obs_or_fq_op( + cast(Callable, n.target) + ): continue prev_node = n.args[0] @@ -217,7 +219,7 @@ def convert_scalars_to_attrs(model: GraphModule) -> GraphModule: prefix = "_tensor_constant_" get_new_attr_name = get_new_attr_name_with_prefix(prefix) tensor_constant_name = get_new_attr_name(model) - float_tensor = torch.tensor(float(args[i])) + float_tensor = torch.tensor(float(cast(Union[int, float], args[i]))) model.register_buffer(tensor_constant_name, float_tensor) fake_mode = n.meta["val"].fake_mode with model.graph.inserting_before(n): diff --git a/backends/arm/quantizer/quantization_annotation/TARGETS b/backends/arm/quantizer/quantization_annotation/TARGETS new file mode 100644 index 00000000000..4ce8b5cad2c --- /dev/null +++ b/backends/arm/quantizer/quantization_annotation/TARGETS @@ -0,0 +1,12 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +python_library( + name = "quantization_annotation", + srcs = glob(["*.py"]), + typing = True, + deps = [ + "//caffe2:torch", + "//executorch/backends/arm/quantizer:arm_quantizer_utils", + "//executorch/backends/arm/quantizer:quantization_config", + ], +) diff --git a/backends/arm/quantizer/quantization_annotation/cat_annotator.py b/backends/arm/quantizer/quantization_annotation/cat_annotator.py index 40dd19526b3..992070ac172 100644 --- a/backends/arm/quantizer/quantization_annotation/cat_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/cat_annotator.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import itertools -from typing import Callable, List, Optional +from typing import Callable, cast, List, Optional import torch.fx from executorch.backends.arm.quantizer import arm_quantizer_utils @@ -34,7 +34,7 @@ def _annotate_cat( if arm_quantizer_utils.is_annotated(cat_node): continue - input_acts = cat_node.args[0] + input_acts = cast(list[torch.fx.Node], cat_node.args[0]) input_act0 = input_acts[0] input_act_qspec = quantization_config.get_input_act_qspec() diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index c0d16d51b25..d93f2544070 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -6,15 +6,16 @@ # Utiliy functions for TOSA quantized lowerings import math -from typing import NamedTuple +from typing import NamedTuple, Sequence import numpy as np import serializer.tosa_serializer as ts import torch.fx +import tosa.Op as TosaOp from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg from executorch.exir.dialects._ops import ops as exir_ops -from serializer.tosa_serializer import TosaOp, TosaSerializerTensor +from serializer.tosa_serializer import TosaSerializerTensor from torch.fx import Node q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default @@ -65,6 +66,7 @@ def is_quant_node(node: torch.fx.Node): def get_quant_node_dtype(node: torch.fx.Node): + # pyre-ignore[16]: Undefined attribute. if "tosa" in node.target.__name__: return node.meta["val"].dtype @@ -231,7 +233,7 @@ def build_rescale_from_int32( rescale_scale, is_scale32=True, is_double_round=False, -) -> TosaSerializerTensor: +) -> None: multiplier, shift = compute_multiplier_and_shift(rescale_scale) attr_rescale_output = ts.TosaSerializerAttribute() attr_rescale_output.RescaleAttribute( @@ -254,7 +256,7 @@ def build_rescale_from_int32( def rescale_nodes_to_int32( - nodes: list[Node], tosa_graph: ts.TosaSerializer + nodes: Sequence[Node], tosa_graph: ts.TosaSerializer ) -> tuple[list[TosaSerializerTensor], float]: """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'. The scales are adjusted using the smallest scale of all 'nodes'. diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index f84e371279b..5353dd49fae 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -5,7 +5,7 @@ import logging import os -from typing import Any, Dict +from typing import Any, cast, Dict import numpy as np import serializer.tosa_serializer as ts @@ -235,7 +235,7 @@ def build_avg_pool_2d_common( output_zp = 0 if is_quant_node: - input_zp = get_quant_node_args(node.args[0]).zp + input_zp = get_quant_node_args(cast(torch.fx.Node, node.args[0])).zp output_zp = get_quant_node_args(list(node.users)[0]).zp attr = ts.TosaSerializerAttribute() @@ -306,7 +306,9 @@ def process_call_function( ) # Visiting each Node + # pyre-ignore[16]: Undefined attribute. if node.target.__name__ in node_visitors: + # pyre-ignore[16]: Undefined attribute. node_visitors[node.target.__name__].define_node( node, tosa_graph, @@ -319,7 +321,10 @@ def process_call_function( def expand_dims( - tosa_graph: ts.TosaSerializer, input_node: TosaArg, dtype: ts.DType, dim: int + tosa_graph: ts.TosaSerializer, + input_node: TosaArg, + dtype: int, + dim: int, ) -> Any: """Inserts TOSA operators into the tosa_graph, that perform the equivalent of the expand_dims (a.k.a unsqueeze) operation. A new axis is created at the