Skip to content

Commit

Permalink
[ONNX] Register list/tuple/dict to format_argumment and refactor fx.N…
Browse files Browse the repository at this point in the history
…ode format_argument in diagnostics

ghstack-source-id: 0c6507f7624380ce7983069b0cae855621a526c4
Pull Request resolved: #105263
  • Loading branch information
titaiwangms committed Jul 19, 2023
1 parent f139aab commit 7dd9b09
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .ci/docker/common/install_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pip_install \
transformers==4.25.1

# TODO: change this when onnx-script is on testPypi
pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@2bb3e9f2d094912f81cb63cecb412efb14c65738"
pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@ab54cb17feb256ca52fa4c616e27e06b3ce67139"

# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
Expand Down
9 changes: 3 additions & 6 deletions test/onnx/test_fx_op_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,17 +330,14 @@
dtypes=(torch.uint8, torch.int8, torch.int16,),
reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "int8, int16"),
),
xfail(
"nn.functional.adaptive_avg_pool1d",
reason=onnx_test_common.reason_onnx_script_does_not_support("aten.index.Tensor"),
),
xfail(
"nn.functional.adaptive_avg_pool2d",
reason=onnx_test_common.reason_onnx_script_does_not_support("aten.index.Tensor"),
reason=onnx_test_common.reason_dynamo_does_not_support("RecursionError: maximum recursion depth exceeded \
while calling a Python object in Decompose pass"),
),
xfail(
"nn.functional.adaptive_avg_pool3d",
reason=onnx_test_common.reason_onnx_script_does_not_support("aten.index.Tensor"),
reason=onnx_test_common.reason_onnx_script_does_not_support("aten._adaptive_avg_pool3d.default"),
),
xfail(
"nn.functional.conv1d",
Expand Down
121 changes: 102 additions & 19 deletions torch/onnx/_internal/fx/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import decorator, formatter, utils

_LENGTH_LIMIT: int = 89
from torch.onnx._internal.fx import type_utils as fx_type_utils

# NOTE: Symbolic shapes could be a calculation of values, such as
# Tensor(i64[s0, 64, (s1//2) - 2, (s1//2) - 2]) where s0 and s1 are symbolic
# so we need to relax the length limit.
_LENGTH_LIMIT: int = 120

# NOTE(bowbao): This is a shim over `torch.onnx._internal.diagnostics`, which is
# used in `torch.onnx`, and loaded with `torch`. Hence anything related to `onnxscript`
Expand All @@ -31,21 +36,23 @@ def format_argument(obj: Any) -> str:
formatter = _format_argument.dispatch(type(obj))
result_str = formatter(obj)

if len(result_str) > _LENGTH_LIMIT:
# TODO(bowbao): group diagnostics.
# Related fields of sarif.Result: occurance_count, fingerprints.
# Do a final process to group results before outputing sarif log.
diag = infra.Diagnostic(
*diagnostics.rules.arg_format_too_verbose.format(
level=infra.levels.WARNING,
length=len(result_str),
length_limit=_LENGTH_LIMIT,
argument_type=type(obj),
formatter_type=type(format_argument),
result_str_lines = result_str.splitlines()
for line in result_str_lines:
if len(line) > _LENGTH_LIMIT:
# TODO(bowbao): group diagnostics.
# Related fields of sarif.Result: occurance_count, fingerprints.
# Do a final process to group results before outputing sarif log.
diag = infra.Diagnostic(
*diagnostics.rules.arg_format_too_verbose.format(
level=infra.levels.WARNING,
length=len(result_str),
length_limit=_LENGTH_LIMIT,
argument_type=type(obj),
formatter_type=type(format_argument),
)
)
)
diag.with_location(utils.function_location(formatter))
diagnostics.export_context().log(diag)
diag.with_location(utils.function_location(formatter))
diagnostics.export_context().log(diag)

return result_str

Expand All @@ -71,12 +78,52 @@ def _torch_fx_graph_module(obj: torch.fx.GraphModule) -> str:

@_format_argument.register
def _torch_fx_node(obj: torch.fx.Node) -> str:
return f"fx.Node({obj.name}[{obj.op}]'{obj.target}')"
node_string = f"fx.Node({obj.target})[{obj.op}]:"
if "val" not in obj.meta:
return node_string + "None"
return node_string + _format_nested_argument_by_dtype(obj.meta["val"])


@_format_argument.register
def _torch_fx_symbolic_value(
obj, # NOTE: functools.singledispatch does not support Union until 3.11, so we use Any here.
) -> str:
return f"Sym({obj})"


@_format_argument.register
def _torch_tensor(obj: torch.Tensor) -> str:
return f"Tensor(shape={obj.shape}, dtype={obj.dtype})"
return f"Tensor({fx_type_utils.from_torch_dtype_to_abbr(obj.dtype)}{_stringify_shape(obj.shape)})"


@_format_argument.register
def _list(obj: list) -> str:
list_string = f"List[length={len(obj)}](\n"
if not obj:
return list_string + "None)"
for item in obj:
list_string += f"{_format_nested_argument_by_dtype(item)},\n"
return list_string + ")"


@_format_argument.register
def _tuple(obj: tuple) -> str:
tuple_string = f"Tuple[length={len(obj)}](\n"
if not obj:
return tuple_string + "None)"
for item in obj:
tuple_string += f"{_format_nested_argument_by_dtype(item)},\n"
return tuple_string + ")"


@_format_argument.register
def _dict(obj: dict) -> str:
dict_string = f"Dict[length={len(obj)}](\n"
if not obj:
return dict_string + "None)"
for key, value in obj.items():
dict_string += f"{key}: {_format_nested_argument_by_dtype(value)},\n"
return dict_string + ")"


@_format_argument.register
Expand All @@ -86,15 +133,51 @@ def _torch_nn_parameter(obj: torch.nn.Parameter) -> str:

@_format_argument.register
def _onnxscript_torch_script_tensor(obj: graph_building.TorchScriptTensor) -> str:
# TODO(bowbao) obj.dtype throws error.
return f"`TorchScriptTensor({obj.name}, {obj.onnx_dtype}, {obj.shape}, {obj.symbolic_value()})`"
return f"`TorchScriptTensor({fx_type_utils.from_torch_dtype_to_abbr(obj.dtype)}{_stringify_shape(obj.shape)})`"


@_format_argument.register
def _onnxscript_onnx_function(obj: onnxscript.OnnxFunction) -> str:
return f"`OnnxFunction({obj.name})`"


@_format_argument.register
def _onnxscript_traced_onnx_function(obj: onnxscript.TracedOnnxFunction) -> str:
return f"`TracedOnnxFunction({obj.name})`"


# from torch/fx/graph.py to follow torch format
def _stringify_shape(shape: Optional[torch.Size]) -> str:
if shape is None:
return ""
return f"[{', '.join(str(x) for x in shape)}]"


def _format_nested_argument_by_dtype(obj: Any) -> str:
"""Dispatch to the correct formatter based on the type of the argument."""
if isinstance(obj, torch.Tensor):
return _torch_tensor(obj)
if isinstance(obj, torch.nn.Parameter):
return _torch_nn_parameter(obj)
if isinstance(obj, torch.fx.Node):
return _torch_fx_node(obj)
if fx_type_utils.is_torch_symbolic_type(obj):
return _torch_fx_symbolic_value(obj)
if isinstance(obj, graph_building.TorchScriptTensor):
return _onnxscript_torch_script_tensor(obj)
if isinstance(obj, onnxscript.OnnxFunction):
return _onnxscript_onnx_function(obj)
if isinstance(obj, onnxscript.TracedOnnxFunction):
return _onnxscript_traced_onnx_function(obj)
if isinstance(obj, list):
return _list(obj)
if isinstance(obj, tuple):
return _tuple(obj)
if isinstance(obj, dict):
return _dict(obj)
return format_argument(obj)


diagnose_call = functools.partial(
decorator.diagnose_call,
diagnostic_type=diagnostics.ExportDiagnostic,
Expand Down
2 changes: 1 addition & 1 deletion torch/onnx/_internal/fx/op_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _fx_args_to_torch_args(
wrapped_args.append(real_tensor)
elif isinstance(fake_tensor, (int, float, bool)):
wrapped_args.append(fake_tensor)
elif isinstance(fake_tensor, (torch.SymBool, torch.SymInt, torch.SymFloat)):
elif fx_type_utils.is_torch_symbolic_type(fake_tensor):
raise ValueError(
f"Unexpected input argument Sym type found inside fx.Node. arg: {arg}; "
f"arg.meta['static_shape']: {fake_tensor}; type(arg.meta['static_shape']): "
Expand Down
26 changes: 26 additions & 0 deletions torch/onnx/_internal/fx/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ def from_python_type_to_onnx_attribute_type(
return _PYTHON_TYPE_TO_ONNX_ATTRIBUTE_TYPE.get(dtype)


def is_torch_symbolic_type(t: Any) -> bool:
return isinstance(t, (torch.SymBool, torch.SymInt, torch.SymFloat))


def from_torch_dtype_to_abbr(dtype: Optional[torch.dtype]) -> str:
if dtype is None:
return ""
return _TORCH_DTYPE_TO_ABBREVIATION.get(dtype, "")


# NOTE: this is a mapping from torch dtype to a set of compatible onnx types
# It's used in dispatcher to find the best match overload for the input dtypes
_TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS: Dict[
Expand Down Expand Up @@ -93,6 +103,22 @@ def from_python_type_to_onnx_attribute_type(
torch.complex128: torch.float64, # NOTE: ORT doesn't support torch.float64
}

_TORCH_DTYPE_TO_ABBREVIATION = {
torch.bfloat16: "bf16",
torch.float64: "f64",
torch.float32: "f32",
torch.float16: "f16",
torch.complex32: "c32",
torch.complex64: "c64",
torch.complex128: "c128",
torch.int8: "i8",
torch.int16: "i16",
torch.int32: "i32",
torch.int64: "i64",
torch.bool: "b8",
torch.uint8: "u8",
}

# NOTE: Belows are from torch/fx/node.py
BaseArgumentTypes = Union[
str,
Expand Down
8 changes: 8 additions & 0 deletions torch/onnx/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,14 @@ def from_value(
return cls.from_dtype(value.type().getElementType().dtype())
except RuntimeError:
return cls._from_name(str(value.type().getElementType()))
if isinstance(value.type(), torch._C.OptionalType):
if value.type().getElementType().dtype() is None:
if isinstance(default, JitScalarType):
return default
raise errors.OnnxExporterError(
"default value must be a JitScalarType object."
)
return cls.from_dtype(value.type().getElementType().dtype())

scalar_type = None
if value.node().kind() != "prim::Constant" or not isinstance(
Expand Down

0 comments on commit 7dd9b09

Please sign in to comment.