Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
class ArmCompileSpecBuilder:
class DebugMode(Enum):
JSON = 1
TOSA = 2

def __init__(self):
self.compile_spec: List[CompileSpec] = []
Expand Down
51 changes: 34 additions & 17 deletions backends/arm/debug/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
import json

from dataclasses import asdict, dataclass
from typing import Any
from typing import Any, Optional

import serializer.tosa_serializer as ts # type: ignore
import torch

from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder

from torch.fx.traceback import NodeSource


Expand Down Expand Up @@ -97,37 +99,52 @@ def from_node(node: torch.fx.Node) -> TorchDebugSchema:
class DebugSchema:
event_id: int
aten_info: ATenDebugSchema
tosa_info: TosaDebugSchema
tosa_info: Optional[TosaDebugSchema]
torch_info: TorchDebugSchema

def to_dict(self) -> dict[str, Any]:
output = asdict(self)

if self.tosa_info is None:
output.pop("tosa_info")

return output


class DebugHook:
def __init__(self) -> None:
def __init__(self, debug_mode: ArmCompileSpecBuilder.DebugMode) -> None:
self._debug_events: list[DebugSchema] = []
self.__op_id_to_name = {}
self.mode = debug_mode

# Build up a mapping from TOSA 1.0 operator IDs to their names
for name, val in vars(ts.Op).items():
self.__op_id_to_name[val] = name

def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> None:
tosa_debug_info = TosaDebugSchema(
node_name=str(tosa_op),
operator_name=self.__op_id_to_name[tosa_op_id],
operator_id=tosa_op_id,
)
def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> DebugSchema:
tosa_debug_info = None

# If the debug data is being embedded into the TOSA flatbuffer
# do not collect TOSADebugSchema data, it's redundent
if self.mode != ArmCompileSpecBuilder.DebugMode.TOSA:
tosa_debug_info = TosaDebugSchema(
node_name=str(tosa_op),
operator_name=self.__op_id_to_name[tosa_op_id],
operator_id=tosa_op_id,
)

aten_debug_info = ATenDebugSchema.from_node(node)
torch_debug_info = TorchDebugSchema.from_node(node)

