Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/arm/debug/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
133 changes: 133 additions & 0 deletions backends/arm/debug/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import json

from dataclasses import asdict, dataclass
from typing import Any

import serializer.tosa_serializer as ts # type: ignore
import torch

from torch.fx.traceback import NodeSource


@dataclass
class TosaDebugSchema:
node_name: str
operator_name: str
operator_id: int


@dataclass
class ATenDebugSchema:
node_name: str
operator_name: str

@staticmethod
def from_node(node: torch.fx.Node) -> ATenDebugSchema:
# node.target is Union[Callable[..., Any], str], so we need to access this correctly depending on the type
if callable(node.target):
operator_name = node.target.__name__
else:
operator_name = node.target

return ATenDebugSchema(node_name=node.name, operator_name=operator_name)


@dataclass
class TorchDebugSchema:
stack_trace: list[str]
node_trace: list[dict[str, Any]] | str
nn_module_stack: dict[str, Any] | str
torch_fn: tuple[str, str] | str

@staticmethod
def serialize_node_trace(node_trace: list[NodeSource]) -> list[dict[str, Any]]:
"""Flatten the from_node dictionary to remove nesting."""
flattened = []
node_stack = []

for n in node_trace:
node_stack.append((n, -1))

while len(node_stack) > 0:
node, parent_id = node_stack.pop()
flattened.append(
{
"name": node.name,
"target": node.target,
"graph_id": node.graph_id,
"pass_name": node.pass_name,
"action": node._get_action_string(),
"parent_graph_id": parent_id,
}
)

for n in node.from_node:
node_stack.append((n, node.graph_id))

return flattened

@staticmethod
def from_node(node: torch.fx.Node) -> TorchDebugSchema:
node_trace: str | list[dict[str, Any]] = "No node trace available."

if "from_node" in node.meta:
# Flatten the node_trace dictionary, so there is no nesting
node_trace = TorchDebugSchema.serialize_node_trace(node.meta["from_node"])

return TorchDebugSchema(
stack_trace=node.meta.get("stack_trace", "No stack trace available").split(
"\n"
),
node_trace=node_trace,
nn_module_stack=node.meta.get(
"nn_module_stack", "No module stack trace available"
),
torch_fn=node.meta.get("torch_fn", "No torch_fn available"),
)


@dataclass
class DebugSchema:
event_id: int
aten_info: ATenDebugSchema
tosa_info: TosaDebugSchema
torch_info: TorchDebugSchema


class DebugHook:
def __init__(self) -> None:
self._debug_events: list[DebugSchema] = []
self.__op_id_to_name = {}

# Build up a mapping from TOSA 1.0 operator IDs to their names
for name, val in vars(ts.Op).items():
self.__op_id_to_name[val] = name

def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> None:
tosa_debug_info = TosaDebugSchema(
node_name=str(tosa_op),
operator_name=self.__op_id_to_name[tosa_op_id],
operator_id=tosa_op_id,
)

aten_debug_info = ATenDebugSchema.from_node(node)
torch_debug_info = TorchDebugSchema.from_node(node)

self._debug_events.append(
DebugSchema(
event_id=len(self._debug_events),
aten_info=aten_debug_info,
tosa_info=tosa_debug_info,
torch_info=torch_debug_info,
)
)

def serialize(self) -> str:
return json.dumps([asdict(event) for event in self._debug_events], indent=4)
168 changes: 168 additions & 0 deletions backends/arm/test/misc/test_debug_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from types import SimpleNamespace

from executorch.backends.arm.debug.schema import DebugHook, DebugSchema
from executorch.backends.arm.test import common


@dataclass
class DebugHookTestCase:
mock_node: SimpleNamespace
tosa_op: str
op_id: int
expected_events: int
num_nodes_traced: int


def create_mock_node_1():
def _get_action_str() -> str:
return "create"

