Skip to content

Commit

Permalink
[ONNX] Migrate to PT2 logging
Browse files Browse the repository at this point in the history
Summary
- 'dynamo_export' diagnostics now utilize PT2 artifact logger to manage
the verbosity level of logs recorded into each diagnostic in SARIF log.
The terminal logging is by default turned off. Setting environment variable
'TORCH_LOGS="+onnx_diagnostics"' turns on terminal logging, as well as
adjusts the verbosity level which overrides that in diagnostic options.
Replaces 'with_additional_message' with 'Logger.log' like apis.
- Introduce 'LazyString', adopted from 'torch._dynamo.utils', to skip
evaluation if the message will not be logged into diagnostic.
- Introduce 'log_source_exception' for easier exception logging.
- Introduce 'log_section' for easier markdown title logging.
- Updated all existing code to use new api.
- Removed 'arg_format_too_verbose' diagnostic.
- Rename legacy diagnostic classes for TorchScript Onnx Exporter to avoid
confusion.

Follow ups
- The 'dynamo_export' diagnostic now will not capture python stack
information at point of diagnostic creation. This will be added back in
follow up PRs for debug level logging.
- There is type mismatch due to subclassing 'Diagnostic' and 'DiagnosticContext'
for 'dynamo_export' to incorporate with PT2 logging. Follow up PR will
attempt to fix it.
- More docstrings with examples.

ghstack-source-id: 4cc785e10fb4a9eccf9b62e7f9540cb5aa2bc0e8
Pull Request resolved: #106592
  • Loading branch information
BowenBao committed Aug 4, 2023
1 parent d0a7aab commit 570d594
Show file tree
Hide file tree
Showing 16 changed files with 507 additions and 242 deletions.
3 changes: 1 addition & 2 deletions test/onnx/dynamo/test_registry_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from onnxscript import BFLOAT16, DOUBLE, FLOAT, FLOAT16 # type: ignore[import]
from onnxscript.function_libs.torch_lib import ops # type: ignore[import]
from onnxscript.onnx_opset import opset15 as op # type: ignore[import]
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher, registration
from torch.testing._internal import common_utils

