Skip to content

Commit

Permalink
[ONNX] Set 'Generic[Diagnostic]' as base class for 'DiagnosticContext'
Browse files Browse the repository at this point in the history
Allows type checking for 'Diagnostic' argument when calling 'context.log'.

ghstack-source-id: c662af361b56c7bdb299e226375445763fbc02b8
Pull Request resolved: #107165
  • Loading branch information
BowenBao committed Aug 15, 2023
1 parent f5e7423 commit bf197bf
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 19 deletions.
25 changes: 24 additions & 1 deletion test/onnx/internal/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,18 @@ def test_diagnostic_log_emit_correctly_formatted_string(self):
)
self.assertIn("hello world", diagnostic.additional_messages)

def test_log_diagnostic_to_diagnostic_context_raises_when_diagnostic_type_is_wrong(
self,
):
with self.diagnostic_context:
# Dynamo onnx exporter diagnostic context expects fx_diagnostics.Diagnostic
# instead of base infra.Diagnostic.
diagnostic = infra.Diagnostic(
self.rules.rule_without_message_args, infra.Level.NOTE
)
with self.assertRaises(TypeError):
self.diagnostic_context.log(diagnostic)


class TestTorchScriptOnnxDiagnostics(common_utils.TestCase):
"""Test cases for diagnostics emitted by the TorchScript ONNX export code."""
Expand Down Expand Up @@ -353,7 +365,9 @@ class TestDiagnosticsInfra(common_utils.TestCase):
def setUp(self):
self.rules = _RuleCollectionForTest()
with contextlib.ExitStack() as stack:
self.context = stack.enter_context(infra.DiagnosticContext("test", "1.0.0"))
self.context: infra.DiagnosticContext[
infra.Diagnostic
] = stack.enter_context(infra.DiagnosticContext("test", "1.0.0"))
self.addCleanup(stack.pop_all().close)
return super().setUp()

Expand Down Expand Up @@ -591,6 +605,15 @@ def test_diagnostic_log_source_exception_emits_exception_traceback_and_error_mes
self.assertIn("ValueError: original exception", diagnostic_message)
self.assertIn("Traceback (most recent call last):", diagnostic_message)

def test_log_diagnostic_to_diagnostic_context_raises_when_diagnostic_type_is_wrong(
self,
):
with self.context:
with self.assertRaises(TypeError):
# The method expects 'Diagnostic' or its subclasses as arguments.
# Passing any other type will trigger a TypeError.
self.context.log("This is a str message.")

