diff --git a/test/onnx/dynamo/test_registry_dispatcher.py b/test/onnx/dynamo/test_registry_dispatcher.py index 4f56cc84a6930..427a06ecbbfdc 100644 --- a/test/onnx/dynamo/test_registry_dispatcher.py +++ b/test/onnx/dynamo/test_registry_dispatcher.py @@ -13,7 +13,13 @@ from onnxscript import BFLOAT16, DOUBLE, FLOAT, FLOAT16 # type: ignore[import] from onnxscript.function_libs.torch_lib import ops # type: ignore[import] from onnxscript.onnx_opset import opset15 as op # type: ignore[import] -from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher, registration +from torch.onnx._internal.diagnostics import infra +from torch.onnx._internal.fx import ( + analysis, + diagnostics, + onnxfunction_dispatcher, + registration, +) from torch.testing._internal import common_utils # TODO: this can only be global. https://github.com/microsoft/onnxscript/issues/805 @@ -77,6 +83,60 @@ def test_custom(x, y): [test_original, test_custom], ) + def test_unsupported_nodes_analysis_with_missing_aten_op(self): + # NOTE: simulate unsupported nodes + aten_mul_tensor = registration.OpName.from_name_parts( + namespace="aten", op_name="mul", overload="Tensor" + ) + aten_mul_default = registration.OpName.from_name_parts( + namespace="aten", op_name="mul" + ) + aten_add_tensor = registration.OpName.from_name_parts( + namespace="aten", op_name="add", overload="Tensor" + ) + aten_add_default = registration.OpName.from_name_parts( + namespace="aten", op_name="add" + ) + + self.registry._registry.pop(aten_mul_tensor) + self.registry._registry.pop(aten_mul_default) + self.registry._registry.pop(aten_add_tensor) + self.registry._registry.pop(aten_add_default) + + diagnostic_context = diagnostics.DiagnosticContext( + "torch.onnx.dynamo_export", torch.__version__ + ) + dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher( + self.registry, diagnostic_context + ) + + graph: torch.fx.Graph = torch.fx.Graph() + x: torch.fx.Node = graph.create_node("placeholder", "x") + x.meta["val"] = torch.tensor(3.0) + b: torch.fx.Node = graph.create_node( + "call_function", target=torch.ops.aten.mul.Tensor, args=(x, x) + ) + c: torch.fx.Node = graph.create_node( + "call_function", target=torch.ops.aten.add.Tensor, args=(b, b) + ) + output: torch.fx.Node = graph.output(c) + module = torch.fx.GraphModule(torch.nn.Module(), graph) + + with self.assertRaises(infra.RuntimeErrorWithDiagnostic): + analysis.UnsupportedFxNodesAnalysis( + diagnostic_context, module, dispatcher + ).analyze(infra.levels.ERROR) + + try: + analysis.UnsupportedFxNodesAnalysis( + diagnostic_context, module, dispatcher + ).analyze(infra.levels.ERROR) + except infra.RuntimeErrorWithDiagnostic as e: + self.assertIn( + "Unsupported FX nodes: {'call_function': ['aten.mul.Tensor', 'aten.add.Tensor']}.", + e.diagnostic.message, + ) + @common_utils.instantiate_parametrized_tests class TestDispatcher(common_utils.TestCase): diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index ec527e4dbd2e4..5427bdb4a2b29 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -266,6 +266,48 @@ def forward(self, input): expected_node="aten.clone.default", ) + def test_missing_complex_onnx_variant_raises_errors_in_dispatcher(self): + registry = torch.onnx.OnnxRegistry() + + # NOTE: simulate unsupported nodes + aten_mul_tensor = registration.OpName.from_name_parts( + namespace="aten", op_name="mul", overload="Tensor" + ) + + # Only keep real aten.mul to test missing complex aten.mul + registry._registry[aten_mul_tensor] = [ + onnx_func + for onnx_func in registry._registry[aten_mul_tensor] + if not onnx_func.is_complex + ] + + class TraceModel(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.mul.Tensor(input, input) + + x = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64) + + with self.assertRaises(torch.onnx.OnnxExporterError) as e: + torch.onnx.dynamo_export( + TraceModel(), + x, + export_options=torch.onnx.ExportOptions(onnx_registry=registry), + ) + + try: + torch.onnx.dynamo_export( + TraceModel(), + x, + export_options=torch.onnx.ExportOptions(onnx_registry=registry), + ) + except torch.onnx.OnnxExporterError as e: + assert_has_diagnostics( + e.onnx_program.diagnostic_context, + diagnostics.rules.no_symbolic_function_for_call_function, + diagnostics.levels.ERROR, + expected_node="aten.mul.Tensor", + ) + def test_dynamo_export_retains_readable_parameter_and_buffer_names(self): class SubModule(torch.nn.Module): def __init__(self): diff --git a/torch/onnx/_internal/fx/analysis/unsupported_nodes.py b/torch/onnx/_internal/fx/analysis/unsupported_nodes.py index 93bdc8014f33f..5da0dbed3d919 100644 --- a/torch/onnx/_internal/fx/analysis/unsupported_nodes.py +++ b/torch/onnx/_internal/fx/analysis/unsupported_nodes.py @@ -1,10 +1,9 @@ from __future__ import annotations import dataclasses -from typing import Dict, List +from typing import Dict -import torch -from torch.onnx._internal.fx import _pass, diagnostics +from torch.onnx._internal.fx import _pass, diagnostics, registration @dataclasses.dataclass @@ -52,23 +51,35 @@ def analyze( RuntimeErrorWithDiagnostic: If diagnostics are emitted and the diagnostic level is `ERROR`. """ - unsupported_nodes: List[torch.fx.Node] = [] + + op_to_target_mapping: Dict[str, Dict[str, None]] = {} for node in self.module.graph.nodes: if node.op == "call_function": - try: - # NOTE: OPSchema matcher is not in this analysis scope. - self.onnxfunction_dispatcher.get_function_overloads( - node, self.diagnostic_context + # NOTE: OPSchema matcher is not in this analysis scope. + internal_opname: registration.OpName = ( + self.onnxfunction_dispatcher._get_aten_name( + node=node, diagnostic_context=self.diagnostic_context + ) + ) + overload_registration = ( + self.onnxfunction_dispatcher.onnx_registry.is_registered_op( + namespace=internal_opname.namespace, + op_name=internal_opname.op_name, + overload=internal_opname.overload, + ) + ) + # NOTE: Fall back to default overload if the ONNX registry doesn't have the overload. + default_registration = ( + self.onnxfunction_dispatcher.onnx_registry.is_registered_op( + namespace=internal_opname.namespace, + op_name=internal_opname.op_name, + overload=None, + ) + ) + if not overload_registration and not default_registration: + op_to_target_mapping.setdefault(node.op, {}).setdefault( + str(node.target), None ) - except diagnostics.RuntimeErrorWithDiagnostic as e: - unsupported_nodes.append(node) - - op_to_target_mapping: Dict[str, Dict[str, None]] = {} - - for node in unsupported_nodes: - op = node.op - target = node.target - op_to_target_mapping.setdefault(op, {}).setdefault(str(target), None) analysis_result = UnsupportedFxNodesAnalysisResult(op_to_target_mapping) self._lint(analysis_result, diagnostic_level)