Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ ignore_missing_imports = True
[mypy-tosa_tools.*]
ignore_missing_imports = True

[mypy-tosa_serializer]
ignore_missing_imports = True

[mypy-tosa_serializer.*]
ignore_missing_imports = True

[mypy-setuptools.*]
ignore_missing_imports = True

Expand Down
22 changes: 7 additions & 15 deletions backends/arm/common/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import os
from typing import Optional

import serializer.tosa_serializer as ts
import torch

import tosa_serializer as ts
from executorch.exir.print_program import inspect_node

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -50,29 +51,20 @@ def get_node_debug_info(
return output


# Output TOSA flatbuffer and test harness file
def debug_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""):
# Output TOSA flatbuffer for debugging
def debug_tosa_dump(tosa_graph: bytes, path: str, suffix: str = ""):
filename = f"output{suffix}.tosa"

logger.info(f"Emitting debug output to: {path=}, {suffix=}")

os.makedirs(path, exist_ok=True)

fb = tosa_graph.serialize()
js = tosa_graph.writeJson(filename)

filepath_tosa_fb = os.path.join(path, filename)
with open(filepath_tosa_fb, "wb") as f:
f.write(fb)
f.write(tosa_graph)
if not os.path.exists(filepath_tosa_fb):
raise IOError("Failed to write TOSA flatbuffer")

filepath_desc_json = os.path.join(path, f"desc{suffix}.json")
with open(filepath_desc_json, "w") as f:
f.write(js)
if not os.path.exists(filepath_desc_json):
raise IOError("Failed to write TOSA JSON")


def debug_fail(
node,
Expand All @@ -81,7 +73,7 @@ def debug_fail(
path: Optional[str] = None,
):
logger.warning("Internal error due to poorly handled node:")
if tosa_graph is not None and path is not None:
debug_tosa_dump(tosa_graph, path)
if tosa_graph is not None and path:
debug_tosa_dump(tosa_graph.serialize(), path)
logger.warning(f"Debug output captured in '{path}'.")
debug_node(node, graph_module)
13 changes: 4 additions & 9 deletions backends/arm/debug/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from dataclasses import asdict, dataclass
from typing import Any, Optional

import serializer.tosa_serializer as ts
import torch
import tosa_serializer as ts

from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec

Expand Down Expand Up @@ -114,23 +114,18 @@ def to_dict(self) -> dict[str, Any]:
class DebugHook:
def __init__(self, debug_mode: ArmCompileSpec.DebugMode) -> None:
self._debug_events: list[DebugSchema] = []
self.__op_id_to_name = {}
self.mode = debug_mode

# Build up a mapping from TOSA 1.0 operator IDs to their names
for name, val in vars(ts.Op).items():
self.__op_id_to_name[val] = name

def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> DebugSchema:
def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: ts.Op) -> DebugSchema:
tosa_debug_info = None

# If the debug data is being embedded into the TOSA flatbuffer
# do not collect TOSADebugSchema data, it's redundent
if self.mode != ArmCompileSpec.DebugMode.TOSA:
tosa_debug_info = TosaDebugSchema(
node_name=str(tosa_op),
operator_name=self.__op_id_to_name[tosa_op_id],
operator_id=tosa_op_id,
operator_name=str(tosa_op_id),
operator_id=int(tosa_op_id),
)

aten_debug_info = ATenDebugSchema.from_node(node)
Expand Down
9 changes: 9 additions & 0 deletions backends/arm/ethosu/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ def _compile_tosa_flatbuffer(
"compile_flags are required in the CompileSpec list for EthosUBackend"
)

# Vela tooling only supports flatbuffers up to 2 GiB.
max_flatbuffer_size = 2 * 1024 * 1024 * 1024
flatbuffer_size = len(tosa_flatbuffer)
if flatbuffer_size > max_flatbuffer_size:
raise RuntimeError(
"TOSA flatbuffer is too large for Vela "
f"({flatbuffer_size} bytes > {max_flatbuffer_size} bytes limit)."
)

