Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Migrate to PT2 logging #106592

Closed
wants to merge 8 commits into from
2 changes: 1 addition & 1 deletion docs/source/onnx_diagnostics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ Diagnostic Rules
API Reference
-------------

.. autoclass:: torch.onnx._internal.diagnostics.ExportDiagnostic
.. autoclass:: torch.onnx._internal.diagnostics.TorchScriptOnnxExportDiagnostic
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
:members:
4 changes: 2 additions & 2 deletions test/onnx/dynamo/test_exporter_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from beartype import roar
from torch.onnx import dynamo_export, ExportOptions, ExportOutput
from torch.onnx._internal import exporter, io_adapter
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.exporter import (
ExportOutputSerializer,
ProtobufExportOutputSerializer,
ResolvedExportOptions,
)
from torch.onnx._internal.fx import diagnostics

from torch.testing._internal import common_utils

Expand Down Expand Up @@ -144,7 +144,7 @@ def test_raise_on_invalid_save_argument_type(self):
onnx.ModelProto(),
io_adapter.InputAdapter(),
io_adapter.OutputAdapter(),
infra.DiagnosticContext("test", "1.0"),
diagnostics.DiagnosticContext("test", "1.0"),
fake_context=None,
)
with self.assertRaises(roar.BeartypeException):
Expand Down
3 changes: 1 addition & 2 deletions test/onnx/dynamo/test_registry_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from onnxscript import BFLOAT16, DOUBLE, FLOAT, FLOAT16 # type: ignore[import]
from onnxscript.function_libs.torch_lib import ops # type: ignore[import]
from onnxscript.onnx_opset import opset15 as op # type: ignore[import]
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher, registration
from torch.testing._internal import common_utils