Expand Down Expand Up @@ -85,7 +84,7 @@ def setUp(self):
self.registry = registration.OnnxRegistry(opset_version=18)
# TODO: remove this once we have a better way to do this
logger = logging.getLogger("TestDispatcher")
self.diagnostic_context = infra.DiagnosticContext(
self.diagnostic_context = diagnostics.DiagnosticContext(
"torch.onnx.dynamo_export", torch.__version__
)
self.dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher(
Expand Down
64 changes: 61 additions & 3 deletions test/onnx/internal/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import contextlib
import dataclasses
import io
import logging
import typing
import unittest
from typing import AbstractSet, Protocol, Tuple
Expand Down Expand Up @@ -123,7 +124,7 @@ def setUp(self):

def _trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp(
self,
) -> diagnostics.ExportDiagnostic:
) -> diagnostics.TorchScriptOnnxExportDiagnostic:
class CustomAdd(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
Expand All @@ -147,7 +148,9 @@ def forward(self, x):
diagnostic.rule == rule
and diagnostic.level == diagnostics.levels.WARNING
):
return typing.cast(diagnostics.ExportDiagnostic, diagnostic)
return typing.cast(
diagnostics.TorchScriptOnnxExportDiagnostic, diagnostic
)
raise AssertionError("No diagnostic found.")

def test_assert_diagnostic_raises_when_diagnostic_not_found(self):
Expand Down Expand Up @@ -203,7 +206,7 @@ def test_diagnostics_engine_records_diagnosis_reported_outside_of_export(
diagnostics.export_context().log(diagnostic)

def test_diagnostics_records_python_call_stack(self):
diagnostic = diagnostics.ExportDiagnostic(self._sample_rule, diagnostics.levels.NOTE) # fmt: skip
diagnostic = diagnostics.TorchScriptOnnxExportDiagnostic(self._sample_rule, diagnostics.levels.NOTE) # fmt: skip
# Do not break the above line, otherwise it will not work with Python-3.8+
stack = diagnostic.python_call_stack
assert stack is not None # for mypy
Expand Down Expand Up @@ -288,6 +291,61 @@ def test_diagnostics_engine_records_diagnosis_with_custom_rules(self):
)
self.context.log(diagnostic2)

def test_diagnostic_log_is_not_emitted_when_level_less_than_diagnostic_options_verbosity_level(
self,
):
verbosity_level = logging.INFO
self.context.options.verbosity_level = verbosity_level
with self.context:
diagnostic = infra.Diagnostic(
self.rules.rule_without_message_args, infra.Level.NOTE
)

with self.assertLogs(
diagnostic.logger, level=verbosity_level
) as assert_log_context:
diagnostic.log(logging.DEBUG, "debug message")
# NOTE: self.assertNoLogs only exist >= Python 3.10
# Add this dummy log such that we can pass self.assertLogs, and inspect
# assert_log_context.records to check if the log level is correct.
diagnostic.log(logging.INFO, "info message")

for record in assert_log_context.records:
self.assertGreaterEqual(record.levelno, logging.INFO)
self.assertFalse(
any(
message.find("debug message") >= 0
for message in diagnostic.additional_messages
)
)

def test_diagnostic_log_is_emitted_when_level_not_less_than_diagnostic_options_verbosity_level(
self,
):
verbosity_level = logging.INFO
self.context.options.verbosity_level = verbosity_level
with self.context:
diagnostic = infra.Diagnostic(
self.rules.rule_without_message_args, infra.Level.NOTE
)

level_message_pairs = [
(logging.INFO, "info message"),
(logging.WARNING, "warning message"),
(logging.ERROR, "error message"),
]

for level, message in level_message_pairs:
with self.assertLogs(diagnostic.logger, level=verbosity_level):
diagnostic.log(level, message)

self.assertTrue(
any(
message.find(message) >= 0
for message in diagnostic.additional_messages
)
)

def test_diagnostic_context_raises_if_diagnostic_is_error(self):
with self.assertRaises(infra.RuntimeErrorWithDiagnostic):
self.context.log_and_raise_if_error(
Expand Down
4 changes: 2 additions & 2 deletions torch/onnx/_internal/diagnostics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
diagnose,
engine,
export_context,
ExportDiagnostic,
ExportDiagnosticEngine,
TorchScriptOnnxExportDiagnostic,
)
from ._rules import rules
from .infra import levels

__all__ = [
"ExportDiagnostic",
"TorchScriptOnnxExportDiagnostic",
"ExportDiagnosticEngine",
"rules",
"levels",
Expand Down
16 changes: 8 additions & 8 deletions torch/onnx/_internal/diagnostics/_diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _cpp_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32) -> infra.S
)


