Skip to content

Commit

Permalink
[ONNX] Set 'Generic[Diagnostic]' as base class for 'DiagnosticContext'
Browse files Browse the repository at this point in the history
Allows type checking for 'Diagnostic' argument when calling 'context.log'.

ghstack-source-id: a7803c2c5ee6ede41b5b84e43a9f3098dcc453a0
Pull Request resolved: #107165
  • Loading branch information
BowenBao committed Aug 14, 2023
1 parent f5e7423 commit e5cffb1
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 15 deletions.
23 changes: 22 additions & 1 deletion test/onnx/internal/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,18 @@ def test_diagnostic_log_emit_correctly_formatted_string(self):
)
self.assertIn("hello world", diagnostic.additional_messages)

def test_log_diagnostic_to_diagnostic_context_raises_when_diagnostic_type_is_wrong(
self,
):
with self.diagnostic_context:
# Dynamo onnx exporter diagnostic context expects fx_diagnostics.Diagnostic
# instead of base infra.Diagnostic.
diagnostic = infra.Diagnostic(
self.rules.rule_without_message_args, infra.Level.NOTE
)
with self.assertRaises(TypeError):
self.diagnostic_context.log(diagnostic)


class TestTorchScriptOnnxDiagnostics(common_utils.TestCase):
"""Test cases for diagnostics emitted by the TorchScript ONNX export code."""
Expand Down Expand Up @@ -353,7 +365,9 @@ class TestDiagnosticsInfra(common_utils.TestCase):
def setUp(self):
self.rules = _RuleCollectionForTest()
with contextlib.ExitStack() as stack:
self.context = stack.enter_context(infra.DiagnosticContext("test", "1.0.0"))
self.context: infra.DiagnosticContext[
infra.Diagnostic
] = stack.enter_context(infra.DiagnosticContext("test", "1.0.0"))
self.addCleanup(stack.pop_all().close)
return super().setUp()

Expand Down Expand Up @@ -591,6 +605,13 @@ def test_diagnostic_log_source_exception_emits_exception_traceback_and_error_mes
self.assertIn("ValueError: original exception", diagnostic_message)
self.assertIn("Traceback (most recent call last):", diagnostic_message)

def test_log_diagnostic_to_diagnostic_context_raises_when_diagnostic_type_is_wrong(
self,
):
with self.context:
with self.assertRaises(TypeError):
self.context.log("I thought I should put a message here.")

def test_diagnostic_context_raises_if_diagnostic_is_error(self):
with self.assertRaises(infra.RuntimeErrorWithDiagnostic):
self.context.log_and_raise_if_error(
Expand Down
37 changes: 24 additions & 13 deletions torch/onnx/_internal/diagnostics/infra/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@

import logging

from typing import Callable, Generator, List, Literal, Mapping, Optional, TypeVar
from typing import (
Callable,
Generator,
Generic,
List,
Literal,
Mapping,
Optional,
Type,
TypeVar,
)

from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import formatter, sarif, utils
Expand Down Expand Up @@ -260,20 +270,21 @@ def __init__(self, diagnostic: Diagnostic):


@dataclasses.dataclass
class DiagnosticContext:
class DiagnosticContext(Generic[_Diagnostic]):
name: str
version: str
options: infra.DiagnosticOptions = dataclasses.field(
default_factory=infra.DiagnosticOptions
)
diagnostics: List[Diagnostic] = dataclasses.field(init=False, default_factory=list)
diagnostics: List[_Diagnostic] = dataclasses.field(init=False, default_factory=list)
# TODO(bowbao): Implement this.
# _invocation: infra.Invocation = dataclasses.field(init=False)
_inflight_diagnostics: List[Diagnostic] = dataclasses.field(
_inflight_diagnostics: List[_Diagnostic] = dataclasses.field(
init=False, default_factory=list
)
_previous_log_level: int = dataclasses.field(init=False, default=logging.WARNING)
logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
_bound_diagnostic_type: Type = dataclasses.field(init=False, default=Diagnostic)

def __enter__(self):
self._previous_log_level = self.logger.level
Expand Down Expand Up @@ -318,7 +329,7 @@ def dump(self, file_path: str, compress: bool = False) -> None:
with open(file_path, "w") as f:
f.write(self.to_json())

def log(self, diagnostic: Diagnostic) -> None:
def log(self, diagnostic: _Diagnostic) -> None:
"""Logs a diagnostic.
This method should be used only after all the necessary information for the diagnostic
Expand All @@ -327,15 +338,15 @@ def log(self, diagnostic: Diagnostic) -> None:
Args:
diagnostic: The diagnostic to add.
"""
if not isinstance(diagnostic, Diagnostic):
if not isinstance(diagnostic, self._bound_diagnostic_type):
raise TypeError(
f"Expected diagnostic of type {Diagnostic}, got {type(diagnostic)}"
f"Expected diagnostic of type {self._bound_diagnostic_type}, got {type(diagnostic)}"
)
if self.options.warnings_as_errors and diagnostic.level == infra.Level.WARNING:
diagnostic.level = infra.Level.ERROR
self.diagnostics.append(diagnostic)

def log_and_raise_if_error(self, diagnostic: Diagnostic) -> None:
def log_and_raise_if_error(self, diagnostic: _Diagnostic) -> None:
"""Logs a diagnostic and raises an exception if it is an error.
Use this method for logging non inflight diagnostics where diagnostic level is not known or
Expand All @@ -357,8 +368,8 @@ def log_and_raise_if_error(self, diagnostic: Diagnostic) -> None:

@contextlib.contextmanager
def add_inflight_diagnostic(
self, diagnostic: Diagnostic
) -> Generator[Diagnostic, None, None]:
self, diagnostic: _Diagnostic
) -> Generator[_Diagnostic, None, None]:
"""Adds a diagnostic to the context.
Use this method to add diagnostics that are not created by the context.
Expand All @@ -371,7 +382,7 @@ def add_inflight_diagnostic(
finally:
self._inflight_diagnostics.pop()

def push_inflight_diagnostic(self, diagnostic: Diagnostic) -> None:
def push_inflight_diagnostic(self, diagnostic: _Diagnostic) -> None:
"""Pushes a diagnostic to the inflight diagnostics stack.
Args:
Expand All @@ -382,15 +393,15 @@ def push_inflight_diagnostic(self, diagnostic: Diagnostic) -> None:
"""
self._inflight_diagnostics.append(diagnostic)

def pop_inflight_diagnostic(self) -> Diagnostic:
def pop_inflight_diagnostic(self) -> _Diagnostic:
"""Pops the last diagnostic from the inflight diagnostics stack.
Returns:
The popped diagnostic.
"""
return self._inflight_diagnostics.pop()

def inflight_diagnostic(self, rule: Optional[infra.Rule] = None) -> Diagnostic:
def inflight_diagnostic(self, rule: Optional[infra.Rule] = None) -> _Diagnostic:
if rule is None:
# TODO(bowbao): Create builtin-rules and create diagnostic using that.
if len(self._inflight_diagnostics) <= 0:
Expand Down
5 changes: 4 additions & 1 deletion torch/onnx/_internal/fx/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,11 @@ def log(self, level: int, message: str, *args, **kwargs) -> None:


@dataclasses.dataclass
class DiagnosticContext(infra.DiagnosticContext):
class DiagnosticContext(infra.DiagnosticContext[Diagnostic]):
logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
_bound_diagnostic_type: type[Diagnostic] = dataclasses.field(
init=False, default=Diagnostic
)

def __enter__(self):
self._previous_log_level = self.logger.level
Expand Down

0 comments on commit e5cffb1

Please sign in to comment.