def test_diagnostic_context_raises_if_diagnostic_is_error(self):
with self.assertRaises(infra.RuntimeErrorWithDiagnostic):
self.context.log_and_raise_if_error(
Expand Down
8 changes: 4 additions & 4 deletions torch/onnx/_internal/diagnostics/_diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import contextlib
import gzip
from collections.abc import Generator
from typing import List, Optional, Type
from typing import List, Optional

import torch

Expand Down Expand Up @@ -113,7 +113,6 @@ def create_diagnostic_context(
name: str,
version: str,
options: Optional[infra.DiagnosticOptions] = None,
diagnostic_type: Type[infra.Diagnostic] = infra.Diagnostic,
) -> infra.DiagnosticContext:
"""Creates a new diagnostic context.
Expand All @@ -127,7 +126,9 @@ def create_diagnostic_context(
"""
if options is None:
options = infra.DiagnosticOptions()
context = infra.DiagnosticContext(name, version, options)
context: infra.DiagnosticContext[infra.Diagnostic] = infra.DiagnosticContext(
name, version, options
)
self.contexts.append(context)
return context

Expand Down Expand Up @@ -179,7 +180,6 @@ def create_export_diagnostic_context() -> (
_context = engine.create_diagnostic_context(
"torch.onnx.export",
torch.__version__,
diagnostic_type=TorchScriptOnnxExportDiagnostic,
)
try:
yield _context
Expand Down
37 changes: 24 additions & 13 deletions torch/onnx/_internal/diagnostics/infra/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@

import logging

from typing import Callable, Generator, List, Literal, Mapping, Optional, TypeVar
from typing import (
Callable,
Generator,
Generic,
List,
Literal,
Mapping,
Optional,
Type,
TypeVar,
)

from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import formatter, sarif, utils
Expand Down Expand Up @@ -260,20 +270,21 @@ def __init__(self, diagnostic: Diagnostic):


@dataclasses.dataclass
class DiagnosticContext:
class DiagnosticContext(Generic[_Diagnostic]):
name: str
version: str
options: infra.DiagnosticOptions = dataclasses.field(
default_factory=infra.DiagnosticOptions
)
diagnostics: List[Diagnostic] = dataclasses.field(init=False, default_factory=list)
diagnostics: List[_Diagnostic] = dataclasses.field(init=False, default_factory=list)
# TODO(bowbao): Implement this.
# _invocation: infra.Invocation = dataclasses.field(init=False)
_inflight_diagnostics: List[Diagnostic] = dataclasses.field(
_inflight_diagnostics: List[_Diagnostic] = dataclasses.field(
init=False, default_factory=list
)
_previous_log_level: int = dataclasses.field(init=False, default=logging.WARNING)
logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
_bound_diagnostic_type: Type = dataclasses.field(init=False, default=Diagnostic)

def __enter__(self):
self._previous_log_level = self.logger.level
Expand Down Expand Up @@ -318,7 +329,7 @@ def dump(self, file_path: str, compress: bool = False) -> None:
with open(file_path, "w") as f:
f.write(self.to_json())

def log(self, diagnostic: Diagnostic) -> None:
def log(self, diagnostic: _Diagnostic) -> None:
"""Logs a diagnostic.
This method should be used only after all the necessary information for the diagnostic
Expand All @@ -327,15 +338,15 @@ def log(self, diagnostic: Diagnostic) -> None:
Args:
diagnostic: The diagnostic to add.
"""
if not isinstance(diagnostic, Diagnostic):
if not isinstance(diagnostic, self._bound_diagnostic_type):
raise TypeError(
f"Expected diagnostic of type {Diagnostic}, got {type(diagnostic)}"
f"Expected diagnostic of type {self._bound_diagnostic_type}, got {type(diagnostic)}"
)
if self.options.warnings_as_errors and diagnostic.level == infra.Level.WARNING:
diagnostic.level = infra.Level.ERROR
self.diagnostics.append(diagnostic)

def log_and_raise_if_error(self, diagnostic: Diagnostic) -> None:
def log_and_raise_if_error(self, diagnostic: _Diagnostic) -> None:
"""Logs a diagnostic and raises an exception if it is an error.
Use this method for logging non inflight diagnostics where diagnostic level is not known or
Expand All @@ -357,8 +368,8 @@ def log_and_raise_if_error(self, diagnostic: Diagnostic) -> None:

@contextlib.contextmanager
def add_inflight_diagnostic(
self, diagnostic: Diagnostic
) -> Generator[Diagnostic, None, None]:
self, diagnostic: _Diagnostic
) -> Generator[_Diagnostic, None, None]:
"""Adds a diagnostic to the context.
Use this method to add diagnostics that are not created by the context.
Expand All @@ -371,7 +382,7 @@ def add_inflight_diagnostic(
finally:
self._inflight_diagnostics.pop()

def push_inflight_diagnostic(self, diagnostic: Diagnostic) -> None:
def push_inflight_diagnostic(self, diagnostic: _Diagnostic) -> None:
"""Pushes a diagnostic to the inflight diagnostics stack.
Args:
Expand All @@ -382,15 +393,15 @@ def push_inflight_diagnostic(self, diagnostic: Diagnostic) -> None:
"""
self._inflight_diagnostics.append(diagnostic)

def pop_inflight_diagnostic(self) -> Diagnostic:
def pop_inflight_diagnostic(self) -> _Diagnostic:
"""Pops the last diagnostic from the inflight diagnostics stack.
Returns:
The popped diagnostic.
"""
return self._inflight_diagnostics.pop()

def inflight_diagnostic(self, rule: Optional[infra.Rule] = None) -> Diagnostic:
def inflight_diagnostic(self, rule: Optional[infra.Rule] = None) -> _Diagnostic:
if rule is None:
# TODO(bowbao): Create builtin-rules and create diagnostic using that.
if len(self._inflight_diagnostics) <= 0:
Expand Down
5 changes: 4 additions & 1 deletion torch/onnx/_internal/fx/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,11 @@ def log(self, level: int, message: str, *args, **kwargs) -> None:


@dataclasses.dataclass
class DiagnosticContext(infra.DiagnosticContext):
class DiagnosticContext(infra.DiagnosticContext[Diagnostic]):
logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
_bound_diagnostic_type: type[Diagnostic] = dataclasses.field(
init=False, default=Diagnostic
)

def __enter__(self):
self._previous_log_level = self.logger.level
Expand Down

0 comments on commit bf197bf

Please sign in to comment.