Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Diagnostic 'log' and 'log_and_raise_if_error' #100407

Closed
wants to merge 7 commits into from
41 changes: 38 additions & 3 deletions test/onnx/internal/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import contextlib
import dataclasses
import io
import logging
import typing
import unittest
from typing import AbstractSet, Protocol, Tuple
Expand Down Expand Up @@ -199,7 +200,8 @@ def test_diagnostics_engine_records_diagnosis_reported_outside_of_export(
self._sample_rule,
sample_level,
):
diagnostics.export_context().diagnose(self._sample_rule, sample_level)
diagnostic = infra.Diagnostic(self._sample_rule, sample_level)
diagnostics.export_context().log(diagnostic)

def test_diagnostics_records_python_call_stack(self):
diagnostic = diagnostics.ExportDiagnostic(self._sample_rule, diagnostics.levels.NOTE) # fmt: skip
Expand Down Expand Up @@ -277,12 +279,45 @@ def test_diagnostics_engine_records_diagnosis_with_custom_rules(self):
(custom_rules.custom_rule_2, infra.Level.ERROR), # type: ignore[attr-defined]
},
):
self.context.diagnose(
diagnostic1 = infra.Diagnostic(
custom_rules.custom_rule, infra.Level.WARNING # type: ignore[attr-defined]
)
self.context.diagnose(
self.context.log(diagnostic1)

diagnostic2 = infra.Diagnostic(
custom_rules.custom_rule_2, infra.Level.ERROR # type: ignore[attr-defined]
)
self.context.log(diagnostic2)

def test_diagnostic_context_logs_with_correct_logger_level_based_on_diagnostic_level(
self,
):
diagnostic_logging_level_pairs = [
(infra.Level.NONE, logging.DEBUG),
(infra.Level.NOTE, logging.INFO),
(infra.Level.WARNING, logging.WARNING),
(infra.Level.ERROR, logging.ERROR),
]

for diagnostic_level, expected_logger_level in diagnostic_logging_level_pairs:
with self.assertLogs(
self.context.logger, level=expected_logger_level
) as assert_log_context:
self.context.log(
infra.Diagnostic(
self.rules.rule_without_message_args, diagnostic_level
)
)
for record in assert_log_context.records:
self.assertEqual(record.levelno, expected_logger_level)

def test_diagnostic_context_raises_if_diagnostic_is_error(self):
with self.assertRaises(infra.RuntimeErrorWithDiagnostic):
self.context.log_and_raise_if_error(
infra.Diagnostic(
self.rules.rule_without_message_args, infra.Level.ERROR
)
)


if __name__ == "__main__":
Expand Down
9 changes: 3 additions & 6 deletions torch/onnx/_internal/diagnostics/_diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def __init__(self) -> None:
self._background_context = infra.DiagnosticContext(
name="torch.onnx",
version=torch.__version__,
diagnostic_type=ExportDiagnostic,
)

@property
Expand All @@ -131,9 +130,7 @@ def create_diagnostic_context(
"""
if options is None:
options = infra.DiagnosticOptions()
context = infra.DiagnosticContext(
name, version, options, diagnostic_type=diagnostic_type
)
context = infra.DiagnosticContext(name, version, options)
self.contexts.append(context)
return context

Expand Down Expand Up @@ -201,14 +198,14 @@ def diagnose(
) -> ExportDiagnostic:
"""Creates a diagnostic and record it in the global diagnostic context.

This is a wrapper around `context.add_diagnostic` that uses the global diagnostic
This is a wrapper around `context.log` that uses the global diagnostic
context.
"""
# NOTE: Cannot use `@_beartype.beartype`. It somehow erases the cpp stack frame info.
diagnostic = ExportDiagnostic(
rule, level, message, frames_to_skip=frames_to_skip, **kwargs
)
export_context().add_diagnostic(diagnostic)
export_context().log(diagnostic)
return diagnostic


Expand Down
3 changes: 2 additions & 1 deletion torch/onnx/_internal/diagnostics/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Tag,
ThreadFlowLocation,
)
from .context import Diagnostic, DiagnosticContext
from .context import Diagnostic, DiagnosticContext, RuntimeErrorWithDiagnostic

__all__ = [
"Diagnostic",
Expand All @@ -25,6 +25,7 @@
"Location",
"Rule",
"RuleCollection",
"RuntimeErrorWithDiagnostic",
"Stack",
"StackFrame",
"Tag",
Expand Down
19 changes: 14 additions & 5 deletions torch/onnx/_internal/diagnostics/infra/_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.onnx._internal.diagnostics.infra import formatter, sarif


class Level(enum.Enum):
class Level(enum.IntEnum):
"""The level of a diagnostic.

This class is used to represent the level of a diagnostic. The levels are defined
Expand All @@ -22,12 +22,21 @@ class Level(enum.Enum):
- NOTE: An opportunity for improvement was found.
- WARNING: A potential problem was found.
- ERROR: A serious problem was found.

This level is a subclass of enum.IntEnum, and can be used as an integer. Its integer
value maps to the logging levels in Python's logging module. The mapping is as
follows:

Level.NONE = logging.DEBUG = 10
Level.NOTE = logging.INFO = 20
Level.WARNING = logging.WARNING = 30
Level.ERROR = logging.ERROR = 40
"""

NONE = enum.auto()
NOTE = enum.auto()
WARNING = enum.auto()
ERROR = enum.auto()
NONE = 10
NOTE = 20
WARNING = 30
ERROR = 40


levels = Level
Expand Down
62 changes: 28 additions & 34 deletions torch/onnx/_internal/diagnostics/infra/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,13 @@

import logging

from typing import Callable, Generator, List, Literal, Mapping, Optional, Type, TypeVar
from typing import Callable, Generator, List, Literal, Mapping, Optional, TypeVar

from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import formatter, sarif, utils
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version


class DiagnosticError(RuntimeError):
pass


# This is a workaround for mypy not supporting Self from typing_extensions.
_Diagnostic = TypeVar("_Diagnostic", bound="Diagnostic")

Expand All @@ -37,6 +33,8 @@ class Diagnostic:
)
additional_message: Optional[str] = None
tags: List[infra.Tag] = dataclasses.field(default_factory=list)
source_exception: Optional[Exception] = None
"""The exception that caused this diagnostic to be created."""

def sarif(self) -> sarif.Result:
"""Returns the SARIF Result representation of this diagnostic."""
Expand Down Expand Up @@ -105,6 +103,11 @@ def with_additional_message(self: _Diagnostic, message: str) -> _Diagnostic:
self.additional_message = f"{self.additional_message}\n{message}"
return self

def with_source_exception(self: _Diagnostic, exception: Exception) -> _Diagnostic:
"""Adds the source exception to the diagnostic."""
self.source_exception = exception
return self

def record_python_call_stack(self, frames_to_skip: int) -> infra.Stack:
"""Records the current Python call stack."""
frames_to_skip += 1 # Skip this function.
Expand Down Expand Up @@ -173,14 +176,21 @@ def pretty_print(
# TODO: print help url to rule at the end.


class RuntimeErrorWithDiagnostic(RuntimeError):
"""Runtime error with enclosed diagnostic information."""

def __init__(self, diagnostic: Diagnostic):
super().__init__(diagnostic.message)
self.diagnostic = diagnostic


@dataclasses.dataclass
class DiagnosticContext:
name: str
version: str
options: infra.DiagnosticOptions = dataclasses.field(
default_factory=infra.DiagnosticOptions
)
diagnostic_type: Type[Diagnostic] = dataclasses.field(default=Diagnostic)
diagnostics: List[Diagnostic] = dataclasses.field(init=False, default_factory=list)
logger: logging.Logger = dataclasses.field(
init=True, default_factory=lambda: logging.getLogger().getChild("diagnostics")
Expand Down Expand Up @@ -231,7 +241,7 @@ def dump(self, file_path: str, compress: bool = False) -> None:
with open(file_path, "w") as f:
f.write(self.to_json())

def add_diagnostic(self, diagnostic: Diagnostic) -> None:
def log(self, diagnostic: Diagnostic) -> None:
"""Adds a diagnostic to the context.

Use this method to add diagnostics that are not created by the context.
Expand All @@ -243,6 +253,15 @@ def add_diagnostic(self, diagnostic: Diagnostic) -> None:
f"Expected diagnostic of type {Diagnostic}, got {type(diagnostic)}"
)
self.diagnostics.append(diagnostic)
self.logger.log(diagnostic.level, diagnostic.message)
self.logger.log(diagnostic.level, diagnostic.additional_message)

def log_and_raise_if_error(self, diagnostic: Diagnostic) -> None:
self.log(diagnostic)
if diagnostic.level == infra.Level.ERROR:
raise RuntimeErrorWithDiagnostic(
diagnostic
) from diagnostic.source_exception

@contextlib.contextmanager
def add_inflight_diagnostic(
Expand All @@ -260,31 +279,6 @@ def add_inflight_diagnostic(
finally:
self._inflight_diagnostics.pop()

def diagnose(
self,
rule: infra.Rule,
level: infra.Level,
message: Optional[str] = None,
**kwargs,
) -> Diagnostic:
"""Creates a diagnostic for the given arguments.

Args:
rule: The rule that triggered the diagnostic.
level: The level of the diagnostic.
message: The message of the diagnostic.
**kwargs: Additional arguments to pass to the Diagnostic constructor.

Returns:
The created diagnostic.

Raises:
ValueError: If the rule is not supported by the tool.
"""
diagnostic = self.diagnostic_type(rule, level, message, **kwargs)
self.add_diagnostic(diagnostic)
return diagnostic

def push_inflight_diagnostic(self, diagnostic: Diagnostic) -> None:
"""Pushes a diagnostic to the inflight diagnostics stack.

Expand All @@ -308,15 +302,15 @@ def inflight_diagnostic(self, rule: Optional[infra.Rule] = None) -> Diagnostic:
if rule is None:
# TODO(bowbao): Create builtin-rules and create diagnostic using that.
if len(self._inflight_diagnostics) <= 0:
raise DiagnosticError("No inflight diagnostics")
raise AssertionError("No inflight diagnostics")

return self._inflight_diagnostics[-1]
else:
# TODO(bowbao): Improve efficiency with Mapping[Rule, List[Diagnostic]]
for diagnostic in reversed(self._inflight_diagnostics):
if diagnostic.rule == rule:
return diagnostic
raise DiagnosticError(f"No inflight diagnostic for rule {rule.name}")
raise AssertionError(f"No inflight diagnostic for rule {rule.name}")

def pretty_print(
self, verbose: Optional[bool] = None, log_level: Optional[infra.Level] = None
Expand Down
38 changes: 8 additions & 30 deletions torch/onnx/_internal/diagnostics/infra/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@_beartype.beartype
def format_message_in_text(fn: Callable, *args: Any, **kwargs: Any) -> str:
return f"{formatter.display_name(fn)}"
return f"{formatter.display_name(fn)}. "


@_beartype.beartype
Expand Down Expand Up @@ -57,30 +57,14 @@ def format_return_values_in_markdown(
]


@_beartype.beartype
def modify_diagnostic(
diag: infra.Diagnostic,
fn: Callable,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
return_values: Any,
) -> None:
return


@_beartype.beartype
def diagnose_call(
rule: infra.Rule,
*,
level: infra.Level = infra.Level.NONE,
exception_report_level: infra.Level = infra.Level.WARNING,
diagnostic_type: Type[infra.Diagnostic] = infra.Diagnostic,
format_argument: Callable[[Any], str] = formatter.format_argument,
diagnostic_message_formatter: MessageFormatterType = format_message_in_text,
diagnostic_modifier: ModifierCallableType = modify_diagnostic,
report_criterion: Callable[
[Callable, Tuple[Any, ...], Dict[str, Any], Any], bool
] = lambda _1, _2, _3, _4: True,
) -> Callable:
def decorator(fn):
@functools.wraps(fn)
Expand Down Expand Up @@ -138,30 +122,24 @@ def wrapper(*args, **kwargs):
]

return_values: Any = None
report_diagnostic: bool = True
with ctx.add_inflight_diagnostic(diag) as diag:
try:
return_values = fn(*args, **kwargs)
additional_messages.append(
format_return_values_in_markdown(return_values, format_argument)
)
report_diagnostic = report_criterion(
fn, args, kwargs, return_values
)
return return_values
except Exception as e:
# Record exception.
report_diagnostic = True
diag.level = exception_report_level
diag.level = infra.levels.ERROR
# TODO(bowbao): Message emitting api.
diag.message = diag.message or ""
diag.message += f"Raised from:\n {type(e).__name__}: {e}"
diag.with_source_exception(e)
additional_messages.append(format_exception_in_markdown(e))
raise
finally:
if report_diagnostic:
diag.with_additional_message(
"\n".join(additional_messages).strip()
)
diagnostic_modifier(diag, fn, args, kwargs, return_values)
ctx.add_diagnostic(diag)
diag.with_additional_message("\n".join(additional_messages).strip())
ctx.log_and_raise_if_error(diag)

return wrapper

Expand Down
14 changes: 13 additions & 1 deletion torch/onnx/_internal/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,16 @@ def __init__(self, package_name: str, message: str):
self.package_name = package_name


class OnnxExporterError(RuntimeError):
"""Raised when an ONNX exporter error occurs. Diagnostic context is enclosed."""

diagnostic_context: Final[infra.DiagnosticContext]

def __init__(self, diagnostic_context: infra.DiagnosticContext, message: str):
super().__init__(message)
self.diagnostic_context = diagnostic_context


@_beartype.beartype
def _assert_dependencies(export_options: ResolvedExportOptions):
logger = export_options.logger
Expand Down Expand Up @@ -619,7 +629,9 @@ def dynamo_export(
f"Failed to export the model to ONNX. Generating SARIF report at {sarif_report_path}. "
f"Please report a bug on PyTorch Github: {_PYTORCH_GITHUB_ISSUES_URL}"
)
raise RuntimeError(message) from e
raise OnnxExporterError(
resolved_export_options.diagnostic_context, message
) from e


__all__ = [
Expand Down