Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions test/onnx/test_fx_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def func(x, y):
self.assertNotIsInstance(tensor_x, fake_tensor.FakeTensor)
self.assertNotIsInstance(tensor_y, fake_tensor.FakeTensor)

def test_mnist(self):
def test_mnist_exported_with_no_warnings_on_get_attr_node_in_op_level_debug(self):
class MNISTModel(nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -105,7 +105,19 @@ def forward(self, tensor_x: torch.Tensor):
return output

tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32)
_ = dynamo_export(MNISTModel(), tensor_x, export_options=self.export_options)
export_output = dynamo_export(
MNISTModel(), tensor_x, export_options=ExportOptions(op_level_debug=True)
)

# NOTE: This additional test makes sure that op level debug supports `get_attr`
# fx.Node, also known as weight in PyTorch. aten.convolution.default is one of
# the nodes that has weight attribute.
assert_has_diagnostics(
export_output.diagnostic_context,
diagnostics.rules.op_level_debugging,
diagnostics.levels.NONE,
expected_node="aten.convolution.default",
)

def test_trace_only_op_with_evaluator(self):
model_input = torch.tensor([[1.0, 2.0, 3.0], [1.0, 1.0, 2.0]])
Expand Down
9 changes: 0 additions & 9 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,9 +978,6 @@ def _test_fake_tensor_mode_exporter(
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
torch.testing.assert_close(ref_output, torch.tensor(ort_output))

@pytorch_test_common.skip_op_level_debug_test(
"https://github.com/pytorch/pytorch/issues/105490"
)
def test_fake_tensor_mode_simple(self):
def create_model() -> nn.Module:
class Model(torch.nn.Module):
Expand Down Expand Up @@ -1009,9 +1006,6 @@ def create_kwargs():
export_within_fake_mode=self.export_within_fake_mode,
)

@pytorch_test_common.skip_op_level_debug_test(
"https://github.com/pytorch/pytorch/issues/105490"
)
def test_large_scale_exporter_with_tiny_gpt2(self):
model_name = "sshleifer/tiny-gpt2"

Expand All @@ -1037,9 +1031,6 @@ def create_kwargs():
export_within_fake_mode=self.export_within_fake_mode,
)