Expand Down Expand Up @@ -85,7 +84,7 @@ def setUp(self):
self.registry = torch.onnx.OnnxRegistry()
# TODO: remove this once we have a better way to do this
logger = logging.getLogger("TestDispatcher")
self.diagnostic_context = infra.DiagnosticContext(
self.diagnostic_context = diagnostics.DiagnosticContext(
"torch.onnx.dynamo_export", torch.__version__
)
self.dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher(
Expand Down
208 changes: 190 additions & 18 deletions test/onnx/internal/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import contextlib
import dataclasses
import io
import logging
import typing
import unittest
from typing import AbstractSet, Protocol, Tuple
Expand All @@ -12,8 +13,9 @@
from torch.onnx import errors
from torch.onnx._internal import diagnostics
from torch.onnx._internal.diagnostics import infra
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
from torch.onnx._internal.diagnostics.infra import sarif
from torch.testing._internal import common_utils
from torch.onnx._internal.diagnostics.infra import formatter, sarif
from torch.onnx._internal.fx import diagnostics as fx_diagnostics
from torch.testing._internal import common_utils, logging_utils


class _SarifLogBuilder(Protocol):
Expand Down Expand Up @@ -43,6 +45,17 @@ def _assert_has_diagnostics(
)


@dataclasses.dataclass
class _RuleCollectionForTest(infra.RuleCollection):
rule_without_message_args: infra.Rule = dataclasses.field(
default=infra.Rule(
"1",
"rule-without-message-args",
message_default_template="rule message",
)
)


@contextlib.contextmanager
def assert_all_diagnostics(
test_suite: unittest.TestCase,
Expand Down Expand Up @@ -112,8 +125,93 @@ def assert_diagnostic(
return assert_all_diagnostics(test_suite, sarif_log_builder, {(rule, level)})


class TestOnnxDiagnostics(common_utils.TestCase):
"""Test cases for diagnostics emitted by the ONNX export code."""
class TestDynamoOnnxDiagnostics(common_utils.TestCase):
"""Test cases for diagnostics emitted by the Dynamo ONNX export code."""

def setUp(self):
self.diagnostic_context = fx_diagnostics.DiagnosticContext("dynamo_export", "")
self.rules = _RuleCollectionForTest()
return super().setUp()

def test_log_is_recorded_in_sarif_additional_messages_according_to_diagnostic_options_verbosity_level(
self,
):
logging_levels = [
logging.DEBUG,
logging.INFO,
logging.WARNING,
logging.ERROR,
]
for verbosity_level in logging_levels:
self.diagnostic_context.options.verbosity_level = verbosity_level
with self.diagnostic_context:
diagnostic = fx_diagnostics.Diagnostic(
self.rules.rule_without_message_args, infra.Level.NONE
)
additional_messages_count = len(diagnostic.additional_messages)
for log_level in logging_levels:
diagnostic.log(level=log_level, message="log message")
if log_level >= verbosity_level:
self.assertGreater(
len(diagnostic.additional_messages),
additional_messages_count,
f"Additional message should be recorded when log level is {log_level} "
f"and verbosity level is {verbosity_level}",
)
else:
self.assertEqual(
len(diagnostic.additional_messages),
additional_messages_count,
f"Additional message should not be recorded when log level is "
f"{log_level} and verbosity level is {verbosity_level}",
)

def test_torch_logs_environment_variable_precedes_diagnostic_options_verbosity_level(
self,
):
self.diagnostic_context.options.verbosity_level = logging.ERROR
with logging_utils.log_settings("onnx_diagnostics"), self.diagnostic_context:
diagnostic = fx_diagnostics.Diagnostic(
self.rules.rule_without_message_args, infra.Level.NONE
)
additional_messages_count = len(diagnostic.additional_messages)
diagnostic.debug("message")
self.assertGreater(
len(diagnostic.additional_messages), additional_messages_count
)

def test_log_is_not_emitted_to_terminal_when_log_artifact_is_not_enabled(self):
self.diagnostic_context.options.verbosity_level = logging.INFO
with self.diagnostic_context:
diagnostic = fx_diagnostics.Diagnostic(
self.rules.rule_without_message_args, infra.Level.NONE
)

with self.assertLogs(
diagnostic.logger, level=logging.INFO
) as assert_log_context:
diagnostic.info("message")
# NOTE: self.assertNoLogs only exist >= Python 3.10
# Add this dummy log such that we can pass self.assertLogs, and inspect
# assert_log_context.records to check if the log we don't want is not emitted.
diagnostic.logger.log(logging.ERROR, "dummy message")

self.assertEqual(len(assert_log_context.records), 1)

def test_log_is_emitted_to_terminal_when_log_artifact_is_enabled(self):
self.diagnostic_context.options.verbosity_level = logging.INFO

with logging_utils.log_settings("onnx_diagnostics"), self.diagnostic_context:
diagnostic = fx_diagnostics.Diagnostic(
self.rules.rule_without_message_args, infra.Level.NONE
)

with self.assertLogs(diagnostic.logger, level=logging.INFO):
diagnostic.info("message")


class TestTorchScriptOnnxDiagnostics(common_utils.TestCase):
"""Test cases for diagnostics emitted by the TorchScript ONNX export code."""

def setUp(self):
engine = diagnostics.engine
Expand All @@ -123,7 +221,7 @@ def setUp(self):

def _trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp(
self,
) -> diagnostics.ExportDiagnostic:
) -> diagnostics.TorchScriptOnnxExportDiagnostic:
class CustomAdd(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
Expand All @@ -147,7 +245,9 @@ def forward(self, x):
diagnostic.rule == rule
and diagnostic.level == diagnostics.levels.WARNING
):
return typing.cast(diagnostics.ExportDiagnostic, diagnostic)
return typing.cast(
diagnostics.TorchScriptOnnxExportDiagnostic, diagnostic
)
raise AssertionError("No diagnostic found.")

def test_assert_diagnostic_raises_when_diagnostic_not_found(self):
Expand Down Expand Up @@ -203,7 +303,7 @@ def test_diagnostics_engine_records_diagnosis_reported_outside_of_export(
diagnostics.export_context().log(diagnostic)

def test_diagnostics_records_python_call_stack(self):
diagnostic = diagnostics.ExportDiagnostic(self._sample_rule, diagnostics.levels.NOTE) # fmt: skip
diagnostic = diagnostics.TorchScriptOnnxExportDiagnostic(self._sample_rule, diagnostics.levels.NOTE) # fmt: skip
# Do not break the above line, otherwise it will not work with Python-3.8+
stack = diagnostic.python_call_stack
assert stack is not None # for mypy
Expand Down Expand Up @@ -232,17 +332,6 @@ def test_diagnostics_records_cpp_call_stack(self):
)


@dataclasses.dataclass
class _RuleCollectionForTest(infra.RuleCollection):
rule_without_message_args: infra.Rule = dataclasses.field(
default=infra.Rule(
"1",
"rule-without-message-args",
message_default_template="rule message",
)
)


class TestDiagnosticsInfra(common_utils.TestCase):
"""Test cases for diagnostics infra."""

Expand Down Expand Up @@ -288,6 +377,89 @@ def test_diagnostics_engine_records_diagnosis_with_custom_rules(self):
)
self.context.log(diagnostic2)

def test_diagnostic_log_is_not_emitted_when_level_less_than_diagnostic_options_verbosity_level(
self,
):
verbosity_level = logging.INFO
self.context.options.verbosity_level = verbosity_level
with self.context:
diagnostic = infra.Diagnostic(
self.rules.rule_without_message_args, infra.Level.NOTE
)

with self.assertLogs(
diagnostic.logger, level=verbosity_level
) as assert_log_context:
diagnostic.log(logging.DEBUG, "debug message")
# NOTE: self.assertNoLogs only exist >= Python 3.10
# Add this dummy log such that we can pass self.assertLogs, and inspect
# assert_log_context.records to check if the log level is correct.
diagnostic.log(logging.INFO, "info message")

for record in assert_log_context.records:
self.assertGreaterEqual(record.levelno, logging.INFO)
self.assertFalse(
any(
message.find("debug message") >= 0
for message in diagnostic.additional_messages
)
)

def test_diagnostic_log_is_emitted_when_level_not_less_than_diagnostic_options_verbosity_level(
self,
):
verbosity_level = logging.INFO
self.context.options.verbosity_level = verbosity_level
with self.context:
diagnostic = infra.Diagnostic(
self.rules.rule_without_message_args, infra.Level.NOTE
)

level_message_pairs = [
(logging.INFO, "info message"),
(logging.WARNING, "warning message"),
(logging.ERROR, "error message"),
]

for level, message in level_message_pairs:
with self.assertLogs(diagnostic.logger, level=verbosity_level):
diagnostic.log(level, message)

self.assertTrue(
any(
message.find(message) >= 0
for message in diagnostic.additional_messages
)
)

def test_diagnostic_log_lazy_string_is_not_evaluated_when_level_less_than_diagnostic_options_verbosity_level(
self,
):
verbosity_level = logging.INFO
self.context.options.verbosity_level = verbosity_level
with self.context:
diagnostic = infra.Diagnostic(
self.rules.rule_without_message_args, infra.Level.NOTE
)

reference_val = 0

def expensive_formatting_function() -> str:
# Modify the reference_val to reflect this function is evaluated
nonlocal reference_val
reference_val += 1
return f"expensive formatting {reference_val}"

# `expensive_formatting_function` should NOT be evaluated.
diagnostic.log(
logging.DEBUG, "%s", formatter.LazyString(expensive_formatting_function)
)
self.assertEqual(
reference_val,
0,
"expensive_formatting_function should not be evaluated after being wrapped under LazyString",
)

def test_diagnostic_context_raises_if_diagnostic_is_error(self):
with self.assertRaises(infra.RuntimeErrorWithDiagnostic):
self.context.log_and_raise_if_error(
Expand Down
4 changes: 2 additions & 2 deletions torch/onnx/_internal/diagnostics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
diagnose,
engine,
export_context,
ExportDiagnostic,
ExportDiagnosticEngine,
TorchScriptOnnxExportDiagnostic,
)
from ._rules import rules
from .infra import levels

__all__ = [
"ExportDiagnostic",
"TorchScriptOnnxExportDiagnostic",
"ExportDiagnosticEngine",
"rules",
"levels",
Expand Down
13 changes: 6 additions & 7 deletions torch/onnx/_internal/diagnostics/_diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _cpp_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32) -> infra.S
)


class ExportDiagnostic(infra.Diagnostic):
class TorchScriptOnnxExportDiagnostic(infra.Diagnostic):
"""Base class for all export diagnostics.

This class is used to represent all export diagnostics. It is a subclass of
Expand Down Expand Up @@ -77,9 +77,6 @@ def record_cpp_call_stack(self, frames_to_skip: int) -> infra.Stack:
self.with_stack(stack)
return stack

def record_fx_graphmodule(self, gm: torch.fx.GraphModule) -> None:
self.with_graph(infra.Graph(gm.print_readable(False), gm.__class__.__name__))


class ExportDiagnosticEngine:
"""PyTorch ONNX Export diagnostic engine.
Expand Down Expand Up @@ -180,7 +177,9 @@ def create_export_diagnostic_context() -> (
_context == engine.background_context
), "Export context is already set. Nested export is not supported."
_context = engine.create_diagnostic_context(
"torch.onnx.export", torch.__version__, diagnostic_type=ExportDiagnostic
"torch.onnx.export",
torch.__version__,
diagnostic_type=TorchScriptOnnxExportDiagnostic,
)
try:
yield _context
Expand All @@ -194,14 +193,14 @@ def diagnose(
message: Optional[str] = None,
frames_to_skip: int = 2,
**kwargs,
) -> ExportDiagnostic:
) -> TorchScriptOnnxExportDiagnostic:
"""Creates a diagnostic and record it in the global diagnostic context.

This is a wrapper around `context.log` that uses the global diagnostic
context.
"""
# NOTE: Cannot use `@_beartype.beartype`. It somehow erases the cpp stack frame info.
diagnostic = ExportDiagnostic(
diagnostic = TorchScriptOnnxExportDiagnostic(
rule, level, message, frames_to_skip=frames_to_skip, **kwargs
)
export_context().log(diagnostic)
Expand Down