Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions torch/onnx/_internal/diagnostics/infra/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import functools

import inspect
import traceback
from typing import Any, Callable, Dict, Mapping, Tuple
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple

from torch.onnx._internal import _beartype
from torch.onnx._internal.diagnostics.infra import _infra, formatter
Expand Down Expand Up @@ -41,13 +43,24 @@ def python_call_stack(frames_to_skip: int = 0, frames_to_log: int = 16) -> _infr
return stack


@functools.lru_cache()
def _function_source_info(fn: Callable) -> Tuple[Sequence[str], int, Optional[str]]:
"""Returns the source lines, line number, and source file path for the given function.

Essentially, inspect.getsourcelines() and inspect.getsourcefile() combined.
Caching is applied to reduce the performance impact of this function.
"""
source_lines, lineno = inspect.getsourcelines(fn)
return source_lines, lineno, inspect.getsourcefile(fn)


@_beartype.beartype
def function_location(fn: Callable) -> _infra.Location:
"""Returns a Location for the given function."""
source_lines, lineno = inspect.getsourcelines(fn)
source_lines, lineno, uri = _function_source_info(fn)
snippet = source_lines[0].strip() if len(source_lines) > 0 else "<unknown>"
return _infra.Location(
uri=inspect.getsourcefile(fn),
uri=uri,
line=lineno,
snippet=snippet,
message=formatter.display_name(fn),
Expand Down
13 changes: 11 additions & 2 deletions torch/onnx/_internal/fx/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,23 @@ def format_argument(obj: Any) -> str:
return result_str


# NOTE: EDITING BELOW? READ THIS FIRST!
#
# The below functions register the `format_argument` function for different types via
# `functools.singledispatch` registry. These are invoked by the diagnostics system
# when recording function arguments and return values as part of a diagnostic.
# Hence, code with heavy workload should be avoided. Things to avoid for example:
# `torch.fx.GraphModule.print_readable()`.


@_format_argument.register
def _torch_nn_module(obj: torch.nn.Module) -> str:
return f"{obj.__class__.__name__}"
return f"torch.nn.Module({obj.__class__.__name__})"


@_format_argument.register
def _torch_fx_graph_module(obj: torch.fx.GraphModule) -> str:
return f"{obj.print_readable(print_output=False)}"
return f"torch.fx.GraphModule({obj.__class__.__name__})"


@_format_argument.register
Expand Down