Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Set 'Generic[Diagnostic]' as base class for 'DiagnosticContext' #107165

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 24 additions & 1 deletion test/onnx/internal/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,18 @@ def test_diagnostic_log_emit_correctly_formatted_string(self):
)
self.assertIn("hello world", diagnostic.additional_messages)

def test_log_diagnostic_to_diagnostic_context_raises_when_diagnostic_type_is_wrong(
self,
):
with self.diagnostic_context:
# Dynamo onnx exporter diagnostic context expects fx_diagnostics.Diagnostic
# instead of base infra.Diagnostic.
diagnostic = infra.Diagnostic(
self.rules.rule_without_message_args, infra.Level.NOTE
)
with self.assertRaises(TypeError):
self.diagnostic_context.log(diagnostic)


class TestTorchScriptOnnxDiagnostics(common_utils.TestCase):
"""Test cases for diagnostics emitted by the TorchScript ONNX export code."""
Expand Down Expand Up @@ -353,7 +365,9 @@ class TestDiagnosticsInfra(common_utils.TestCase):
def setUp(self):
self.rules = _RuleCollectionForTest()
with contextlib.ExitStack() as stack:
self.context = stack.enter_context(infra.DiagnosticContext("test", "1.0.0"))
self.context: infra.DiagnosticContext[
infra.Diagnostic
] = stack.enter_context(infra.DiagnosticContext("test", "1.0.0"))
self.addCleanup(stack.pop_all().close)
return super().setUp()

Expand Down Expand Up @@ -591,6 +605,15 @@ def test_diagnostic_log_source_exception_emits_exception_traceback_and_error_mes
self.assertIn("ValueError: original exception", diagnostic_message)
self.assertIn("Traceback (most recent call last):", diagnostic_message)

def test_log_diagnostic_to_diagnostic_context_raises_when_diagnostic_type_is_wrong(
self,
):
with self.context:
with self.assertRaises(TypeError):
# The method expects 'Diagnostic' or its subclasses as arguments.
# Passing any other type will trigger a TypeError.
self.context.log("This is a str message.")