from_node_2 = SimpleNamespace(
name="convolution",
target="aten.convolution.default",
graph_id=6052414368,
pass_name="ExportedProgram.module()",
action="create",
from_node=[],
_get_action_string=_get_action_str,
)

from_node_1 = SimpleNamespace(
name="convolution",
target="aten.convolution.default",
graph_id=5705954832,
pass_name="Interpreter_PropagateUnbackedSymInts",
action="create",
from_node=[from_node_2],
_get_action_string=_get_action_str,
)

fx_node_mock = SimpleNamespace(
name="aten_convolution_default",
target="aten.convolution.default",
meta={
"stack_trace": 'File "models/model.py", line 221, in forward\nreturn self.features(x)',
"nn_module_stack": {"__self__": ["", "model.Model"]},
"torch_fn": ("conv2d", "builtin_function_or_method.conv2d"),
"from_node": [from_node_1],
},
)

return fx_node_mock


def create_mock_node_2():
def _get_action_str() -> str:
return "create"

from_node_1 = SimpleNamespace(
name="convolution",
target="aten.convolution.default",
graph_id=5705954832,
pass_name="Interpreter_PropagateUnbackedSymInts",
action="create",
from_node=[],
_get_action_string=_get_action_str,
)

fx_node_mock = SimpleNamespace(
name="aten_convolution_default",
target="aten.convolution.default",
meta={
"from_node": [from_node_1],
},
)

return fx_node_mock


def create_mock_node_3():
fx_node_mock = SimpleNamespace(
name="aten_convolution_default",
target="aten.convolution.default",
meta={
"from_node": [],
},
)

return fx_node_mock


def _compare_tosa_and_schema(debug_event: DebugSchema, tosa_op):
tosa_info = debug_event.tosa_info

assert tosa_info.node_name == tosa_op

# The mapping between op_ids to operator names could change
# So just check operator_name is a string
assert isinstance(tosa_info.operator_name, str)


def _compare_node_and_schema(debug_event: DebugSchema, mocked_node):
# Check aten info
aten_info = debug_event.aten_info

assert aten_info.node_name == mocked_node.name
assert aten_info.operator_name == mocked_node.target

# Check torch info
torch_info = debug_event.torch_info

if "nn_module_stack" in mocked_node.meta:
assert torch_info.nn_module_stack == mocked_node.meta["nn_module_stack"]
else:
assert torch_info.nn_module_stack == "No module stack trace available"

if "stack_trace" in mocked_node.meta:
assert torch_info.stack_trace == mocked_node.meta["stack_trace"].split("\n")
else:
assert torch_info.stack_trace == ["No stack trace available"]

if "torch_fn" in mocked_node.meta:
assert torch_info.torch_fn == mocked_node.meta["torch_fn"]
else:
assert torch_info.torch_fn == "No torch_fn available"


TESTCASES = {
"mocked_node": DebugHookTestCase(
mock_node=create_mock_node_1(),
tosa_op="layer-1",
op_id=3,
expected_events=1,
num_nodes_traced=2,
),
"mocked_node_partially_empty": DebugHookTestCase(
mock_node=create_mock_node_2(),
tosa_op="layer-1",
op_id=1,
expected_events=1,
num_nodes_traced=1,
),
"mocked_node_all_empty": DebugHookTestCase(
mock_node=create_mock_node_3(),
tosa_op="layer-2",
op_id=1,
expected_events=1,
num_nodes_traced=0,
),
}


@common.parametrize("test_data", TESTCASES)
def test_debug_hook_add_1(test_data: DebugHookTestCase):
hook = DebugHook()
hook.add(test_data.mock_node, test_data.tosa_op, test_data.op_id)

debug_events = hook._debug_events
assert len(debug_events) == test_data.expected_events
assert len(debug_events[0].torch_info.node_trace) == test_data.num_nodes_traced

_compare_tosa_and_schema(debug_events[0], test_data.tosa_op)
_compare_node_and_schema(debug_events[0], test_data.mock_node)
Loading