class ExportDiagnostic(infra.Diagnostic):
class TorchScriptOnnxExportDiagnostic(infra.Diagnostic):
"""Base class for all export diagnostics.
This class is used to represent all export diagnostics. It is a subclass of
Expand Down Expand Up @@ -77,9 +77,6 @@ def record_cpp_call_stack(self, frames_to_skip: int) -> infra.Stack:
self.with_stack(stack)
return stack

def record_fx_graphmodule(self, gm: torch.fx.GraphModule) -> None:
self.with_graph(infra.Graph(gm.print_readable(False), gm.__class__.__name__))


class ExportDiagnosticEngine:
"""PyTorch ONNX Export diagnostic engine.
Expand Down Expand Up @@ -180,7 +177,9 @@ def create_export_diagnostic_context() -> (
_context == engine.background_context
), "Export context is already set. Nested export is not supported."
_context = engine.create_diagnostic_context(
"torch.onnx.export", torch.__version__, diagnostic_type=ExportDiagnostic
"torch.onnx.export",
torch.__version__,
diagnostic_type=TorchScriptOnnxExportDiagnostic,
)
try:
yield _context
Expand All @@ -194,15 +193,16 @@ def diagnose(
message: Optional[str] = None,
frames_to_skip: int = 2,
**kwargs,
) -> ExportDiagnostic:
) -> TorchScriptOnnxExportDiagnostic:
"""Creates a diagnostic and record it in the global diagnostic context.
This is a wrapper around `context.log` that uses the global diagnostic
context.
"""
# NOTE: Cannot use `@_beartype.beartype`. It somehow erases the cpp stack frame info.
diagnostic = ExportDiagnostic(
rule, level, message, frames_to_skip=frames_to_skip, **kwargs
logger = torch._logging.getArtifactLogger("torch.onnx", "onnx_diagnostics")
diagnostic = TorchScriptOnnxExportDiagnostic(
rule, level, message, logger=logger, frames_to_skip=frames_to_skip, **kwargs
)
export_context().log(diagnostic)
return diagnostic
Expand Down
4 changes: 4 additions & 0 deletions torch/onnx/_internal/diagnostics/infra/_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import dataclasses
import enum
import logging
from typing import FrozenSet, List, Mapping, Optional, Sequence, Tuple

from torch.onnx._internal.diagnostics.infra import formatter, sarif
Expand Down Expand Up @@ -272,5 +273,8 @@ class DiagnosticOptions:
Options for diagnostic context.
"""

verbosity_level: int = dataclasses.field(default=logging.INFO)
"""The verbosity level of the diagnostic context."""

warnings_as_errors: bool = dataclasses.field(default=False)
"""If True, warnings are treated as errors."""
128 changes: 115 additions & 13 deletions torch/onnx/_internal/diagnostics/infra/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import dataclasses
import gzip

from typing import Callable, Generator, List, Literal, Mapping, Optional, TypeVar
import logging

from typing import Callable, Generator, List, Literal, Mapping, Optional, TypeVar, Union

from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import formatter, sarif, utils
Expand All @@ -16,6 +18,7 @@

# This is a workaround for mypy not supporting Self from typing_extensions.
_Diagnostic = TypeVar("_Diagnostic", bound="Diagnostic")
diagnostic_logger: logging.Logger = logging.getLogger(__name__)


@dataclasses.dataclass
Expand All @@ -29,20 +32,25 @@ class Diagnostic:
thread_flow_locations: List[infra.ThreadFlowLocation] = dataclasses.field(
default_factory=list
)
additional_message: Optional[str] = None
additional_messages: List[str] = dataclasses.field(default_factory=list)
tags: List[infra.Tag] = dataclasses.field(default_factory=list)
source_exception: Optional[Exception] = None
"""The exception that caused this diagnostic to be created."""
logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
"""The logger for this diagnostic. Defaults to 'diagnostic_logger' which has the same
log level setting with `DiagnosticOptions.verbosity_level`."""
_current_log_section_depth: int = 0

def __post_init__(self) -> None:
pass

def sarif(self) -> sarif.Result:
"""Returns the SARIF Result representation of this diagnostic."""
message = self.message or self.rule.message_default_template
if self.additional_message:
if self.additional_messages:
additional_message = "\n".join(self.additional_messages)
message_markdown = (
f"{message}\n\n## Additional Message:\n\n{self.additional_message}"
f"{message}\n\n## Additional Message:\n\n{additional_message}"
)
else:
message_markdown = message
Expand Down Expand Up @@ -96,13 +104,87 @@ def with_graph(self: _Diagnostic, graph: infra.Graph) -> _Diagnostic:
self.graphs.append(graph)
return self

def with_additional_message(self: _Diagnostic, message: str) -> _Diagnostic:
"""Adds an additional message to the diagnostic."""
if self.additional_message is None:
self.additional_message = message
else:
self.additional_message = f"{self.additional_message}\n{message}"
return self
@contextlib.contextmanager
def log_section(
self, level: int, title: Union[str, formatter.LazyString]
) -> Generator[None, None, None]:
"""
Context manager for a section of log messages, denoted by a title and increased indentation.
This context manager logs the given title at the specified log level, increases the current
section depth for subsequent log messages, and ensures that the section depth is decreased
again when exiting the context.
Args:
level (int): The log level.
title (Union[str, formatter.LazyString]): The title of the log section. This can be a string
or a `formatter.LazyString`.
Yields:
None: This context manager does not yield any value.
"""
self._current_log_section_depth
self.log(level, "##%s %s", "#" * self._current_log_section_depth, title)
self._current_log_section_depth += 1
try:
yield
finally:
self._current_log_section_depth -= 1

def log(self, level: int, message: str, *args, **kwargs) -> None:
"""Logs a message within the diagnostic. Same api as `logging.Logger.log`.
If logger is not enabled for the given level, the message will not be logged.
Otherwise, the message will be logged and also added to the diagnostic's additional_messages.
Args:
level: The log level.
message: The message to log.
*args: The arguments to the message. Use `formatter.LazyString` to defer the
expensive evaluation of the arguments until the message is actually logged.
**kwargs: The keyword arguments for `logging.Logger.log`.
"""
self.logger.log(level, message, *args, **kwargs)
if self.logger.isEnabledFor(level):
self.additional_messages.append(message % args)
self.additional_messages.append("\n")

def debug(self, message: str, *args, **kwargs) -> None:
"""Logs a debug message within the diagnostic. Same api as logging.Logger.debug.
Checkout `log` for more details.
"""
self.log(logging.DEBUG, message, *args, **kwargs)

def info(self, message: str, *args, **kwargs) -> None:
"""Logs an info message within the diagnostic. Same api as logging.Logger.info.
Checkout `log` for more details.
"""
self.log(logging.INFO, message, *args, **kwargs)

def warning(self, message: str, *args, **kwargs) -> None:
"""Logs a warning message within the diagnostic. Same api as logging.Logger.warning.
Checkout `log` for more details.
"""
self.log(logging.WARNING, message, *args, **kwargs)

def error(self, message: str, *args, **kwargs) -> None:
"""Logs an error message within the diagnostic. Same api as logging.Logger.error.
Checkout `log` for more details.
"""
self.log(logging.ERROR, message, *args, **kwargs)

def log_source_exception(self, level: int, exception: Exception) -> None:
"""Logs a source exception within the diagnostic.
Invokes `log_section` and `log` to log the exception in markdown section format.
"""
self.source_exception = exception
with self.log_section(level, "Exception log"):
self.log(level, "%s", formatter.lazy_format_exception(exception))

def with_source_exception(self: _Diagnostic, exception: Exception) -> _Diagnostic:
"""Adds the source exception to the diagnostic."""
Expand Down Expand Up @@ -163,11 +245,16 @@ class DiagnosticContext:
_inflight_diagnostics: List[Diagnostic] = dataclasses.field(
init=False, default_factory=list
)
_previous_log_level: int = dataclasses.field(init=False, default=logging.WARNING)
logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)

def __enter__(self):
self._previous_log_level = self.logger.level
self.logger.level = self.options.verbosity_level
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.logger.level = self._previous_log_level
return None

def sarif(self) -> sarif.Run:
Expand Down Expand Up @@ -205,9 +292,11 @@ def dump(self, file_path: str, compress: bool = False) -> None:
f.write(self.to_json())

def log(self, diagnostic: Diagnostic) -> None:
"""Adds a diagnostic to the context.
"""Logs a diagnostic.
This method should be used only after all the necessary information for the diagnostic
has been collected.
Use this method to add diagnostics that are not created by the context.
Args:
diagnostic: The diagnostic to add.
"""
Expand All @@ -220,6 +309,19 @@ def log(self, diagnostic: Diagnostic) -> None:
self.diagnostics.append(diagnostic)

def log_and_raise_if_error(self, diagnostic: Diagnostic) -> None:
"""Logs a diagnostic and raises an exception if it is an error.
Use this method for logging non inflight diagnostics where diagnostic level is not known or
lower than ERROR. If it is always expected raise, use `log` and explicit
`raise` instead. Otherwise there is no way to convey the message that it always
raises to Python intellisense and type checking tools.
This method should be used only after all the necessary information for the diagnostic
has been collected.
Args:
diagnostic: The diagnostic to add.
"""
self.log(diagnostic)
if diagnostic.level == infra.Level.ERROR:
if diagnostic.source_exception is not None:
Expand Down

0 comments on commit 570d594

Please sign in to comment.