diff --git a/backends/arm/debug/__init__.py b/backends/arm/debug/__init__.py new file mode 100644 index 00000000000..c8d1c683da3 --- /dev/null +++ b/backends/arm/debug/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/arm/debug/schema.py b/backends/arm/debug/schema.py new file mode 100644 index 00000000000..bb06ddba864 --- /dev/null +++ b/backends/arm/debug/schema.py @@ -0,0 +1,133 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import json + +from dataclasses import asdict, dataclass +from typing import Any + +import serializer.tosa_serializer as ts # type: ignore +import torch + +from torch.fx.traceback import NodeSource + + +@dataclass +class TosaDebugSchema: + node_name: str + operator_name: str + operator_id: int + + +@dataclass +class ATenDebugSchema: + node_name: str + operator_name: str + + @staticmethod + def from_node(node: torch.fx.Node) -> ATenDebugSchema: + # node.target is Union[Callable[..., Any], str], so we need to access this correctly depending on the type + if callable(node.target): + operator_name = node.target.__name__ + else: + operator_name = node.target + + return ATenDebugSchema(node_name=node.name, operator_name=operator_name) + + +@dataclass +class TorchDebugSchema: + stack_trace: list[str] + node_trace: list[dict[str, Any]] | str + nn_module_stack: dict[str, Any] | str + torch_fn: tuple[str, str] | str + + @staticmethod + def serialize_node_trace(node_trace: list[NodeSource]) -> list[dict[str, Any]]: + """Flatten the from_node dictionary to remove nesting.""" + flattened = [] + node_stack = [] + + for n in node_trace: + node_stack.append((n, -1)) + + while len(node_stack) > 0: + node, parent_id = node_stack.pop() + flattened.append( + { + "name": node.name, + "target": node.target, + "graph_id": node.graph_id, + "pass_name": node.pass_name, + "action": node._get_action_string(), + "parent_graph_id": parent_id, + } + ) + + for n in node.from_node: + node_stack.append((n, node.graph_id)) + + return flattened + + @staticmethod + def from_node(node: torch.fx.Node) -> TorchDebugSchema: + node_trace: str | list[dict[str, Any]] = "No node trace available." + + if "from_node" in node.meta: + # Flatten the node_trace dictionary, so there is no nesting + node_trace = TorchDebugSchema.serialize_node_trace(node.meta["from_node"]) + + return TorchDebugSchema( + stack_trace=node.meta.get("stack_trace", "No stack trace available").split( + "\n" + ), + node_trace=node_trace, + nn_module_stack=node.meta.get( + "nn_module_stack", "No module stack trace available" + ), + torch_fn=node.meta.get("torch_fn", "No torch_fn available"), + ) + + +@dataclass +class DebugSchema: + event_id: int + aten_info: ATenDebugSchema + tosa_info: TosaDebugSchema + torch_info: TorchDebugSchema + + +class DebugHook: + def __init__(self) -> None: + self._debug_events: list[DebugSchema] = [] + self.__op_id_to_name = {} + + # Build up a mapping from TOSA 1.0 operator IDs to their names + for name, val in vars(ts.Op).items(): + self.__op_id_to_name[val] = name + + def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> None: + tosa_debug_info = TosaDebugSchema( + node_name=str(tosa_op), + operator_name=self.__op_id_to_name[tosa_op_id], + operator_id=tosa_op_id, + ) + + aten_debug_info = ATenDebugSchema.from_node(node) + torch_debug_info = TorchDebugSchema.from_node(node) + + self._debug_events.append( + DebugSchema( + event_id=len(self._debug_events), + aten_info=aten_debug_info, + tosa_info=tosa_debug_info, + torch_info=torch_debug_info, + ) + ) + + def serialize(self) -> str: + return json.dumps([asdict(event) for event in self._debug_events], indent=4) diff --git a/backends/arm/test/misc/test_debug_hook.py b/backends/arm/test/misc/test_debug_hook.py new file mode 100644 index 00000000000..c6b0dffffbf --- /dev/null +++ b/backends/arm/test/misc/test_debug_hook.py @@ -0,0 +1,168 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from types import SimpleNamespace + +from executorch.backends.arm.debug.schema import DebugHook, DebugSchema +from executorch.backends.arm.test import common + + +@dataclass +class DebugHookTestCase: + mock_node: SimpleNamespace + tosa_op: str + op_id: int + expected_events: int + num_nodes_traced: int + + +def create_mock_node_1(): + def _get_action_str() -> str: + return "create" + + from_node_2 = SimpleNamespace( + name="convolution", + target="aten.convolution.default", + graph_id=6052414368, + pass_name="ExportedProgram.module()", + action="create", + from_node=[], + _get_action_string=_get_action_str, + ) + + from_node_1 = SimpleNamespace( + name="convolution", + target="aten.convolution.default", + graph_id=5705954832, + pass_name="Interpreter_PropagateUnbackedSymInts", + action="create", + from_node=[from_node_2], + _get_action_string=_get_action_str, + ) + + fx_node_mock = SimpleNamespace( + name="aten_convolution_default", + target="aten.convolution.default", + meta={ + "stack_trace": 'File "models/model.py", line 221, in forward\nreturn self.features(x)', + "nn_module_stack": {"__self__": ["", "model.Model"]}, + "torch_fn": ("conv2d", "builtin_function_or_method.conv2d"), + "from_node": [from_node_1], + }, + ) + + return fx_node_mock + + +def create_mock_node_2(): + def _get_action_str() -> str: + return "create" + + from_node_1 = SimpleNamespace( + name="convolution", + target="aten.convolution.default", + graph_id=5705954832, + pass_name="Interpreter_PropagateUnbackedSymInts", + action="create", + from_node=[], + _get_action_string=_get_action_str, + ) + + fx_node_mock = SimpleNamespace( + name="aten_convolution_default", + target="aten.convolution.default", + meta={ + "from_node": [from_node_1], + }, + ) + + return fx_node_mock + + +def create_mock_node_3(): + fx_node_mock = SimpleNamespace( + name="aten_convolution_default", + target="aten.convolution.default", + meta={ + "from_node": [], + }, + ) + + return fx_node_mock + + +def _compare_tosa_and_schema(debug_event: DebugSchema, tosa_op): + tosa_info = debug_event.tosa_info + + assert tosa_info.node_name == tosa_op + + # The mapping between op_ids to operator names could change + # So just check operator_name is a string + assert isinstance(tosa_info.operator_name, str) + + +def _compare_node_and_schema(debug_event: DebugSchema, mocked_node): + # Check aten info + aten_info = debug_event.aten_info + + assert aten_info.node_name == mocked_node.name + assert aten_info.operator_name == mocked_node.target + + # Check torch info + torch_info = debug_event.torch_info + + if "nn_module_stack" in mocked_node.meta: + assert torch_info.nn_module_stack == mocked_node.meta["nn_module_stack"] + else: + assert torch_info.nn_module_stack == "No module stack trace available" + + if "stack_trace" in mocked_node.meta: + assert torch_info.stack_trace == mocked_node.meta["stack_trace"].split("\n") + else: + assert torch_info.stack_trace == ["No stack trace available"] + + if "torch_fn" in mocked_node.meta: + assert torch_info.torch_fn == mocked_node.meta["torch_fn"] + else: + assert torch_info.torch_fn == "No torch_fn available" + + +TESTCASES = { + "mocked_node": DebugHookTestCase( + mock_node=create_mock_node_1(), + tosa_op="layer-1", + op_id=3, + expected_events=1, + num_nodes_traced=2, + ), + "mocked_node_partially_empty": DebugHookTestCase( + mock_node=create_mock_node_2(), + tosa_op="layer-1", + op_id=1, + expected_events=1, + num_nodes_traced=1, + ), + "mocked_node_all_empty": DebugHookTestCase( + mock_node=create_mock_node_3(), + tosa_op="layer-2", + op_id=1, + expected_events=1, + num_nodes_traced=0, + ), +} + + +@common.parametrize("test_data", TESTCASES) +def test_debug_hook_add_1(test_data: DebugHookTestCase): + hook = DebugHook() + hook.add(test_data.mock_node, test_data.tosa_op, test_data.op_id) + + debug_events = hook._debug_events + assert len(debug_events) == test_data.expected_events + assert len(debug_events[0].torch_info.node_trace) == test_data.num_nodes_traced + + _compare_tosa_and_schema(debug_events[0], test_data.tosa_op) + _compare_node_and_schema(debug_events[0], test_data.mock_node)