Skip to content

Commit

Permalink
[ONNX] Diagnostic 'log' and 'log_and_raise_if_error'
Browse files Browse the repository at this point in the history
ghstack-source-id: c445bb98c057ce850d8b89d14aaee7d9b3e6b97a
Pull Request resolved: #100407
  • Loading branch information
BowenBao committed May 1, 2023
1 parent e12cbd7 commit 69549b6
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 81 deletions.
41 changes: 38 additions & 3 deletions test/onnx/internal/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import contextlib
import dataclasses
import io
import logging
import typing
import unittest
from typing import AbstractSet, Protocol, Tuple
Expand Down Expand Up @@ -199,7 +200,8 @@ def test_diagnostics_engine_records_diagnosis_reported_outside_of_export(
self._sample_rule,
sample_level,
):
diagnostics.export_context().diagnose(self._sample_rule, sample_level)
diagnostic = infra.Diagnostic(self._sample_rule, sample_level)
diagnostics.export_context().log(diagnostic)

def test_diagnostics_records_python_call_stack(self):
diagnostic = diagnostics.ExportDiagnostic(self._sample_rule, diagnostics.levels.NOTE) # fmt: skip
Expand Down Expand Up @@ -277,12 +279,45 @@ def test_diagnostics_engine_records_diagnosis_with_custom_rules(self):
(custom_rules.custom_rule_2, infra.Level.ERROR), # type: ignore[attr-defined]
},
):
self.context.diagnose(
diagnostic1 = infra.Diagnostic(
custom_rules.custom_rule, infra.Level.WARNING # type: ignore[attr-defined]
)
self.context.diagnose(
self.context.log(diagnostic1)

diagnostic2 = infra.Diagnostic(
custom_rules.custom_rule_2, infra.Level.ERROR # type: ignore[attr-defined]
)
self.context.log(diagnostic2)

def test_diagnostic_context_logs_with_correct_logger_level_based_on_diagnostic_level(
self,
):
diagnostic_logging_level_pairs = [
(infra.Level.NONE, logging.DEBUG),
(infra.Level.NOTE, logging.INFO),
(infra.Level.WARNING, logging.WARNING),
(infra.Level.ERROR, logging.ERROR),
]

for diagnostic_level, expected_logger_level in diagnostic_logging_level_pairs:
with self.assertLogs(
self.context.logger, level=expected_logger_level
) as assert_log_context:
self.context.log(
infra.Diagnostic(
self.rules.rule_without_message_args, diagnostic_level
)
)
for record in assert_log_context.records:
self.assertEqual(record.levelno, expected_logger_level)

def test_diagnostic_context_raises_if_diagnostic_is_error(self):
with self.assertRaises(infra.DiagnosticError):
self.context.log_and_raise_if_error(
infra.Diagnostic(
self.rules.rule_without_message_args, infra.Level.ERROR
)
)


if __name__ == "__main__":
Expand Down
9 changes: 3 additions & 6 deletions torch/onnx/_internal/diagnostics/_diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def __init__(self) -> None:
self._background_context = infra.DiagnosticContext(
name="torch.onnx",
version=torch.__version__,
diagnostic_type=ExportDiagnostic,
)

@property
Expand All @@ -131,9 +130,7 @@ def create_diagnostic_context(
"""
if options is None:
options = infra.DiagnosticOptions()
context = infra.DiagnosticContext(
name, version, options, diagnostic_type=diagnostic_type
)
context = infra.DiagnosticContext(name, version, options)
self.contexts.append(context)
return context

Expand Down Expand Up @@ -201,14 +198,14 @@ def diagnose(
) -> ExportDiagnostic:
"""Creates a diagnostic and record it in the global diagnostic context.
This is a wrapper around `context.add_diagnostic` that uses the global diagnostic
This is a wrapper around `context.log` that uses the global diagnostic
context.
"""
# NOTE: Cannot use `@_beartype.beartype`. It somehow erases the cpp stack frame info.
diagnostic = ExportDiagnostic(
rule, level, message, frames_to_skip=frames_to_skip, **kwargs
)
export_context().add_diagnostic(diagnostic)
export_context().log(diagnostic)
return diagnostic


Expand Down
3 changes: 2 additions & 1 deletion torch/onnx/_internal/diagnostics/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
Tag,
ThreadFlowLocation,
)
from .context import Diagnostic, DiagnosticContext
from .context import Diagnostic, DiagnosticContext, DiagnosticError

__all__ = [
"Diagnostic",
"DiagnosticContext",
"DiagnosticError",
"DiagnosticOptions",
"Graph",
"Invocation",
Expand Down
19 changes: 14 additions & 5 deletions torch/onnx/_internal/diagnostics/infra/_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.onnx._internal.diagnostics.infra import formatter, sarif


class Level(enum.Enum):
class Level(enum.IntEnum):
"""The level of a diagnostic.
This class is used to represent the level of a diagnostic. The levels are defined
Expand All @@ -22,12 +22,21 @@ class Level(enum.Enum):
- NOTE: An opportunity for improvement was found.
- WARNING: A potential problem was found.
- ERROR: A serious problem was found.
This level is a subclass of enum.IntEnum, and can be used as an integer. Its integer
value maps to the logging levels in Python's logging module. The mapping is as
follows:
Level.NONE = logging.DEBUG = 10
Level.NOTE = logging.INFO = 20
Level.WARNING = logging.WARNING = 30
Level.ERROR = logging.ERROR = 40
"""

NONE = enum.auto()
NOTE = enum.auto()
WARNING = enum.auto()
ERROR = enum.auto()
NONE = 10
NOTE = 20
WARNING = 30
ERROR = 40


levels = Level
Expand Down
60 changes: 26 additions & 34 deletions torch/onnx/_internal/diagnostics/infra/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import logging

from typing import Callable, Generator, List, Mapping, Optional, Type, TypeVar
from typing import Callable, Generator, List, Mapping, Optional, TypeVar

from typing_extensions import Literal

Expand All @@ -18,10 +18,6 @@
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version


class DiagnosticError(RuntimeError):
pass


# This is a workaround for mypy not supporting Self from typing_extensions.
_Diagnostic = TypeVar("_Diagnostic", bound="Diagnostic")

Expand All @@ -39,6 +35,8 @@ class Diagnostic:
)
additional_message: Optional[str] = None
tags: List[infra.Tag] = dataclasses.field(default_factory=list)
source_exception: Optional[Exception] = None
"""The exception that caused this diagnostic to be created."""

def sarif(self) -> sarif.Result:
"""Returns the SARIF Result representation of this diagnostic."""
Expand Down Expand Up @@ -107,6 +105,11 @@ def with_additional_message(self: _Diagnostic, message: str) -> _Diagnostic:
self.additional_message = f"{self.additional_message}\n{message}"
return self

def with_source_exception(self: _Diagnostic, exception: Exception) -> _Diagnostic:
"""Adds the source exception to the diagnostic."""
self.source_exception = exception
return self

def record_python_call_stack(self, frames_to_skip: int) -> infra.Stack:
"""Records the current Python call stack."""
frames_to_skip += 1 # Skip this function.
Expand Down Expand Up @@ -175,14 +178,21 @@ def pretty_print(
# TODO: print help url to rule at the end.


class DiagnosticError(RuntimeError):
"""Raised when a diagnostic is raised."""

def __init__(self, diagnostic: Diagnostic):
super().__init__(diagnostic.message)
self.diagnostic = diagnostic


@dataclasses.dataclass
class DiagnosticContext:
name: str
version: str
options: infra.DiagnosticOptions = dataclasses.field(
default_factory=infra.DiagnosticOptions
)
diagnostic_type: Type[Diagnostic] = dataclasses.field(default=Diagnostic)
diagnostics: List[Diagnostic] = dataclasses.field(init=False, default_factory=list)
logger: logging.Logger = dataclasses.field(
init=True, default_factory=lambda: logging.getLogger().getChild("diagnostics")
Expand Down Expand Up @@ -233,7 +243,7 @@ def dump(self, file_path: str, compress: bool = False) -> None:
with open(file_path, "w") as f:
f.write(self.to_json())

def add_diagnostic(self, diagnostic: Diagnostic) -> None:
def log(self, diagnostic: Diagnostic) -> None:
"""Adds a diagnostic to the context.
Use this method to add diagnostics that are not created by the context.
Expand All @@ -245,6 +255,13 @@ def add_diagnostic(self, diagnostic: Diagnostic) -> None:
f"Expected diagnostic of type {Diagnostic}, got {type(diagnostic)}"
)
self.diagnostics.append(diagnostic)
self.logger.log(diagnostic.level, diagnostic.message)
self.logger.log(diagnostic.level, diagnostic.additional_message)

def log_and_raise_if_error(self, diagnostic: Diagnostic) -> None:
self.log(diagnostic)
if diagnostic.level == infra.Level.ERROR:
raise DiagnosticError(diagnostic) from diagnostic.source_exception

@contextlib.contextmanager
def add_inflight_diagnostic(
Expand All @@ -262,31 +279,6 @@ def add_inflight_diagnostic(
finally:
self._inflight_diagnostics.pop()

def diagnose(
self,
rule: infra.Rule,
level: infra.Level,
message: Optional[str] = None,
**kwargs,
) -> Diagnostic:
"""Creates a diagnostic for the given arguments.
Args:
rule: The rule that triggered the diagnostic.
level: The level of the diagnostic.
message: The message of the diagnostic.
**kwargs: Additional arguments to pass to the Diagnostic constructor.
Returns:
The created diagnostic.
Raises:
ValueError: If the rule is not supported by the tool.
"""
diagnostic = self.diagnostic_type(rule, level, message, **kwargs)
self.add_diagnostic(diagnostic)
return diagnostic

def push_inflight_diagnostic(self, diagnostic: Diagnostic) -> None:
"""Pushes a diagnostic to the inflight diagnostics stack.
Expand All @@ -310,15 +302,15 @@ def inflight_diagnostic(self, rule: Optional[infra.Rule] = None) -> Diagnostic:
if rule is None:
# TODO(bowbao): Create builtin-rules and create diagnostic using that.
if len(self._inflight_diagnostics) <= 0:
raise DiagnosticError("No inflight diagnostics")
raise AssertionError("No inflight diagnostics")

return self._inflight_diagnostics[-1]
else:
# TODO(bowbao): Improve efficiency with Mapping[Rule, List[Diagnostic]]
for diagnostic in reversed(self._inflight_diagnostics):
if diagnostic.rule == rule:
return diagnostic
raise DiagnosticError(f"No inflight diagnostic for rule {rule.name}")
raise AssertionError(f"No inflight diagnostic for rule {rule.name}")

def pretty_print(
self, verbose: Optional[bool] = None, log_level: Optional[infra.Level] = None
Expand Down
30 changes: 3 additions & 27 deletions torch/onnx/_internal/diagnostics/infra/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,6 @@ def format_return_values_in_markdown(
]


@_beartype.beartype
def modify_diagnostic(
diag: infra.Diagnostic,
fn: Callable,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
return_values: Any,
) -> None:
return


@_beartype.beartype
def diagnose_call(
rule: infra.Rule,
Expand All @@ -77,10 +66,6 @@ def diagnose_call(
diagnostic_type: Type[infra.Diagnostic] = infra.Diagnostic,
format_argument: Callable[[Any], str] = formatter.format_argument,
diagnostic_message_formatter: MessageFormatterType = format_message_in_text,
diagnostic_modifier: ModifierCallableType = modify_diagnostic,
report_criterion: Callable[
[Callable, Tuple[Any, ...], Dict[str, Any], Any], bool
] = lambda _1, _2, _3, _4: True,
) -> Callable:
def decorator(fn):
@functools.wraps(fn)
Expand Down Expand Up @@ -138,30 +123,21 @@ def wrapper(*args, **kwargs):
]

return_values: Any = None
report_diagnostic: bool = True
with ctx.add_inflight_diagnostic(diag) as diag:
try:
return_values = fn(*args, **kwargs)
additional_messages.append(
format_return_values_in_markdown(return_values, format_argument)
)
report_diagnostic = report_criterion(
fn, args, kwargs, return_values
)
return return_values
except Exception as e:
# Record exception.
report_diagnostic = True
diag.level = exception_report_level
diag.with_source_exception(e)
additional_messages.append(format_exception_in_markdown(e))
raise
finally:
if report_diagnostic:
diag.with_additional_message(
"\n".join(additional_messages).strip()
)
diagnostic_modifier(diag, fn, args, kwargs, return_values)
ctx.add_diagnostic(diag)
diag.with_additional_message("\n".join(additional_messages).strip())
ctx.log_and_raise_if_error(diag)

return wrapper

Expand Down
12 changes: 10 additions & 2 deletions torch/onnx/_internal/fx/diagnostics.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from __future__ import annotations

import functools
from typing import Any

import onnxscript # type: ignore[import]
from onnxscript.function_libs.torch_lib import graph_building # type: ignore[import]

import torch
import torch.fx
from torch.onnx._internal import diagnostics
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import decorator, formatter, utils

_LENGTH_LIMIT: int = 80
_LENGTH_LIMIT: int = 89

# NOTE(bowbao): This is a shim over `torch.onnx._internal.diagnostics`, which is
# used in `torch.onnx`, and loaded with `torch`. Hence anything related to `onnxscript`
Expand Down Expand Up @@ -39,7 +42,7 @@ def format_argument(obj: Any) -> str:
)
)
diag.with_location(utils.function_location(formatter))
diagnostics.export_context().add_diagnostic(diag)
diagnostics.export_context().log(diag)

return result_str

Expand All @@ -63,6 +66,11 @@ def _torch_fx_graph_module(obj: torch.fx.GraphModule) -> str:
return f"torch.fx.GraphModule({obj.__class__.__name__})"


@_format_argument.register
def _torch_fx_node(obj: torch.fx.Node) -> str:
return f"torch.fx.Node(target: {obj.target})"


@_format_argument.register
def _torch_tensor(obj: torch.Tensor) -> str:
return f"Tensor(shape={obj.shape}, dtype={obj.dtype})"
Expand Down

0 comments on commit 69549b6

Please sign in to comment.