Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions backends/arm/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

python_library(
name = "arm_partitioner",
srcs = [
"arm_partitioner.py",
],
typing = True,
deps = [
":arm_backend",
"//executorch/backends/arm/passes:passes",
"//executorch/exir:lib",
],
)

python_library(
name = "arm_backend",
srcs = [
"arm_backend.py",
],
typing = True,
deps = [
"fbsource//third-party/pypi/flatbuffers:flatbuffers",
"fbsource//third-party/pypi/ml-dtypes:ml-dtypes",
"fbsource//third-party/serialization_lib/python/serializer:serializer",
"fbsource//third-party/serialization_lib/python/tosa:tosa",
":arm_vela",
"//executorch/backends/arm/operators:lib",
"//executorch/backends/arm/operators:node_visitor",
"//executorch/backends/arm/passes:passes",
],
)

python_library(
name = "arm_vela",
srcs = [
"arm_vela.py",
],
typing = True,
deps = [
"fbsource//third-party/pypi/ethos-u-vela:ethos-u-vela",
],
)

python_library(
name = "tosa_mapping",
srcs = [
"tosa_mapping.py",
],
typing = True,
deps = [
"fbsource//third-party/serialization_lib/python/serializer:serializer",
"//caffe2:torch",
],
)

python_library(
name = "tosa_quant_utils",
srcs = [
"tosa_quant_utils.py",
],
typing = True,
deps = [
"fbsource//third-party/pypi/numpy:numpy",
"fbsource//third-party/serialization_lib/python/serializer:serializer",
"fbsource//third-party/serialization_lib/python/tosa:tosa",
":tosa_mapping",
"//executorch/exir/dialects:lib",
],
)

python_library(
name = "tosa_utils",
srcs = [
"tosa_utils.py",
],
typing = True,
deps = [
"fbsource//third-party/serialization_lib/python/serializer:serializer",
":tosa_quant_utils",
"//executorch/backends/arm/operators:node_visitor",
],
)
2 changes: 1 addition & 1 deletion backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def is_tosa(compile_spec: List[CompileSpec]) -> bool:
return False


def get_intermediate_path(compile_spec: List[CompileSpec]) -> str:
def get_intermediate_path(compile_spec: List[CompileSpec]) -> Optional[str]:
for spec in compile_spec:
if spec.key == "debug_artifact_path":
return spec.value.decode()
Expand Down
22 changes: 9 additions & 13 deletions backends/arm/arm_vela.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

import os
import struct
import subprocess
import tempfile

from typing import List

import numpy as np
from ethosu.vela import vela


# Pack either input or output tensor block, compose the related arrays into
Expand Down Expand Up @@ -38,21 +38,17 @@ def vela_compile(tosa_graph, args: List[str]):
with tempfile.TemporaryDirectory() as tmpdir:
tosaname = "out.tosa"
flatbuffer = tosa_graph.serialize()
with open(os.path.join(tmpdir, tosaname), "wb") as f:
tosa_path = os.path.join(tmpdir, tosaname)
with open(tosa_path, "wb") as f:
f.write(flatbuffer)

# invoke vela
vela_command = f"cd {tmpdir}; vela {' '.join(args)} {tosaname}"
try:
subprocess.run([vela_command], shell=True, check=True, capture_output=True)
except subprocess.CalledProcessError as process_error:
raise RuntimeError(
f"Vela compiler ('{vela_command}') failed with error:\n \
{process_error.stderr.decode()}\n \
Stdout:\n{process_error.stdout.decode()}"
)

np_path = os.path.join(tmpdir, "output", "out_sg0_vela.npz")
output_dir = os.path.join(tmpdir, "output")
args.append(f"--output-dir={output_dir}")
args.append(tosa_path)
vela.main(" ".join(args).split(" "))

np_path = os.path.join(output_dir, "out_sg0_vela.npz")
blocks = b""

with np.load(np_path, allow_pickle=False) as data:
Expand Down
34 changes: 34 additions & 0 deletions backends/arm/operators/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

python_library(
name = "node_visitor",
srcs = ["node_visitor.py"],
typing = True,
deps = [
"//executorch/backends/arm:tosa_mapping",
],
)

python_library(
name = "ops",
srcs = glob(["op_*.py"]),
typing = True,
deps = [
"fbsource//third-party/serialization_lib/python/tosa:tosa",
":node_visitor",
"//executorch/backends/arm:tosa_mapping",
"//executorch/backends/arm:tosa_quant_utils",
"//executorch/backends/arm:tosa_utils",
"//executorch/exir:lib",
],
)

