Skip to content

Commit

Permalink
[ONNX] Migrate to PT2 logging
Browse files Browse the repository at this point in the history
Summary
- 'dynamo_export' diagnostics now utilize PT2 artifact logger to manage
the verbosity level of logs recorded into each diagnostic in SARIF log.
The terminal logging is by default turned off. Setting environment variable
'TORCH_LOGS="+onnx_diagnostics"' turns on terminal logging, as well as
adjusts the verbosity level which overrides that in diagnostic options.
Replaces 'with_additional_message' with 'Logger.log' like apis.
- Introduce 'LazyString', adopted from 'torch._dynamo.utils', to skip
evaluation if the message will not be logged into diagnostic.
- Introduce 'log_source_exception' for easier exception logging.
- Introduce 'log_section' for easier markdown title logging.
- Updated all existing code to use new api.
- Removed 'arg_format_too_verbose' diagnostic.
- Rename legacy diagnostic classes for TorchScript Onnx Exporter to avoid
confusion.

Follow ups
- The 'dynamo_export' diagnostic now will not capture python stack
information at point of diagnostic creation. This will be added back in
follow up PRs for debug level logging.
- There is type mismatch due to subclassing 'Diagnostic' and 'DiagnosticContext'
for 'dynamo_export' to incorporate with PT2 logging. Follow up PR will
attempt to fix it.
- More docstrings with examples.

ghstack-source-id: b979765ade83f6372095cfc81708a31dc8984dca
Pull Request resolved: #106592
  • Loading branch information
BowenBao committed Aug 4, 2023
1 parent d0a7aab commit d6f66e2
Show file tree
Hide file tree
Showing 18 changed files with 534 additions and 245 deletions.
2 changes: 1 addition & 1 deletion docs/source/onnx_diagnostics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ Diagnostic Rules
API Reference
-------------

.. autoclass:: torch.onnx._internal.diagnostics.ExportDiagnostic
.. autoclass:: torch.onnx._internal.diagnostics.TorchScriptOnnxExportDiagnostic
:members:
4 changes: 2 additions & 2 deletions test/onnx/dynamo/test_exporter_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from beartype import roar
from torch.onnx import dynamo_export, ExportOptions, ExportOutput
from torch.onnx._internal import exporter, io_adapter
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.exporter import (
_DEFAULT_OPSET_VERSION,
ExportOutputSerializer,
ProtobufExportOutputSerializer,
ResolvedExportOptions,
)
from torch.onnx._internal.fx import diagnostics

from torch.testing._internal import common_utils

Expand Down Expand Up @@ -156,7 +156,7 @@ def test_raise_on_invalid_save_argument_type(self):
onnx.ModelProto(),
io_adapter.InputAdapter(),
io_adapter.OutputAdapter(),
infra.DiagnosticContext("test", "1.0"),
diagnostics.DiagnosticContext("test", "1.0"),
fake_context=None,
)
with self.assertRaises(roar.BeartypeException):
Expand Down
3 changes: 1 addition & 2 deletions test/onnx/dynamo/test_registry_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from onnxscript import BFLOAT16, DOUBLE, FLOAT, FLOAT16 # type: ignore[import]
from onnxscript.function_libs.torch_lib import ops # type: ignore[import]
from onnxscript.onnx_opset import opset15 as op # type: ignore[import]
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher, registration
from torch.testing._internal import common_utils