def test_diagnostic_context_raises_if_diagnostic_is_error(self):
with self.assertRaises(infra.RuntimeErrorWithDiagnostic):
self.context.log_and_raise_if_error(
Expand Down
8 changes: 4 additions & 4 deletions torch/onnx/_internal/diagnostics/_diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import contextlib
import gzip
from collections.abc import Generator
from typing import List, Optional, Type
from typing import List, Optional

import torch

Expand Down Expand Up @@ -113,7 +113,6 @@ def create_diagnostic_context(
name: str,
version: str,
options: Optional[infra.DiagnosticOptions] = None,
diagnostic_type: Type[infra.Diagnostic] = infra.Diagnostic,
) -> infra.DiagnosticContext:
"""Creates a new diagnostic context.

Expand All @@ -127,7 +126,9 @@ def create_diagnostic_context(
"""
if options is None:
options = infra.DiagnosticOptions()
context = infra.DiagnosticContext(name, version, options)
context: infra.DiagnosticContext[infra.Diagnostic] = infra.DiagnosticContext(
name, version, options
)
self.contexts.append(context)
return context

Expand Down Expand Up @@ -179,7 +180,6 @@ def create_export_diagnostic_context() -> (
_context = engine.create_diagnostic_context(
"torch.onnx.export",
torch.__version__,
diagnostic_type=TorchScriptOnnxExportDiagnostic,
)
try:
yield _context
Expand Down
37 changes: 24 additions & 13 deletions torch/onnx/_internal/diagnostics/infra/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@

import logging

from typing import Callable, Generator, List, Literal, Mapping, Optional, TypeVar
from typing import (
Callable,
Generator,
Generic,
List,
Literal,
Mapping,
Optional,
Type,
TypeVar,
)

from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import formatter, sarif, utils
Expand Down Expand Up @@ -260,20 +270,21 @@ def __init__(self, diagnostic: Diagnostic):


@dataclasses.dataclass
class DiagnosticContext:
class DiagnosticContext(Generic[_Diagnostic]):
name: str
version: str
options: infra.DiagnosticOptions = dataclasses.field(
default_factory=infra.DiagnosticOptions
)
diagnostics: List[Diagnostic] = dataclasses.field(init=False, default_factory=list)
diagnostics: List[_Diagnostic] = dataclasses.field(init=False, default_factory=list)
# TODO(bowbao): Implement this.
# _invocation: infra.Invocation = dataclasses.field(init=False)
_inflight_diagnostics: List[Diagnostic] = dataclasses.field(
_inflight_diagnostics: List[_Diagnostic] = dataclasses.field(
init=False, default_factory=list
)
_previous_log_level: int = dataclasses.field(init=False, default=logging.WARNING)
logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
_bound_diagnostic_type: Type = dataclasses.field(init=False, default=Diagnostic)

def __enter__(self):
self._previous_log_level = self.logger.level
Expand Down Expand Up @@ -318,7 +329,7 @@ def dump(self, file_path: str, compress: bool = False) -> None:
with open(file_path, "w") as f:
f.write(self.to_json())

def log(self, diagnostic: Diagnostic) -> None:
def log(self, diagnostic: _Diagnostic) -> None:
"""Logs a diagnostic.

This method should be used only after all the necessary information for the diagnostic
Expand All @@ -327,15 +338,15 @@ def log(self, diagnostic: Diagnostic) -> None:
Args:
diagnostic: The diagnostic to add.
"""
if not isinstance(diagnostic, Diagnostic):
if not isinstance(diagnostic, self._bound_diagnostic_type):
raise TypeError(
f"Expected diagnostic of type {Diagnostic}, got {type(diagnostic)}"
f"Expected diagnostic of type {self._bound_diagnostic_type}, got {type(diagnostic)}"
)
if self.options.warnings_as_errors and diagnostic.level == infra.Level.WARNING:
diagnostic.level = infra.Level.ERROR
self.diagnostics.append(diagnostic)

def log_and_raise_if_error(self, diagnostic: Diagnostic) -> None:
def log_and_raise_if_error(self, diagnostic: _Diagnostic) -> None:
"""Logs a diagnostic and raises an exception if it is an error.

Use this method for logging non inflight diagnostics where diagnostic level is not known or
Expand All @@ -357,8 +368,8 @@ def log_and_raise_if_error(self, diagnostic: Diagnostic) -> None:

@contextlib.contextmanager
def add_inflight_diagnostic(
self, diagnostic: Diagnostic
) -> Generator[Diagnostic, None, None]:
self, diagnostic: _Diagnostic
) -> Generator[_Diagnostic, None, None]:
"""Adds a diagnostic to the context.

Use this method to add diagnostics that are not created by the context.
Expand All @@ -371,7 +382,7 @@ def add_inflight_diagnostic(
finally:
self._inflight_diagnostics.pop()

def push_inflight_diagnostic(self, diagnostic: Diagnostic) -> None:
def push_inflight_diagnostic(self, diagnostic: _Diagnostic) -> None:
"""Pushes a diagnostic to the inflight diagnostics stack.

Args:
Expand All @@ -382,15 +393,15 @@ def push_inflight_diagnostic(self, diagnostic: Diagnostic) -> None:
"""
self._inflight_diagnostics.append(diagnostic)

def pop_inflight_diagnostic(self) -> Diagnostic:
def pop_inflight_diagnostic(self) -> _Diagnostic:
"""Pops the last diagnostic from the inflight diagnostics stack.

Returns:
The popped diagnostic.
"""
return self._inflight_diagnostics.pop()

def inflight_diagnostic(self, rule: Optional[infra.Rule] = None) -> Diagnostic:
def inflight_diagnostic(self, rule: Optional[infra.Rule] = None) -> _Diagnostic:
if rule is None:
# TODO(bowbao): Create builtin-rules and create diagnostic using that.
if len(self._inflight_diagnostics) <= 0:
Expand Down
5 changes: 4 additions & 1 deletion torch/onnx/_internal/fx/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,11 @@ def log(self, level: int, message: str, *args, **kwargs) -> None:


@dataclasses.dataclass
class DiagnosticContext(infra.DiagnosticContext):
class DiagnosticContext(infra.DiagnosticContext[Diagnostic]):
logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
_bound_diagnostic_type: type[Diagnostic] = dataclasses.field(
init=False, default=Diagnostic
)

def __enter__(self):
self._previous_log_level = self.logger.level
Expand Down