# Pass on the TOSA flatbuffer to the vela compiler.
binary = vela_compile(
tosa_flatbuffer,
Expand Down
7 changes: 4 additions & 3 deletions backends/arm/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Dict, List, Optional

import torch
import tosa_serializer as ts

from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
from executorch.backends.arm.debug.schema import DebugHook
Expand Down Expand Up @@ -46,12 +47,12 @@ def _serialize_operator(
self,
node: torch.fx.Node,
tosa_graph: Any,
tosa_op: Any,
tosa_op: ts.Op,
inputs: List[str],
outputs: List[str],
attributes: Optional[Any] = None,
) -> None:
op_location = ""
op_location = ts.TosaOpLocation()
if self.debug_hook:
debug_info = self.debug_hook.add(
node,
Expand All @@ -60,7 +61,7 @@ def _serialize_operator(
)

if self.debug_hook.mode == ArmCompileSpec.DebugMode.TOSA:
op_location = json.dumps(debug_info.to_dict())
op_location.text = json.dumps(debug_info.to_dict())

tosa_graph.addOperator(
tosa_op,
Expand Down
16 changes: 9 additions & 7 deletions backends/arm/operators/op_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# pyre-unsafe
from typing import Any, List

import serializer.tosa_serializer as ts
import tosa_serializer as ts

from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
Expand Down Expand Up @@ -48,11 +48,13 @@ def define_node(
output.tosa_spec,
)

tosa_graph.addOperator(
ts.TosaOp.Op().ABS,
[
inputs[0].name,
],
attr = ts.TosaSerializerAttribute()
attr.AbsAttribute()
self._serialize_operator(
node,
tosa_graph,
ts.Op.ABS,
[inputs[0].name],
[output.name],
None,
attr,
)
16 changes: 9 additions & 7 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import executorch.backends.arm.tosa.quant_utils as tqutils
import executorch.backends.arm.tosa.utils as tutils
import serializer.tosa_serializer as ts
import tosa_serializer as ts

from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
Expand Down Expand Up @@ -81,15 +81,16 @@ def define_node(
add_output = output

input1, input2 = rescaled_inputs

attr = ts.TosaSerializerAttribute()
attr.AddAttribute()
# Do the INT32 Add
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().ADD,
ts.Op.ADD,
[input1.name, input2.name],
[add_output.name],
None,
attr,
)

if output.dtype == ts.DType.INT8:
Expand Down Expand Up @@ -143,13 +144,14 @@ def define_node(
)

input1, input2 = inputs

attr = ts.TosaSerializerAttribute()
attr.AddAttribute()
# FP lowering
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().ADD,
ts.Op.ADD,
[input1.name, input2.name],
[output.name],
None,
attr,
)
7 changes: 4 additions & 3 deletions backends/arm/operators/op_amax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
from typing import Any, List

import serializer.tosa_serializer as ts
import tosa_serializer as ts

from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm.operators.node_visitor import (
Expand Down Expand Up @@ -60,11 +60,12 @@ def define_node(
)

attr = ts.TosaSerializerAttribute()
attr.ReduceMaxAttribute(axis=input.dim_order.index(dim), nan_mode=1)
nan_mode = ts.NanPropagationMode.PROPAGATE
attr.ReduceMaxAttribute(axis=input.dim_order.index(dim), nan_mode=nan_mode)
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().REDUCE_MAX,
ts.Op.REDUCE_MAX,
[input.name],
[output.name],
attr,
Expand Down
8 changes: 5 additions & 3 deletions backends/arm/operators/op_amin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
from typing import Any, List

import serializer.tosa_serializer as ts
import tosa_serializer as ts

from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm.operators.node_visitor import (
Expand Down Expand Up @@ -60,11 +60,13 @@ def define_node(
)

attr = ts.TosaSerializerAttribute()
attr.ReduceMinAttribute(axis=input.dim_order.index(dim), nan_mode=1)
attr.ReduceMinAttribute(
axis=input.dim_order.index(dim), nan_mode=ts.NanPropagationMode.PROPAGATE
)
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().REDUCE_MIN,
ts.Op.REDUCE_MIN,
[input.name],
[output.name],
attr,
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operators/op_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# pyre-unsafe
from typing import Any, cast, List

import serializer.tosa_serializer as ts
import tosa_serializer as ts

from executorch.backends.arm.operators.node_visitor import ( # type: ignore
NodeVisitor,
Expand Down Expand Up @@ -55,7 +55,7 @@ def define_node(
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().REDUCE_ANY,
ts.Op.REDUCE_ANY,
[inputs[0].name],
[output.name],
attr,
Expand Down
15 changes: 6 additions & 9 deletions backends/arm/operators/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
# pyre-unsafe
from typing import Any, List

import serializer.tosa_serializer as ts

import torch

import tosa_serializer as ts

from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
get_output_qparams,
Expand Down Expand Up @@ -93,17 +93,14 @@ def _build_generic_avgpool2d(
pad=pad_size_list,
acc_type=accumulator_type,
)
input_zp_tensor = tosa_graph.addConst(
shape=[1], dtype=output.dtype, vals=[input_zp]
)
output_zp_tensor = tosa_graph.addConst(
shape=[1], dtype=output.dtype, vals=[output_zp]
)
dt: ts.DType = output.dtype
input_zp_tensor = tosa_graph.addConst(shape=[1], dtype=dt, vals=[input_zp])
output_zp_tensor = tosa_graph.addConst(shape=[1], dtype=dt, vals=[output_zp])

self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().AVG_POOL2D,
ts.Op.AVG_POOL2D,
[input_tensor.name, input_zp_tensor.name, output_zp_tensor.name],
[output.name],
attr,
Expand Down
8 changes: 6 additions & 2 deletions backends/arm/operators/op_bitwise_not.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Any, List

import serializer.tosa_serializer as ts
import tosa_serializer as ts

from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
Expand Down Expand Up @@ -49,10 +49,14 @@ def define_node(
output.tosa_spec,
)

attr = ts.TosaSerializerAttribute()
attr.BitwiseNotAttribute()

self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().BITWISE_NOT,
ts.Op.BITWISE_NOT,
[inputs[0].name],
[output.name],
attr,
)
4 changes: 2 additions & 2 deletions backends/arm/operators/op_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from typing import Any, List

import serializer.tosa_serializer as ts
import tosa_serializer as ts

from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
Expand Down Expand Up @@ -50,7 +50,7 @@ def define_node(
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().CONCAT,
ts.Op.CONCAT,
[tensor.name for tensor in tensors],
[output.name],
attr,
Expand Down
8 changes: 5 additions & 3 deletions backends/arm/operators/op_ceil.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from typing import Any, List

import serializer.tosa_serializer as ts

import torch.fx

import tosa_serializer as ts

from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
Expand Down Expand Up @@ -49,6 +49,8 @@ def define_node(
output.tosa_spec,
)

attr = ts.TosaSerializerAttribute()
attr.CeilAttribute()
self._serialize_operator(
node, tosa_graph, ts.TosaOp.Op().CEIL, [inputs[0].name], [output.name]
node, tosa_graph, ts.Op.CEIL, [inputs[0].name], [output.name], attr
)
Loading
Loading