python_library(
name = "lib",
srcs = ["__init__.py"],
typing = True,
deps = [
":node_visitor",
":ops",
],
)
1 change: 1 addition & 0 deletions backends/arm/operators/op_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def define_node(
build_rescale(
tosa_fb=tosa_graph,
scale=final_output_scale,
# pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined.
input_node=bmm_result,
output_name=output.name,
output_type=ts.DType.INT8,
Expand Down
7 changes: 4 additions & 3 deletions backends/arm/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
from typing import cast, List

import serializer.tosa_serializer as ts
import torch
Expand Down Expand Up @@ -156,11 +156,12 @@ def define_node(
# integer value domain of the next op. Otherwise return float32 output.
if is_quant_node:
# Get scale_factor from input, weight, and output.
_, input_scale, _, _, _, _ = getNodeArgs(node.args[0])
_, weight_scale, _, _, _, _ = getNodeArgs(node.args[1])
_, input_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[0]))
_, weight_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[1]))
_, output_scale, output_zp, _, _, _ = getNodeArgs(list(node.users)[0])
build_rescale_conv_output(
tosa_graph,
# pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.
conv2d_res,
output.name,
actual_out_type,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/op_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def define_node(
build_rescale(
tosa_fb=tosa_graph,
scale=final_output_scale,
# pyre-ignore[61]: Uninitialized local [61]: Local variable `reshape_intermediate` is undefined, or not always defined.
input_node=reshape_intermediate,
output_name=output.name,
output_type=ts.DType.INT8,
Expand Down
10 changes: 7 additions & 3 deletions backends/arm/operators/op_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List
from typing import cast, List

import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils
Expand Down Expand Up @@ -35,8 +35,12 @@ def define_node(
if is_quant_node:
input_A = inputs[0]
input_B = inputs[1]
input_A_qargs = tqutils.get_quant_node_args(node.args[0])
input_B_qargs = tqutils.get_quant_node_args(node.args[1])
input_A_qargs = tqutils.get_quant_node_args(
cast(torch.fx.Node, node.args[0])
)
input_B_qargs = tqutils.get_quant_node_args(
cast(torch.fx.Node, node.args[1])
)

input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)
input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order)
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/operators/op_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast

import serializer.tosa_serializer as ts
import torch

Expand All @@ -11,7 +13,7 @@ def process_output(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
):
for output in node.args[0]:
for output in cast(tuple[torch.fx.Node, ...], node.args[0]):
tosa_graph.addOutputTensor(
tosa_graph.currRegion.currBasicBlock.tensors[output.name]
)
2 changes: 1 addition & 1 deletion backends/arm/operators/op_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

import serializer.tosa_serializer as ts
import torch
import tosa.Op as TosaOp

from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_utils import tosa_shape
from serializer.tosa_serializer import TosaOp


@register_node_visitor
Expand Down
12 changes: 12 additions & 0 deletions backends/arm/passes/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

python_library(
name = "passes",
srcs = glob(["*.py"]),
typing = True,
deps = [
"//executorch/backends/arm:tosa_quant_utils",
"//executorch/backends/arm:tosa_utils",
"//executorch/exir:lib",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast

import torch
from executorch.backends.arm.tosa_quant_utils import dq_op
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
Expand All @@ -28,7 +30,7 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
if node.target != dq_op:
return False
prev_node = node.args[0]
if prev_node.op != "placeholder":
if cast(torch.fx.Node, prev_node).op != "placeholder":
return False
return is_consumer_node_depthwise_conv2d(node)
elif node.op == "placeholder":
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@

class ArmPassManager(PassManager):

def _transform(self, graph_module: torch.fx.Graph):
def _transform(self, graph_module: torch.fx.GraphModule):
return self(graph_module).graph_module

def transform_to_backend_pipeline(
self, graph_module: torch.fx.Graph, compile_spec: CompileSpec
self, graph_module: torch.fx.GraphModule, compile_spec: list[CompileSpec]
):
"""Apply passes before transforming program to backend"""
self.add_pass(SizeAdjustConv2DPass())
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/passes/convert_expand_copy_to_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast

import torch.fx
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
from executorch.exir.dialects._ops import ops as exir_ops
Expand Down Expand Up @@ -31,7 +33,7 @@ def call(self, graph_module: torch.fx.GraphModule):

expand_node = src_partition.nodes[0]
_, shape, _ = extract_tensor_meta(expand_node.all_input_nodes[0].meta)
multiples = expand_node.args[1]
multiples = cast(tuple[int], expand_node.args[1])
expanded_rank = len(multiples)

# Expanded shape is 'shape' front-padded with ones.
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/passes/size_adjust_conv2d_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def call(self, graph_module: torch.fx.GraphModule):
input_node, weight, _, stride_hw, pad_hw, dilation_hw, _, _, _ = (
conv_node.args
)
weight_shape = weight.meta["val"].shape
input_shape = input_node.meta["val"].shape
weight_shape = cast(torch.fx.Node, weight).meta["val"].shape
input_shape = cast(torch.fx.Node, input_node).meta["val"].shape

slice_args = []
for stride, pad, dilation, dim in zip(
Expand Down Expand Up @@ -119,7 +119,7 @@ def call(self, graph_module: torch.fx.GraphModule):
last_node = dq_node
else:
last_node = slice_node
conv_node.replace_input_with(input_node, last_node)
conv_node.replace_input_with(cast(torch.fx.Node, input_node), last_node)
modified_graph = True

if modified_graph:
Expand Down
31 changes: 31 additions & 0 deletions backends/arm/quantizer/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

python_library(
name = "arm_quantizer",
srcs = ["arm_quantizer.py"],
typing = True,
deps = [
":arm_quantizer_utils",
"//caffe2:torch",
"//executorch/backends/arm/quantizer/quantization_annotation:quantization_annotation",
"//executorch/exir:lib",
],
)

python_library(
name = "quantization_config",
srcs = ["quantization_config.py"],
typing = True,
deps = [
"//caffe2:torch",
],
)

python_library(
name = "arm_quantizer_utils",
srcs = ["arm_quantizer_utils.py"],
typing = True,
deps = [
":quantization_config",
],
)
Loading
Loading