@pytorch_test_common.skip_op_level_debug_test(
"https://github.com/pytorch/pytorch/issues/105490"
)
def test_large_scale_exporter_with_toy_mlp(self):
class MLPModel(nn.Module):
def __init__(self):
Expand Down
8 changes: 0 additions & 8 deletions torch/onnx/_internal/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,14 +887,6 @@ def pre_export_passes(
# Insert type casts explicitly where needed.
module = passes.InsertTypePromotion(diagnostic_context, module).run()

# Run ShapeInferenceWithFakeTensor to get static shape of nodes for op_level_debug purposes
# The pass added nodes with static shape into original node metadata:
# node.meta["static_shape"]: FakeTensor/int/float/SymInt/SynFloat
if options.op_level_debug:
module = passes.ShapeInferenceWithFakeTensor(diagnostic_context, module).run(
*fx_module_args
)

analysis.UnsupportedFxNodesAnalysis(
diagnostic_context, module, options.onnxfunction_dispatcher
).analyze(infra.levels.ERROR)
Expand Down
39 changes: 13 additions & 26 deletions torch/onnx/_internal/fx/fx_onnx_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def run_node(
fx_name_to_onnxscript_value,
onnxfunction_dispatcher,
op_level_debug,
fx_graph_module,
)
elif node.op == "call_method":
self.call_method(node)
Expand Down Expand Up @@ -520,6 +521,7 @@ def call_function(
],
onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher,
op_level_debug: bool,
fx_graph_module: torch.fx.GraphModule,
):
# aten ops and other stateless functions.
if node.target == operator.getitem and isinstance(
Expand All @@ -540,10 +542,10 @@ def call_function(

# Map FX inputs to ONNX inputs and fill optional inputs with default values.
# torch_args and torch_kwargs are for op-level validation
complete_args, complete_kwargs = _fill_in_default_kwargs(node)
fx_args, fx_kwargs = _fill_in_default_kwargs(node)
onnx_args, onnx_kwargs = _wrap_fx_args_as_onnxscript_args(
complete_args,
complete_kwargs,
fx_args,
fx_kwargs,
fx_name_to_onnxscript_value,
onnxscript_tracer,
)
Expand Down Expand Up @@ -582,29 +584,14 @@ def call_function(
and node.target != torch.ops.aten.sym_size
and not isinstance(node.target, types.BuiltinFunctionType)
):
(
node_with_fixed_shape_args,
node_with_fixed_shape_kwargs,
) = _fill_in_default_kwargs(node)
try:
torch_args, torch_kwargs = op_validation.wrap_fx_args_as_torch_args(
node_with_fixed_shape_args, node_with_fixed_shape_kwargs
)
except ValueError as value_error:
diagnostic = self.diagnostic_context.inflight_diagnostic()
diagnostic.with_additional_message(
f"### Op level debug fails due to unsupported input types\n"
f"{diagnostics.decorator.format_exception_in_markdown(value_error)}"
)
diagnostic.level = diagnostics.levels.ERROR
else:
op_validation.validate_op_between_ort_torch(
self.diagnostic_context,
node,
symbolic_fn,
torch_args,
torch_kwargs,
)
op_validation.validate_op_between_ort_torch(
self.diagnostic_context,
node,
symbolic_fn,
fx_args,
fx_kwargs,
fx_graph_module,
)
fx_name_to_onnxscript_value[node.name] = output

@_beartype.beartype
Expand Down
91 changes: 68 additions & 23 deletions torch/onnx/_internal/fx/op_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import torch
import torch.fx

from torch.fx.experimental import symbolic_shapes
from torch.onnx import _constants, _type_utils as jit_type_utils
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import (
Expand Down Expand Up @@ -43,8 +45,9 @@ def validate_op_between_ort_torch(
diagnostic_context: diagnostics.DiagnosticContext,
node: torch.fx.Node,
symbolic_fn: Union[onnxscript.OnnxFunction, onnxscript.TracedOnnxFunction],
torch_args: List[fx_type_utils.Argument],
torch_kwargs: Dict[str, fx_type_utils.Argument],
fx_args: List[fx_type_utils.Argument],
fx_kwargs: Dict[str, fx_type_utils.Argument],
fx_graph_module: torch.fx.GraphModule,
):
"""Validate the op between ONNX Runtime and PyTorch.

Expand All @@ -62,10 +65,24 @@ def validate_op_between_ort_torch(
symbolic_fn (Union[onnxscript.OnnxFunction, onnxscript.TracedOnnxFunction]): The corresponded ONNX node
torch_args (list): torch argument inputs
torch_kwargs (dict): torch keyword argument inputs
fx_graph_module (torch.fx.GraphModule): The fx.GraphModule that contains the nodes
"""
# op-level validation
# Symbolic_fn should have the same output as node.target (torch ops)

try:
torch_args, torch_kwargs = _wrap_fx_args_as_torch_args(
fx_args, fx_kwargs, fx_graph_module
)
except ValueError as value_error:
diagnostic = diagnostic_context.inflight_diagnostic()
diagnostic.with_additional_message(
f"### Op level debug fails due to unsupported input types\n"
f"{diagnostics.decorator.format_exception_in_markdown(value_error)}"
)
diagnostic.level = diagnostics.levels.WARNING
return

with evaluator.default_as(evaluator.ort_evaluator):
try:
expected_outputs = node.target(*torch_args, **torch_kwargs) # type: ignore[operator]
Expand Down Expand Up @@ -165,8 +182,37 @@ def validate_op_between_ort_torch(
diagnostic.level = diagnostics.levels.WARNING


@_beartype.beartype
def _convert_symint_to_int_in_shape(shape: torch.Size) -> torch.Size:
"""Convert SymInt to int in shape

Args:
shape (torch.Size): The shape of a tensor
Raises:
ValueError: When SymInt is found in shape
Returns:
torch.Size: The shape of a tensor with SymInt converted to int

"""
list_int_shape = []
for dim in shape:
if isinstance(dim, torch.SymInt):
if symbolic_shapes.has_hint(dim):
list_int_shape.append(symbolic_shapes.hint_int(dim))
else:
raise ValueError(
f"An unbacked SymInt found in shape. SymInt: {dim}; "
f"torch.Size: {shape}. There is no hint for SymInt."
)
else:
list_int_shape.append(dim)
return torch.Size(list_int_shape)


@_beartype.beartype
def generate_random_tensors(shape: torch.Size, dtype: torch.dtype):
shape = _convert_symint_to_int_in_shape(shape)

if dtype == torch.uint8:
return torch.randint(
low=_constants.UINT8_MIN, high=_constants.UINT8_MAX, size=shape, dtype=dtype
Expand Down Expand Up @@ -197,38 +243,35 @@ def generate_random_tensors(shape: torch.Size, dtype: torch.dtype):

@_beartype.beartype
def _fx_args_to_torch_args(
complete_args: List[fx_type_utils.Argument],
fx_args: List[fx_type_utils.Argument], fx_graph_module: torch.fx.GraphModule
) -> List[fx_type_utils.Argument]:
"""Recursively convert fx args to torch args"""
wrapped_args: List[fx_type_utils.Argument] = []
for arg in complete_args:
for arg in fx_args:
if isinstance(arg, torch.fx.Node):
# NOTE(titaiwang): The arg type here should align to the type handled in
# shape.inference.FakeTensorPropGetStaticShapes. Currently, we are aware
# of FakeTensor/Tensor/SymInt/SymFloat/Symbool/int/float/bool could be in
# arg.meta["static_shape"].
fake_tensor = arg.meta.get("static_shape", None)
fake_tensor = arg.meta.get("val")
if fake_tensor is None and arg.op == "get_attr":
fake_tensor = getattr(fx_graph_module, arg.target) # type: ignore[operator]
# NOTE: Currently, we are aware of
# FakeTensor/Tensor/SymInt/SymFloat/Symbool/int/float/bool could be in
# arg.meta["val"]/get_attr.
if isinstance(fake_tensor, torch.Tensor):
real_tensor = generate_random_tensors(
fake_tensor.shape, fake_tensor.dtype
)
wrapped_args.append(real_tensor)
elif isinstance(fake_tensor, (int, float, bool)):
wrapped_args.append(fake_tensor)
elif fx_type_utils.is_torch_symbolic_type(fake_tensor):
raise ValueError(
f"Unexpected input argument Sym type found inside fx.Node. arg: {arg}; "
f"arg.meta['static_shape']: {fake_tensor}; type(arg.meta['static_shape']): "
f"{type(fake_tensor)}. Sym type is not supported in op_level_debug."
)
elif symbolic_shapes.has_hint(fake_tensor):
wrapped_args.append(symbolic_shapes.hint_int(fake_tensor))
else:
raise ValueError(
f"Unexpected input argument type found inside fx.Node. arg: {arg}; "
f"arg.meta['static_shape']: {fake_tensor}; type(arg.meta['static_shape']): "
f"arg.meta['val']/get_attr: {fake_tensor}; type(arg.meta['val']/get_attr): "
f"{type(fake_tensor)}."
)
elif isinstance(arg, Sequence):
wrapped_args.append(_fx_args_to_torch_args(arg))
wrapped_args.append(_fx_args_to_torch_args(arg, fx_graph_module))
elif isinstance(arg, (int, float, torch.dtype)) or arg is None:
wrapped_args.append(arg)
elif isinstance(arg, torch.device):
Expand All @@ -242,16 +285,18 @@ def _fx_args_to_torch_args(


@_beartype.beartype
def wrap_fx_args_as_torch_args(
complete_args: List[fx_type_utils.Argument],
complete_kwargs: Dict[str, fx_type_utils.Argument],
def _wrap_fx_args_as_torch_args(
fx_args: List[fx_type_utils.Argument],
fx_kwargs: Dict[str, fx_type_utils.Argument],
fx_graph_module: torch.fx.GraphModule,
) -> Tuple[List[fx_type_utils.Argument], Dict[str, fx_type_utils.Argument]]:
"""Prepare torch format args and kwargs for op-level validation by using fake tensor to create real tensor to feed in ops"""

# NOTE: This function only supports FakeTensor with concrete shapes
torch_args: List[fx_type_utils.Argument] = _fx_args_to_torch_args(complete_args)
torch_kwargs = complete_kwargs
return torch_args, torch_kwargs
torch_args: List[fx_type_utils.Argument] = _fx_args_to_torch_args(
fx_args, fx_graph_module
)
return torch_args, fx_kwargs


# NOTE: Referenced from onnxscript internal function: _tag_arguments_with_param_schemas.
Expand Down
2 changes: 0 additions & 2 deletions torch/onnx/_internal/fx/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .decomp import Decompose
from .functionalization import Functionalize, RemoveInputMutation
from .readability import RestoreParameterAndBufferNames
from .shape_inference import ShapeInferenceWithFakeTensor
from .type_promotion import InsertTypePromotion
from .virtualization import MovePlaceholderToFront, ReplaceGetAttrWithPlaceholder

Expand All @@ -13,5 +12,4 @@
"RemoveInputMutation",
"RestoreParameterAndBufferNames",
"ReplaceGetAttrWithPlaceholder",
"ShapeInferenceWithFakeTensor",
]
Loading