Skip to content

Commit

Permalink
[ONNX] Fix diagnostic log and add unittest
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
BowenBao committed Aug 14, 2023
1 parent b7a3e55 commit 538e202
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
28 changes: 28 additions & 0 deletions test/onnx/internal/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,20 @@ def test_log_is_emitted_to_terminal_when_log_artifact_is_enabled(self):
with self.assertLogs(diagnostic.logger, level=logging.INFO):
diagnostic.info("message")

def test_diagnostic_log_emit_correctly_formatted_string(self):
verbosity_level = logging.INFO
self.diagnostic_context.options.verbosity_level = verbosity_level
with self.diagnostic_context:
diagnostic = fx_diagnostics.Diagnostic(
self.rules.rule_without_message_args, infra.Level.NOTE
)
diagnostic.log(
logging.INFO,
"%s",
formatter.LazyString(lambda x, y: f"{x} {y}", "hello", "world"),
)
self.assertIn("hello world", diagnostic.additional_messages)


class TestTorchScriptOnnxDiagnostics(common_utils.TestCase):
"""Test cases for diagnostics emitted by the TorchScript ONNX export code."""
Expand Down Expand Up @@ -522,6 +536,20 @@ def expensive_formatting_function() -> str:
"expensive_formatting_function should only be evaluated once after being wrapped under LazyString",
)

def test_diagnostic_log_emit_correctly_formatted_string(self):
verbosity_level = logging.INFO
self.context.options.verbosity_level = verbosity_level
with self.context:
diagnostic = infra.Diagnostic(
self.rules.rule_without_message_args, infra.Level.NOTE
)
diagnostic.log(
logging.INFO,
"%s",
formatter.LazyString(lambda x, y: f"{x} {y}", "hello", "world"),
)
self.assertIn("hello world", diagnostic.additional_messages)

def test_diagnostic_nested_log_section_emits_messages_with_correct_section_title_indentation(
self,
):
Expand Down
6 changes: 3 additions & 3 deletions torch/onnx/_internal/fx/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,11 @@ def log(self, level: int, message: str, *args, **kwargs) -> None:
if self.logger.isEnabledFor(level):
formatted_message = message % args
if is_onnx_diagnostics_log_artifact_enabled():
# Only log to terminal if artifact is not enabled.
# Only log to terminal if artifact is enabled.
# See [NOTE: `dynamo_export` diagnostics logging] for details.
self.logger.log(level, message, **kwargs)
self.logger.log(level, formatted_message, **kwargs)

self.additional_messages.append(message)
self.additional_messages.append(formatted_message)


@dataclasses.dataclass
Expand Down

0 comments on commit 538e202

Please sign in to comment.