Skip to content

Commit

Permalink
[ONNX] Diagnostic option 'warnings_as_errors'
Browse files Browse the repository at this point in the history
If set, diagnostics with level as WARNING will be logged as level
with ERROR, and immediately raised.

TODO: bikeshed public export api.

[ghstack-poisoned]
  • Loading branch information
BowenBao committed Jul 25, 2023
1 parent b8eb827 commit e1ac710
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 3 deletions.
24 changes: 24 additions & 0 deletions test/onnx/internal/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,30 @@ def test_diagnostic_context_raises_if_diagnostic_is_error(self):
)
)

def test_diagnostic_context_raises_original_exception_from_diagnostic_created_from_it(
self,
):
with self.assertRaises(ValueError):
try:
raise ValueError("original exception")
except ValueError as e:
diagnostic = infra.Diagnostic(
self.rules.rule_without_message_args, infra.Level.ERROR
)
diagnostic = diagnostic.with_source_exception(e)
self.context.log_and_raise_if_error(diagnostic)

def test_diagnostic_context_raises_if_diagnostic_is_warning_and_warnings_as_errors_is_true(
self,
):
with self.assertRaises(infra.RuntimeErrorWithDiagnostic):
self.context.options.warnings_as_errors = True
self.context.log_and_raise_if_error(
infra.Diagnostic(
self.rules.rule_without_message_args, infra.Level.WARNING
)
)


if __name__ == "__main__":
common_utils.run_tests()
2 changes: 2 additions & 0 deletions torch/onnx/_internal/diagnostics/infra/_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,5 @@ class DiagnosticOptions:

log_verbose: bool = dataclasses.field(default=False)
log_level: Level = dataclasses.field(default=Level.ERROR)
warnings_as_errors: bool = dataclasses.field(default=False)
"""If True, warnings are treated as errors."""
8 changes: 5 additions & 3 deletions torch/onnx/_internal/diagnostics/infra/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,16 +255,18 @@ def log(self, diagnostic: Diagnostic) -> None:
raise TypeError(
f"Expected diagnostic of type {Diagnostic}, got {type(diagnostic)}"
)
if self.options.warnings_as_errors and diagnostic.level == infra.Level.WARNING:
diagnostic.level = infra.Level.ERROR
self.diagnostics.append(diagnostic)
self.logger.log(diagnostic.level, diagnostic.message)
self.logger.log(diagnostic.level, diagnostic.additional_message)

def log_and_raise_if_error(self, diagnostic: Diagnostic) -> None:
self.log(diagnostic)
if diagnostic.level == infra.Level.ERROR:
raise RuntimeErrorWithDiagnostic(
diagnostic
) from diagnostic.source_exception
if diagnostic.source_exception is not None:
raise diagnostic.source_exception
raise RuntimeErrorWithDiagnostic(diagnostic)

@contextlib.contextmanager
def add_inflight_diagnostic(
Expand Down

0 comments on commit e1ac710

Please sign in to comment.