Expand Down Expand Up @@ -85,7 +84,7 @@ def setUp(self):
self.registry = registration.OnnxRegistry(opset_version=18)
# TODO: remove this once we have a better way to do this
logger = logging.getLogger("TestDispatcher")
self.diagnostic_context = infra.DiagnosticContext(
self.diagnostic_context = diagnostics.DiagnosticContext(
"torch.onnx.dynamo_export", torch.__version__
)
self.dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher(
Expand Down
64 changes: 61 additions & 3 deletions test/onnx/internal/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import contextlib
import dataclasses
import io
import logging
import typing
import unittest
from typing import AbstractSet, Protocol, Tuple
Expand Down Expand Up @@ -123,7 +124,7 @@ def setUp(self):

def _trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp(
self,
) -> diagnostics.ExportDiagnostic:
) -> diagnostics.TorchScriptOnnxExportDiagnostic:
class CustomAdd(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
Expand All @@ -147,7 +148,9 @@ def forward(self, x):
diagnostic.rule == rule
and diagnostic.level == diagnostics.levels.WARNING
):
return typing.cast(diagnostics.ExportDiagnostic, diagnostic)
return typing.cast(
diagnostics.TorchScriptOnnxExportDiagnostic, diagnostic
)
raise AssertionError("No diagnostic found.")

def test_assert_diagnostic_raises_when_diagnostic_not_found(self):
Expand Down Expand Up @@ -203,7 +206,7 @@ def test_diagnostics_engine_records_diagnosis_reported_outside_of_export(
diagnostics.export_context().log(diagnostic)

def test_diagnostics_records_python_call_stack(self):
diagnostic = diagnostics.ExportDiagnostic(self._sample_rule, diagnostics.levels.NOTE) # fmt: skip
diagnostic = diagnostics.TorchScriptOnnxExportDiagnostic(self._sample_rule, diagnostics.levels.NOTE) # fmt: skip
# Do not break the above line, otherwise it will not work with Python-3.8+
stack = diagnostic.python_call_stack
assert stack is not None # for mypy
Expand Down Expand Up @@ -288,6 +291,61 @@ def test_diagnostics_engine_records_diagnosis_with_custom_rules(self):
)
self.context.log(diagnostic2)

def test_diagnostic_log_is_not_emitted_when_level_less_than_diagnostic_options_verbosity_level(
self,
):
verbosity_level = logging.INFO
self.context.options.verbosity_level = verbosity_level
with self.context:
diagnostic = infra.Diagnostic(
self.rules.rule_without_message_args, infra.Level.NOTE
)

with self.assertLogs(
diagnostic.logger, level=verbosity_level
) as assert_log_context:
diagnostic.log(logging.DEBUG, "debug message")
# NOTE: self.assertNoLogs only exist >= Python 3.10
# Add this dummy log such that we can pass self.assertLogs, and inspect
# assert_log_context.records to check if the log level is correct.
diagnostic.log(logging.INFO, "info message")

for record in assert_log_context.records:
self.assertGreaterEqual(record.levelno, logging.INFO)
self.assertFalse(
any(
message.find("debug message") >= 0
for message in diagnostic.additional_messages
)
)

def test_diagnostic_log_is_emitted_when_level_not_less_than_diagnostic_options_verbosity_level(
self,
):
verbosity_level = logging.INFO
self.context.options.verbosity_level = verbosity_level
with self.context:
diagnostic = infra.Diagnostic(
self.rules.rule_without_message_args, infra.Level.NOTE
)

level_message_pairs = [
(logging.INFO, "info message"),
(logging.WARNING, "warning message"),
(logging.ERROR, "error message"),
]

for level, message in level_message_pairs:
with self.assertLogs(diagnostic.logger, level=verbosity_level):
diagnostic.log(level, message)

self.assertTrue(
any(
message.find(message) >= 0
for message in diagnostic.additional_messages
)
)

def test_diagnostic_context_raises_if_diagnostic_is_error(self):
with self.assertRaises(infra.RuntimeErrorWithDiagnostic):
self.context.log_and_raise_if_error(
Expand Down
4 changes: 2 additions & 2 deletions torch/onnx/_internal/diagnostics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
diagnose,
engine,
export_context,
ExportDiagnostic,
ExportDiagnosticEngine,
TorchScriptOnnxExportDiagnostic,
)
from ._rules import rules
from .infra import levels

__all__ = [
"ExportDiagnostic",
"TorchScriptOnnxExportDiagnostic",
"ExportDiagnosticEngine",
"rules",
"levels",
Expand Down
13 changes: 6 additions & 7 deletions torch/onnx/_internal/diagnostics/_diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _cpp_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32) -> infra.S
)


class ExportDiagnostic(infra.Diagnostic):
class TorchScriptOnnxExportDiagnostic(infra.Diagnostic):
"""Base class for all export diagnostics.
This class is used to represent all export diagnostics. It is a subclass of
Expand Down Expand Up @@ -77,9 +77,6 @@ def record_cpp_call_stack(self, frames_to_skip: int) -> infra.Stack:
self.with_stack(stack)
return stack

def record_fx_graphmodule(self, gm: torch.fx.GraphModule) -> None:
self.with_graph(infra.Graph(gm.print_readable(False), gm.__class__.__name__))


class ExportDiagnosticEngine:
"""PyTorch ONNX Export diagnostic engine.
Expand Down Expand Up @@ -180,7 +177,9 @@ def create_export_diagnostic_context() -> (
_context == engine.background_context
), "Export context is already set. Nested export is not supported."
_context = engine.create_diagnostic_context(
"torch.onnx.export", torch.__version__, diagnostic_type=ExportDiagnostic
"torch.onnx.export",
torch.__version__,
diagnostic_type=TorchScriptOnnxExportDiagnostic,
)
try:
yield _context
Expand All @@ -194,14 +193,14 @@ def diagnose(
message: Optional[str] = None,
frames_to_skip: int = 2,
**kwargs,
) -> ExportDiagnostic:
) -> TorchScriptOnnxExportDiagnostic:
"""Creates a diagnostic and record it in the global diagnostic context.
This is a wrapper around `context.log` that uses the global diagnostic
context.
"""
# NOTE: Cannot use `@_beartype.beartype`. It somehow erases the cpp stack frame info.
diagnostic = ExportDiagnostic(
diagnostic = TorchScriptOnnxExportDiagnostic(
rule, level, message, frames_to_skip=frames_to_skip, **kwargs
)
export_context().log(diagnostic)
Expand Down
4 changes: 4 additions & 0 deletions torch/onnx/_internal/diagnostics/infra/_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import dataclasses
import enum
import logging
from typing import FrozenSet, List, Mapping, Optional, Sequence, Tuple

from torch.onnx._internal.diagnostics.infra import formatter, sarif
Expand Down Expand Up @@ -272,5 +273,8 @@ class DiagnosticOptions:
Options for diagnostic context.
"""

verbosity_level: int = dataclasses.field(default=logging.INFO)
"""The verbosity level of the diagnostic context."""

warnings_as_errors: bool = dataclasses.field(default=False)
"""If True, warnings are treated as errors."""

0 comments on commit d6f66e2

Please sign in to comment.