diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index 4dd49d8528f..2e71f91dbb6 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -23,6 +23,7 @@ class ArmCompileSpecBuilder: class DebugMode(Enum): JSON = 1 + TOSA = 2 def __init__(self): self.compile_spec: List[CompileSpec] = [] diff --git a/backends/arm/debug/schema.py b/backends/arm/debug/schema.py index bb06ddba864..82f0fd6bf7e 100644 --- a/backends/arm/debug/schema.py +++ b/backends/arm/debug/schema.py @@ -8,11 +8,13 @@ import json from dataclasses import asdict, dataclass -from typing import Any +from typing import Any, Optional import serializer.tosa_serializer as ts # type: ignore import torch +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder + from torch.fx.traceback import NodeSource @@ -97,37 +99,52 @@ def from_node(node: torch.fx.Node) -> TorchDebugSchema: class DebugSchema: event_id: int aten_info: ATenDebugSchema - tosa_info: TosaDebugSchema + tosa_info: Optional[TosaDebugSchema] torch_info: TorchDebugSchema + def to_dict(self) -> dict[str, Any]: + output = asdict(self) + + if self.tosa_info is None: + output.pop("tosa_info") + + return output + class DebugHook: - def __init__(self) -> None: + def __init__(self, debug_mode: ArmCompileSpecBuilder.DebugMode) -> None: self._debug_events: list[DebugSchema] = [] self.__op_id_to_name = {} + self.mode = debug_mode # Build up a mapping from TOSA 1.0 operator IDs to their names for name, val in vars(ts.Op).items(): self.__op_id_to_name[val] = name - def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> None: - tosa_debug_info = TosaDebugSchema( - node_name=str(tosa_op), - operator_name=self.__op_id_to_name[tosa_op_id], - operator_id=tosa_op_id, - ) + def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> DebugSchema: + tosa_debug_info = None + + # If the debug data is being embedded into the TOSA flatbuffer + # do not collect TOSADebugSchema data, it's redundent + if self.mode != ArmCompileSpecBuilder.DebugMode.TOSA: + tosa_debug_info = TosaDebugSchema( + node_name=str(tosa_op), + operator_name=self.__op_id_to_name[tosa_op_id], + operator_id=tosa_op_id, + ) aten_debug_info = ATenDebugSchema.from_node(node) torch_debug_info = TorchDebugSchema.from_node(node) - self._debug_events.append( - DebugSchema( - event_id=len(self._debug_events), - aten_info=aten_debug_info, - tosa_info=tosa_debug_info, - torch_info=torch_debug_info, - ) + debug_info = DebugSchema( + event_id=len(self._debug_events), + aten_info=aten_debug_info, + tosa_info=tosa_debug_info, + torch_info=torch_debug_info, ) + self._debug_events.append(debug_info) + + return debug_info def serialize(self) -> str: - return json.dumps([asdict(event) for event in self._debug_events], indent=4) + return json.dumps([event.to_dict() for event in self._debug_events], indent=4) diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index 5881486bfef..54a81bdaaff 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -5,10 +5,12 @@ # pyre-unsafe +import json from typing import Any, Dict, List, Optional import torch +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.debug.schema import DebugHook from executorch.backends.arm.tosa.mapping import TosaArg from executorch.backends.arm.tosa.specification import TosaSpecification @@ -49,20 +51,25 @@ def _serialize_operator( outputs: List[str], attributes: Optional[Any] = None, ) -> None: + op_location = "" + if self.debug_hook: + debug_info = self.debug_hook.add( + node, + tosa_op=outputs[0], + tosa_op_id=tosa_op, + ) + + if self.debug_hook.mode == ArmCompileSpecBuilder.DebugMode.TOSA: + op_location = json.dumps(debug_info.to_dict()) + tosa_graph.addOperator( tosa_op, inputs=inputs, outputs=outputs, attributes=attributes, + location=op_location, ) - if self.debug_hook: - self.debug_hook.add( - node, - tosa_op=outputs[0], - tosa_op_id=tosa_op, - ) - def define_node( self, node: torch.fx.Node, diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index 5648070f869..3e10a9336f9 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -222,6 +222,27 @@ def test_dump_tosa_debug_json(test_data: input_t1): pytest.fail("Failed to load debug JSON file") +@common.parametrize("test_data", Linear.inputs) +def test_dump_tosa_debug_tosa(test_data: input_t1): + with tempfile.TemporaryDirectory() as tmpdir: + pipeline = TosaPipelineINT[input_t1]( + module=Linear(), + test_data=test_data, + aten_op=[], + exir_op=[], + custom_path=tmpdir, + tosa_debug_mode=ArmCompileSpecBuilder.DebugMode.TOSA, + ) + + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + json_output_path = Path(tmpdir) / "debug.json" + + # A JSON file should not be created when TOSA mode used + assert not json_output_path.exists() + + @common.parametrize("test_data", Linear.inputs) def test_dump_tosa_ops(caplog, test_data: input_t1): pipeline = TosaPipelineINT[input_t1](Linear(), test_data, [], []) diff --git a/backends/arm/test/misc/test_debug_hook.py b/backends/arm/test/misc/test_debug_hook.py index c6b0dffffbf..935f3984403 100644 --- a/backends/arm/test/misc/test_debug_hook.py +++ b/backends/arm/test/misc/test_debug_hook.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from types import SimpleNamespace +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.debug.schema import DebugHook, DebugSchema from executorch.backends.arm.test import common @@ -156,8 +157,8 @@ def _compare_node_and_schema(debug_event: DebugSchema, mocked_node): @common.parametrize("test_data", TESTCASES) -def test_debug_hook_add_1(test_data: DebugHookTestCase): - hook = DebugHook() +def test_debug_hook_add_json(test_data: DebugHookTestCase): + hook = DebugHook(ArmCompileSpecBuilder.DebugMode.JSON) hook.add(test_data.mock_node, test_data.tosa_op, test_data.op_id) debug_events = hook._debug_events @@ -166,3 +167,17 @@ def test_debug_hook_add_1(test_data: DebugHookTestCase): _compare_tosa_and_schema(debug_events[0], test_data.tosa_op) _compare_node_and_schema(debug_events[0], test_data.mock_node) + + +@common.parametrize("test_data", TESTCASES) +def test_debug_hook_add_tosa(test_data: DebugHookTestCase): + hook = DebugHook(ArmCompileSpecBuilder.DebugMode.TOSA) + hook.add(test_data.mock_node, test_data.tosa_op, test_data.op_id) + + debug_events = hook._debug_events + assert len(debug_events) == test_data.expected_events + assert len(debug_events[0].torch_info.node_trace) == test_data.num_nodes_traced + + assert debug_events[0].tosa_info is None + + _compare_node_and_schema(debug_events[0], test_data.mock_node) diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 151e37d18a9..d1e400a7fd6 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -14,6 +14,7 @@ from typing import cast, final, List import serializer.tosa_serializer as ts # type: ignore +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.common.debug import debug_fail, debug_tosa_dump from executorch.backends.arm.debug.schema import DebugHook from executorch.backends.arm.process_node import ( @@ -100,7 +101,7 @@ def preprocess( # noqa: C901 debug_hook = None if dump_debug_info is not None: - debug_hook = DebugHook() + debug_hook = DebugHook(ArmCompileSpecBuilder.DebugMode[dump_debug_info]) # TODO: Fix the need to lazily import this. from executorch.backends.arm.operators.node_visitor import get_node_visitors @@ -136,10 +137,11 @@ def preprocess( # noqa: C901 suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"), ) - if debug_hook: - json_output = debug_hook.serialize() - with open(f"{artifact_path}/debug.json", "w") as f: - f.write(json_output) + if debug_hook is not None: + if debug_hook.mode == ArmCompileSpecBuilder.DebugMode.JSON: + json_output = debug_hook.serialize() + with open(f"{artifact_path}/debug.json", "w") as f: + f.write(json_output) # Serialize and return the TOSA flatbuffer. binary = bytes(tosa_graph.serialize())