Skip to content

Commit

Permalink
[ONNX] Diagnostic 'log' and 'log_and_raise_if_error'
Browse files Browse the repository at this point in the history
ghstack-source-id: b0598bd8bd68f96917582c54bfe66a5afda5793d
Pull Request resolved: #100407

patch previous pr
  • Loading branch information
BowenBao committed May 11, 2023
1 parent 502e791 commit ae2287e
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 88 deletions.
41 changes: 38 additions & 3 deletions test/onnx/internal/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import contextlib
import dataclasses
import io
import logging
import typing
import unittest
from typing import AbstractSet, Protocol, Tuple
Expand Down Expand Up @@ -199,7 +200,8 @@ def test_diagnostics_engine_records_diagnosis_reported_outside_of_export(
self._sample_rule,
sample_level,
):
diagnostics.export_context().diagnose(self._sample_rule, sample_level)
diagnostic = infra.Diagnostic(self._sample_rule, sample_level)
diagnostics.export_context().log(diagnostic)

def test_diagnostics_records_python_call_stack(self):
diagnostic = diagnostics.ExportDiagnostic(self._sample_rule, diagnostics.levels.NOTE) # fmt: skip
Expand Down Expand Up @@ -277,12 +279,45 @@ def test_diagnostics_engine_records_diagnosis_with_custom_rules(self):
(custom_rules.custom_rule_2, infra.Level.ERROR), # type: ignore[attr-defined]
},
):
self.context.diagnose(
diagnostic1 = infra.Diagnostic(
custom_rules.custom_rule, infra.Level.WARNING # type: ignore[attr-defined]
)
self.context.diagnose(
self.context.log(diagnostic1)

diagnostic2 = infra.Diagnostic(
custom_rules.custom_rule_2, infra.Level.ERROR # type: ignore[attr-defined]
)
self.context.log(diagnostic2)

def test_diagnostic_context_logs_with_correct_logger_level_based_on_diagnostic_level(
self,
):
diagnostic_logging_level_pairs = [
(infra.Level.NONE, logging.DEBUG),
(infra.Level.NOTE, logging.INFO),
(infra.Level.WARNING, logging.WARNING),
(infra.Level.ERROR, logging.ERROR),
]

for diagnostic_level, expected_logger_level in diagnostic_logging_level_pairs:
with self.assertLogs(
self.context.logger, level=expected_logger_level
) as assert_log_context:
self.context.log(
infra.Diagnostic(
self.rules.rule_without_message_args, diagnostic_level
)
)
for record in assert_log_context.records:
self.assertEqual(record.levelno, expected_logger_level)

def test_diagnostic_context_raises_if_diagnostic_is_error(self):
with self.assertRaises(infra.RuntimeErrorWithDiagnostic):
self.context.log_and_raise_if_error(
infra.Diagnostic(
self.rules.rule_without_message_args, infra.Level.ERROR
)
)


if __name__ == "__main__":
Expand Down
9 changes: 3 additions & 6 deletions torch/onnx/_internal/diagnostics/_diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def __init__(self) -> None:
self._background_context = infra.DiagnosticContext(
name="torch.onnx",
version=torch.__version__,
diagnostic_type=ExportDiagnostic,
)

@property
Expand All @@ -131,9 +130,7 @@ def create_diagnostic_context(
"""
if options is None:
options = infra.DiagnosticOptions()
context = infra.DiagnosticContext(
name, version, options, diagnostic_type=diagnostic_type
)
context = infra.DiagnosticContext(name, version, options)
self.contexts.append(context)
return context

Expand Down Expand Up @@ -201,14 +198,14 @@ def diagnose(
) -> ExportDiagnostic:
"""Creates a diagnostic and record it in the global diagnostic context.
This is a wrapper around `context.add_diagnostic` that uses the global diagnostic
This is a wrapper around `context.log` that uses the global diagnostic
context.
"""
# NOTE: Cannot use `@_beartype.beartype`. It somehow erases the cpp stack frame info.
diagnostic = ExportDiagnostic(
rule, level, message, frames_to_skip=frames_to_skip, **kwargs
)
export_context().add_diagnostic(diagnostic)
export_context().log(diagnostic)
return diagnostic


Expand Down
3 changes: 2 additions & 1 deletion torch/onnx/_internal/diagnostics/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Tag,
ThreadFlowLocation,
)
from .context import Diagnostic, DiagnosticContext
from .context import Diagnostic, DiagnosticContext, RuntimeErrorWithDiagnostic

__all__ = [
"Diagnostic",
Expand All @@ -25,6 +25,7 @@
"Location",
"Rule",
"RuleCollection",
"RuntimeErrorWithDiagnostic",
"Stack",
"StackFrame",
"Tag",
Expand Down
19 changes: 14 additions & 5 deletions torch/onnx/_internal/diagnostics/infra/_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.onnx._internal.diagnostics.infra import formatter, sarif


class Level(enum.Enum):
class Level(enum.IntEnum):
"""The level of a diagnostic.
This class is used to represent the level of a diagnostic. The levels are defined
Expand All @@ -22,12 +22,21 @@ class Level(enum.Enum):
- NOTE: An opportunity for improvement was found.
- WARNING: A potential problem was found.
- ERROR: A serious problem was found.
This level is a subclass of enum.IntEnum, and can be used as an integer. Its integer
value maps to the logging levels in Python's logging module. The mapping is as
follows:
Level.NONE = logging.DEBUG = 10
Level.NOTE = logging.INFO = 20
Level.WARNING = logging.WARNING = 30
Level.ERROR = logging.ERROR = 40
"""

NONE = enum.auto()
NOTE = enum.auto()
WARNING = enum.auto()
ERROR = enum.auto()
NONE = 10
NOTE = 20
WARNING = 30
ERROR = 40


levels = Level
Expand Down
62 changes: 28 additions & 34 deletions torch/onnx/_internal/diagnostics/infra/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import logging

from typing import Callable, Generator, List, Mapping, Optional, Type, TypeVar
from typing import Callable, Generator, List, Mapping, Optional, TypeVar

from typing_extensions import Literal

Expand All @@ -18,10 +18,6 @@
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version


class DiagnosticError(RuntimeError):
pass


# This is a workaround for mypy not supporting Self from typing_extensions.
_Diagnostic = TypeVar("_Diagnostic", bound="Diagnostic")

Expand All @@ -39,6 +35,8 @@ class Diagnostic:
)
additional_message: Optional[str] = None
tags: List[infra.Tag] = dataclasses.field(default_factory=list)
source_exception: Optional[Exception] = None
"""The exception that caused this diagnostic to be created."""

def sarif(self) -> sarif.Result:
"""Returns the SARIF Result representation of this diagnostic."""
Expand Down Expand Up @@ -107,6 +105,11 @@ def with_additional_message(self: _Diagnostic, message: str) -> _Diagnostic:
self.additional_message = f"{self.additional_message}\n{message}"
return self

def with_source_exception(self: _Diagnostic, exception: Exception) -> _Diagnostic:
"""Adds the source exception to the diagnostic."""
self.source_exception = exception
return self

def record_python_call_stack(self, frames_to_skip: int) -> infra.Stack:
"""Records the current Python call stack."""
frames_to_skip += 1 # Skip this function.
Expand Down Expand Up @@ -175,14 +178,21 @@ def pretty_print(
# TODO: print help url to rule at the end.


class RuntimeErrorWithDiagnostic(RuntimeError):
"""Runtime error with enclosed diagnostic information."""

def __init__(self, diagnostic: Diagnostic):
super().__init__(diagnostic.message)
self.diagnostic = diagnostic


@dataclasses.dataclass
class DiagnosticContext:
name: str
version: str
options: infra.DiagnosticOptions = dataclasses.field(
default_factory=infra.DiagnosticOptions
)
diagnostic_type: Type[Diagnostic] = dataclasses.field(default=Diagnostic)
diagnostics: List[Diagnostic] = dataclasses.field(init=False, default_factory=list)
logger: logging.Logger = dataclasses.field(
init=True, default_factory=lambda: logging.getLogger().getChild("diagnostics")
Expand Down Expand Up @@ -233,7 +243,7 @@ def dump(self, file_path: str, compress: bool = False) -> None:
with open(file_path, "w") as f:
f.write(self.to_json())

def add_diagnostic(self, diagnostic: Diagnostic) -> None:
def log(self, diagnostic: Diagnostic) -> None:
"""Adds a diagnostic to the context.
Use this method to add diagnostics that are not created by the context.
Expand All @@ -245,6 +255,15 @@ def add_diagnostic(self, diagnostic: Diagnostic) -> None:
f"Expected diagnostic of type {Diagnostic}, got {type(diagnostic)}"
)
self.diagnostics.append(diagnostic)
self.logger.log(diagnostic.level, diagnostic.message)
self.logger.log(diagnostic.level, diagnostic.additional_message)

def log_and_raise_if_error(self, diagnostic: Diagnostic) -> None:
self.log(diagnostic)
if diagnostic.level == infra.Level.ERROR:
raise RuntimeErrorWithDiagnostic(
diagnostic
) from diagnostic.source_exception

@contextlib.contextmanager
def add_inflight_diagnostic(
Expand All @@ -262,31 +281,6 @@ def add_inflight_diagnostic(
finally:
self._inflight_diagnostics.pop()

def diagnose(
self,
rule: infra.Rule,
level: infra.Level,
message: Optional[str] = None,
**kwargs,
) -> Diagnostic:
"""Creates a diagnostic for the given arguments.
Args:
rule: The rule that triggered the diagnostic.
level: The level of the diagnostic.
message: The message of the diagnostic.
**kwargs: Additional arguments to pass to the Diagnostic constructor.
Returns:
The created diagnostic.
Raises:
ValueError: If the rule is not supported by the tool.
"""
diagnostic = self.diagnostic_type(rule, level, message, **kwargs)
self.add_diagnostic(diagnostic)
return diagnostic

def push_inflight_diagnostic(self, diagnostic: Diagnostic) -> None:
"""Pushes a diagnostic to the inflight diagnostics stack.
Expand All @@ -310,15 +304,15 @@ def inflight_diagnostic(self, rule: Optional[infra.Rule] = None) -> Diagnostic:
if rule is None:
# TODO(bowbao): Create builtin-rules and create diagnostic using that.
if len(self._inflight_diagnostics) <= 0:
raise DiagnosticError("No inflight diagnostics")
raise AssertionError("No inflight diagnostics")

return self._inflight_diagnostics[-1]
else:
# TODO(bowbao): Improve efficiency with Mapping[Rule, List[Diagnostic]]
for diagnostic in reversed(self._inflight_diagnostics):
if diagnostic.rule == rule:
return diagnostic
raise DiagnosticError(f"No inflight diagnostic for rule {rule.name}")
raise AssertionError(f"No inflight diagnostic for rule {rule.name}")

def pretty_print(
self, verbose: Optional[bool] = None, log_level: Optional[infra.Level] = None
Expand Down
38 changes: 8 additions & 30 deletions torch/onnx/_internal/diagnostics/infra/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@_beartype.beartype
def format_message_in_text(fn: Callable, *args: Any, **kwargs: Any) -> str:
return f"{formatter.display_name(fn)}"
return f"{formatter.display_name(fn)}. "


@_beartype.beartype
Expand Down Expand Up @@ -57,30 +57,14 @@ def format_return_values_in_markdown(
]


@_beartype.beartype
def modify_diagnostic(
diag: infra.Diagnostic,
fn: Callable,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
return_values: Any,
) -> None:
return


@_beartype.beartype
def diagnose_call(
rule: infra.Rule,
*,
level: infra.Level = infra.Level.NONE,
exception_report_level: infra.Level = infra.Level.WARNING,
diagnostic_type: Type[infra.Diagnostic] = infra.Diagnostic,
format_argument: Callable[[Any], str] = formatter.format_argument,
diagnostic_message_formatter: MessageFormatterType = format_message_in_text,
diagnostic_modifier: ModifierCallableType = modify_diagnostic,
report_criterion: Callable[
[Callable, Tuple[Any, ...], Dict[str, Any], Any], bool
] = lambda _1, _2, _3, _4: True,
) -> Callable:
def decorator(fn):
@functools.wraps(fn)
Expand Down Expand Up @@ -138,30 +122,24 @@ def wrapper(*args, **kwargs):
]

return_values: Any = None
report_diagnostic: bool = True
with ctx.add_inflight_diagnostic(diag) as diag:
try:
return_values = fn(*args, **kwargs)
additional_messages.append(
format_return_values_in_markdown(return_values, format_argument)
)
report_diagnostic = report_criterion(
fn, args, kwargs, return_values
)
return return_values
except Exception as e:
# Record exception.
report_diagnostic = True
diag.level = exception_report_level
diag.level = infra.levels.ERROR
# TODO(bowbao): Message emitting api.
diag.message = diag.message or ""
diag.message += f"Raised from:\n {type(e).__name__}: {e}"
diag.with_source_exception(e)
additional_messages.append(format_exception_in_markdown(e))
raise
finally:
if report_diagnostic:
diag.with_additional_message(
"\n".join(additional_messages).strip()
)
diagnostic_modifier(diag, fn, args, kwargs, return_values)
ctx.add_diagnostic(diag)
diag.with_additional_message("\n".join(additional_messages).strip())
ctx.log_and_raise_if_error(diag)

return wrapper

Expand Down
14 changes: 13 additions & 1 deletion torch/onnx/_internal/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,16 @@ def __init__(self, package_name: str, message: str):
self.package_name = package_name


class OnnxExporterError(RuntimeError):
"""Raised when an ONNX exporter error occurs. Diagnostic context is enclosed."""

diagnostic_context: Final[infra.DiagnosticContext]

def __init__(self, diagnostic_context: infra.DiagnosticContext, message: str):
super().__init__(message)
self.diagnostic_context = diagnostic_context


@_beartype.beartype
def _assert_dependencies(export_options: ResolvedExportOptions):
logger = export_options.logger
Expand Down Expand Up @@ -619,7 +629,9 @@ def dynamo_export(
f"Failed to export the model to ONNX. Generating SARIF report at {sarif_report_path}. "
f"Please report a bug on PyTorch Github: {_PYTORCH_GITHUB_ISSUES_URL}"
)
raise RuntimeError(message) from e
raise OnnxExporterError(
resolved_export_options.diagnostic_context, message
) from e


__all__ = [
Expand Down

0 comments on commit ae2287e

Please sign in to comment.