Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Register list/tuple/dict to format_argumment and refactor fx.Node format_argument in diagnostics #105263

Closed
wants to merge 7 commits into from
2 changes: 1 addition & 1 deletion .ci/docker/common/install_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pip_install \
transformers==4.25.1

# TODO: change this when onnx-script is on testPypi
pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@2bb3e9f2d094912f81cb63cecb412efb14c65738"
pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@ab54cb17feb256ca52fa4c616e27e06b3ce67139"
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved

# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
Expand Down
9 changes: 3 additions & 6 deletions test/onnx/test_fx_op_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,17 +330,14 @@
dtypes=(torch.uint8, torch.int8, torch.int16,),
reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "int8, int16"),
),
xfail(
"nn.functional.adaptive_avg_pool1d",
reason=onnx_test_common.reason_onnx_script_does_not_support("aten.index.Tensor"),
),
xfail(
"nn.functional.adaptive_avg_pool2d",
reason=onnx_test_common.reason_onnx_script_does_not_support("aten.index.Tensor"),
reason=onnx_test_common.reason_dynamo_does_not_support("RecursionError: maximum recursion depth exceeded \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decompose pass.. probably our bug not dynamo's?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm getting this error onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Node (AveragePool_0) Op (AveragePool) [ShapeInferenceError] Attribute strides has incorrect size ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just repro. Maybe try the latest commit of onnx-script. I had an update on avg_pool in torchlib.

while calling a Python object in Decompose pass"),
),
xfail(
"nn.functional.adaptive_avg_pool3d",
reason=onnx_test_common.reason_onnx_script_does_not_support("aten.index.Tensor"),
reason=onnx_test_common.reason_onnx_script_does_not_support("aten._adaptive_avg_pool3d.default"),
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
),
xfail(
"nn.functional.conv1d",
Expand Down
121 changes: 102 additions & 19 deletions torch/onnx/_internal/fx/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import decorator, formatter, utils

_LENGTH_LIMIT: int = 89
from torch.onnx._internal.fx import type_utils as fx_type_utils

# NOTE: Symbolic shapes could be a calculation of values, such as
# Tensor(i64[s0, 64, (s1//2) - 2, (s1//2) - 2]) where s0 and s1 are symbolic
# so we need to relax the length limit.
_LENGTH_LIMIT: int = 120
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved

# NOTE(bowbao): This is a shim over `torch.onnx._internal.diagnostics`, which is
# used in `torch.onnx`, and loaded with `torch`. Hence anything related to `onnxscript`
Expand All @@ -31,21 +36,23 @@ def format_argument(obj: Any) -> str:
formatter = _format_argument.dispatch(type(obj))
result_str = formatter(obj)

if len(result_str) > _LENGTH_LIMIT:
# TODO(bowbao): group diagnostics.
# Related fields of sarif.Result: occurance_count, fingerprints.
# Do a final process to group results before outputing sarif log.
diag = infra.Diagnostic(
*diagnostics.rules.arg_format_too_verbose.format(
level=infra.levels.WARNING,
length=len(result_str),
length_limit=_LENGTH_LIMIT,
argument_type=type(obj),
formatter_type=type(format_argument),
result_str_lines = result_str.splitlines()
for line in result_str_lines:
if len(line) > _LENGTH_LIMIT:
# TODO(bowbao): group diagnostics.
# Related fields of sarif.Result: occurance_count, fingerprints.
# Do a final process to group results before outputing sarif log.
diag = infra.Diagnostic(
*diagnostics.rules.arg_format_too_verbose.format(
level=infra.levels.WARNING,
length=len(result_str),
length_limit=_LENGTH_LIMIT,
argument_type=type(obj),
formatter_type=type(format_argument),
)
)
)
diag.with_location(utils.function_location(formatter))
diagnostics.export_context().log(diag)
diag.with_location(utils.function_location(formatter))
diagnostics.export_context().log(diag)

return result_str

Expand All @@ -71,12 +78,52 @@ def _torch_fx_graph_module(obj: torch.fx.GraphModule) -> str:

@_format_argument.register
def _torch_fx_node(obj: torch.fx.Node) -> str:
return f"fx.Node({obj.name}[{obj.op}]'{obj.target}')"
node_string = f"fx.Node({obj.target})[{obj.op}]:"
if "val" not in obj.meta:
return node_string + "None"
return node_string + _format_nested_argument_by_dtype(obj.meta["val"])


@_format_argument.register
def _torch_fx_symbolic_value(
obj, # NOTE: singledispatch does not support Union until 3.11, so we use Any here.
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
) -> str:
return f"Sym({obj})"


@_format_argument.register
def _torch_tensor(obj: torch.Tensor) -> str:
return f"Tensor(shape={obj.shape}, dtype={obj.dtype})"
return f"Tensor({fx_type_utils.from_torch_dtype_to_abbr(obj.dtype)}{_stringify_shape(obj.shape)})"


@_format_argument.register
def _list(obj: list) -> str:
list_string = f"List[length={len(obj)}](\n"
if not obj:
return list_string + "None)"
for item in obj:
list_string += f"{_format_nested_argument_by_dtype(item)},\n"
return list_string + ")"


@_format_argument.register
def _tuple(obj: tuple) -> str:
tuple_string = f"Tuple[length={len(obj)}](\n"
if not obj:
return tuple_string + "None)"
for item in obj:
tuple_string += f"{_format_nested_argument_by_dtype(item)},\n"
return tuple_string + ")"


@_format_argument.register
def _dict(obj: dict) -> str:
dict_string = f"Dict[length={len(obj)}](\n"
if not obj:
return dict_string + "None)"
for key, value in obj.items():
dict_string += f"{key}: {_format_nested_argument_by_dtype(value)},\n"
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
return dict_string + ")"


@_format_argument.register
Expand All @@ -86,15 +133,51 @@ def _torch_nn_parameter(obj: torch.nn.Parameter) -> str:

@_format_argument.register
def _onnxscript_torch_script_tensor(obj: graph_building.TorchScriptTensor) -> str:
# TODO(bowbao) obj.dtype throws error.
return f"`TorchScriptTensor({obj.name}, {obj.onnx_dtype}, {obj.shape}, {obj.symbolic_value()})`"
return f"`TorchScriptTensor({fx_type_utils.from_torch_dtype_to_abbr(obj.dtype)}{_stringify_shape(obj.shape)})`"


@_format_argument.register
def _onnxscript_onnx_function(obj: onnxscript.OnnxFunction) -> str:
return f"`OnnxFunction({obj.name})`"


@_format_argument.register
def _onnxscript_traced_onnx_function(obj: onnxscript.TracedOnnxFunction) -> str:
return f"`TracedOnnxFunction({obj.name})`"


# from torch/fx/graph.py to follow torch format
def _stringify_shape(shape: Optional[torch.Size]) -> str:
if shape is None:
return ""
return f"[{', '.join(str(x) for x in shape)}]"


def _format_nested_argument_by_dtype(obj: Any) -> str:
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
"""Dispatch to the correct formatter based on the type of the argument."""
if isinstance(obj, torch.Tensor):
return _torch_tensor(obj)
if isinstance(obj, torch.nn.Parameter):
return _torch_nn_parameter(obj)
if isinstance(obj, torch.fx.Node):
return _torch_fx_node(obj)
if fx_type_utils.is_torch_symbolic_type(obj):
return _torch_fx_symbolic_value(obj)
if isinstance(obj, graph_building.TorchScriptTensor):
return _onnxscript_torch_script_tensor(obj)
if isinstance(obj, onnxscript.OnnxFunction):
return _onnxscript_onnx_function(obj)
if isinstance(obj, onnxscript.TracedOnnxFunction):
return _onnxscript_traced_onnx_function(obj)
if isinstance(obj, list):
return _list(obj)
if isinstance(obj, tuple):
return _tuple(obj)
if isinstance(obj, dict):
return _dict(obj)
return format_argument(obj)


diagnose_call = functools.partial(
decorator.diagnose_call,
diagnostic_type=diagnostics.ExportDiagnostic,
Expand Down
2 changes: 1 addition & 1 deletion torch/onnx/_internal/fx/op_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _fx_args_to_torch_args(
wrapped_args.append(real_tensor)
elif isinstance(fake_tensor, (int, float, bool)):
wrapped_args.append(fake_tensor)
elif isinstance(fake_tensor, (torch.SymBool, torch.SymInt, torch.SymFloat)):
elif fx_type_utils.is_torch_symbolic_type(fake_tensor):
raise ValueError(
f"Unexpected input argument Sym type found inside fx.Node. arg: {arg}; "
f"arg.meta['static_shape']: {fake_tensor}; type(arg.meta['static_shape']): "
Expand Down
26 changes: 26 additions & 0 deletions torch/onnx/_internal/fx/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ def from_python_type_to_onnx_attribute_type(
return _PYTHON_TYPE_TO_ONNX_ATTRIBUTE_TYPE.get(dtype)


def is_torch_symbolic_type(t: Any) -> bool:
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
return isinstance(t, (torch.SymBool, torch.SymInt, torch.SymFloat))


def from_torch_dtype_to_abbr(dtype: Optional[torch.dtype]) -> str:
if dtype is None:
return ""
return _TORCH_DTYPE_TO_ABBREVIATION.get(dtype, "")


# NOTE: this is a mapping from torch dtype to a set of compatible onnx types
# It's used in dispatcher to find the best match overload for the input dtypes
_TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS: Dict[
Expand Down Expand Up @@ -93,6 +103,22 @@ def from_python_type_to_onnx_attribute_type(
torch.complex128: torch.float64, # NOTE: ORT doesn't support torch.float64
}

_TORCH_DTYPE_TO_ABBREVIATION = {
torch.bfloat16: "bf16",
torch.float64: "f64",
torch.float32: "f32",
torch.float16: "f16",
torch.complex32: "c32",
torch.complex64: "c64",
torch.complex128: "c128",
torch.int8: "i8",
torch.int16: "i16",
torch.int32: "i32",
torch.int64: "i64",
torch.bool: "b8",
torch.uint8: "u8",
}

# NOTE: Belows are from torch/fx/node.py
BaseArgumentTypes = Union[
str,
Expand Down
8 changes: 8 additions & 0 deletions torch/onnx/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,14 @@ def from_value(
return cls.from_dtype(value.type().getElementType().dtype())
except RuntimeError:
return cls._from_name(str(value.type().getElementType()))
if isinstance(value.type(), torch._C.OptionalType):
if value.type().getElementType().dtype() is None:
if isinstance(default, JitScalarType):
return default
raise errors.OnnxExporterError(
"default value must be a JitScalarType object."
)
return cls.from_dtype(value.type().getElementType().dtype())

scalar_type = None
if value.node().kind() != "prim::Constant" or not isinstance(
Expand Down