self._debug_events.append(
DebugSchema(
event_id=len(self._debug_events),
aten_info=aten_debug_info,
tosa_info=tosa_debug_info,
torch_info=torch_debug_info,
)
debug_info = DebugSchema(
event_id=len(self._debug_events),
aten_info=aten_debug_info,
tosa_info=tosa_debug_info,
torch_info=torch_debug_info,
)
self._debug_events.append(debug_info)

return debug_info

def serialize(self) -> str:
return json.dumps([asdict(event) for event in self._debug_events], indent=4)
return json.dumps([event.to_dict() for event in self._debug_events], indent=4)
21 changes: 14 additions & 7 deletions backends/arm/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

# pyre-unsafe

import json
from typing import Any, Dict, List, Optional

import torch

from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
from executorch.backends.arm.debug.schema import DebugHook
from executorch.backends.arm.tosa.mapping import TosaArg
from executorch.backends.arm.tosa.specification import TosaSpecification
Expand Down Expand Up @@ -49,20 +51,25 @@ def _serialize_operator(
outputs: List[str],
attributes: Optional[Any] = None,
) -> None:
op_location = ""
if self.debug_hook:
debug_info = self.debug_hook.add(
node,
tosa_op=outputs[0],
tosa_op_id=tosa_op,
)

if self.debug_hook.mode == ArmCompileSpecBuilder.DebugMode.TOSA:
op_location = json.dumps(debug_info.to_dict())

tosa_graph.addOperator(
tosa_op,
inputs=inputs,
outputs=outputs,
attributes=attributes,
location=op_location,
)

if self.debug_hook:
self.debug_hook.add(
node,
tosa_op=outputs[0],
tosa_op_id=tosa_op,
)

def define_node(
self,
node: torch.fx.Node,
Expand Down
21 changes: 21 additions & 0 deletions backends/arm/test/misc/test_debug_feats.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,27 @@ def test_dump_tosa_debug_json(test_data: input_t1):
pytest.fail("Failed to load debug JSON file")


@common.parametrize("test_data", Linear.inputs)
def test_dump_tosa_debug_tosa(test_data: input_t1):
with tempfile.TemporaryDirectory() as tmpdir:
pipeline = TosaPipelineINT[input_t1](
module=Linear(),
test_data=test_data,
aten_op=[],
exir_op=[],
custom_path=tmpdir,
tosa_debug_mode=ArmCompileSpecBuilder.DebugMode.TOSA,
)

pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()

json_output_path = Path(tmpdir) / "debug.json"

# A JSON file should not be created when TOSA mode used
assert not json_output_path.exists()


@common.parametrize("test_data", Linear.inputs)
def test_dump_tosa_ops(caplog, test_data: input_t1):
pipeline = TosaPipelineINT[input_t1](Linear(), test_data, [], [])
Expand Down
19 changes: 17 additions & 2 deletions backends/arm/test/misc/test_debug_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass
from types import SimpleNamespace

from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
from executorch.backends.arm.debug.schema import DebugHook, DebugSchema
from executorch.backends.arm.test import common

Expand Down Expand Up @@ -156,8 +157,8 @@ def _compare_node_and_schema(debug_event: DebugSchema, mocked_node):


@common.parametrize("test_data", TESTCASES)
def test_debug_hook_add_1(test_data: DebugHookTestCase):
hook = DebugHook()
def test_debug_hook_add_json(test_data: DebugHookTestCase):
hook = DebugHook(ArmCompileSpecBuilder.DebugMode.JSON)
hook.add(test_data.mock_node, test_data.tosa_op, test_data.op_id)

debug_events = hook._debug_events
Expand All @@ -166,3 +167,17 @@ def test_debug_hook_add_1(test_data: DebugHookTestCase):

_compare_tosa_and_schema(debug_events[0], test_data.tosa_op)
_compare_node_and_schema(debug_events[0], test_data.mock_node)


@common.parametrize("test_data", TESTCASES)
def test_debug_hook_add_tosa(test_data: DebugHookTestCase):
hook = DebugHook(ArmCompileSpecBuilder.DebugMode.TOSA)
hook.add(test_data.mock_node, test_data.tosa_op, test_data.op_id)

debug_events = hook._debug_events
assert len(debug_events) == test_data.expected_events
assert len(debug_events[0].torch_info.node_trace) == test_data.num_nodes_traced

assert debug_events[0].tosa_info is None

_compare_node_and_schema(debug_events[0], test_data.mock_node)
12 changes: 7 additions & 5 deletions backends/arm/tosa/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import cast, final, List

import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
from executorch.backends.arm.common.debug import debug_fail, debug_tosa_dump
from executorch.backends.arm.debug.schema import DebugHook
from executorch.backends.arm.process_node import (
Expand Down Expand Up @@ -100,7 +101,7 @@ def preprocess( # noqa: C901

debug_hook = None
if dump_debug_info is not None:
debug_hook = DebugHook()
debug_hook = DebugHook(ArmCompileSpecBuilder.DebugMode[dump_debug_info])

# TODO: Fix the need to lazily import this.
from executorch.backends.arm.operators.node_visitor import get_node_visitors
Expand Down Expand Up @@ -136,10 +137,11 @@ def preprocess( # noqa: C901
suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"),
)

if debug_hook:
json_output = debug_hook.serialize()
with open(f"{artifact_path}/debug.json", "w") as f:
f.write(json_output)
if debug_hook is not None:
if debug_hook.mode == ArmCompileSpecBuilder.DebugMode.JSON:
json_output = debug_hook.serialize()
with open(f"{artifact_path}/debug.json", "w") as f:
f.write(json_output)

# Serialize and return the TOSA flatbuffer.
binary = bytes(tosa_graph.serialize())
Expand Down
Loading