Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Register list/tuple/dict to format_argumment and refactor fx.Node format_argument in diagnostics #105263

Closed
wants to merge 7 commits into from
104 changes: 85 additions & 19 deletions torch/onnx/_internal/fx/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import decorator, formatter, utils

_LENGTH_LIMIT: int = 89
from torch.onnx._internal.fx import type_utils as fx_type_utils

# NOTE: Symbolic shapes could be a calculation of values, such as
# Tensor(i64[s0, 64, (s1//2) - 2, (s1//2) - 2]) where s0 and s1 are symbolic
# so we need to relax the length limit.
_LENGTH_LIMIT: int = 120
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved

# NOTE(bowbao): This is a shim over `torch.onnx._internal.diagnostics`, which is
# used in `torch.onnx`, and loaded with `torch`. Hence anything related to `onnxscript`
Expand All @@ -31,21 +36,23 @@ def format_argument(obj: Any) -> str:
formatter = _format_argument.dispatch(type(obj))
result_str = formatter(obj)

if len(result_str) > _LENGTH_LIMIT:
# TODO(bowbao): group diagnostics.
# Related fields of sarif.Result: occurance_count, fingerprints.
# Do a final process to group results before outputing sarif log.
diag = infra.Diagnostic(
*diagnostics.rules.arg_format_too_verbose.format(
level=infra.levels.WARNING,
length=len(result_str),
length_limit=_LENGTH_LIMIT,
argument_type=type(obj),
formatter_type=type(format_argument),
result_str_lines = result_str.splitlines()
for line in result_str_lines:
if len(line) > _LENGTH_LIMIT:
# TODO(bowbao): group diagnostics.
# Related fields of sarif.Result: occurance_count, fingerprints.
# Do a final process to group results before outputing sarif log.
diag = infra.Diagnostic(
*diagnostics.rules.arg_format_too_verbose.format(
level=infra.levels.WARNING,
length=len(result_str),
length_limit=_LENGTH_LIMIT,
argument_type=type(obj),
formatter_type=type(format_argument),
)
)
)
diag.with_location(utils.function_location(formatter))
diagnostics.export_context().log(diag)
diag.with_location(utils.function_location(formatter))
diagnostics.export_context().log(diag)

return result_str

Expand All @@ -71,12 +78,60 @@ def _torch_fx_graph_module(obj: torch.fx.GraphModule) -> str:

@_format_argument.register
def _torch_fx_node(obj: torch.fx.Node) -> str:
return f"fx.Node({obj.name}[{obj.op}]'{obj.target}')"
node_string = f"fx.Node({obj.target})[{obj.op}]:"
if "val" not in obj.meta:
return node_string + "None"
return node_string + format_argument(obj.meta["val"])


@_format_argument.register
def _torch_fx_symbolic_bool(obj: torch.SymBool) -> str:
return f"SymBool({obj})"


@_format_argument.register
def _torch_fx_symbolic_int(obj: torch.SymInt) -> str:
return f"SymInt({obj})"


@_format_argument.register
def _torch_fx_symbolic_float(obj: torch.SymFloat) -> str:
return f"SymFloat({obj})"


@_format_argument.register
def _torch_tensor(obj: torch.Tensor) -> str:
return f"Tensor(shape={obj.shape}, dtype={obj.dtype})"
return f"Tensor({fx_type_utils.from_torch_dtype_to_abbr(obj.dtype)}{_stringify_shape(obj.shape)})"


@_format_argument.register
def _list(obj: list) -> str:
list_string = f"List[length={len(obj)}](\n"
if not obj:
return list_string + "None)"
for item in obj:
list_string += f"{format_argument(item)},\n"
return list_string + ")"


@_format_argument.register
def _tuple(obj: tuple) -> str:
tuple_string = f"Tuple[length={len(obj)}](\n"
if not obj:
return tuple_string + "None)"
for item in obj:
tuple_string += f"{format_argument(item)},\n"
return tuple_string + ")"


@_format_argument.register
def _dict(obj: dict) -> str:
dict_string = f"Dict[length={len(obj)}](\n"
if not obj:
return dict_string + "None)"
for key, value in obj.items():
dict_string += f"{key}: {format_argument(value)},\n"
return dict_string + ")"


@_format_argument.register
Expand All @@ -86,15 +141,26 @@ def _torch_nn_parameter(obj: torch.nn.Parameter) -> str:

@_format_argument.register
def _onnxscript_torch_script_tensor(obj: graph_building.TorchScriptTensor) -> str:
# TODO(bowbao) obj.dtype throws error.
return f"`TorchScriptTensor({obj.name}, {obj.onnx_dtype}, {obj.shape}, {obj.symbolic_value()})`"
return f"`TorchScriptTensor({fx_type_utils.from_torch_dtype_to_abbr(obj.dtype)}{_stringify_shape(obj.shape)})`"


@_format_argument.register
def _onnxscript_onnx_function(obj: onnxscript.OnnxFunction) -> str:
return f"`OnnxFunction({obj.name})`"


@_format_argument.register
def _onnxscript_traced_onnx_function(obj: onnxscript.TracedOnnxFunction) -> str:
return f"`TracedOnnxFunction({obj.name})`"


# from torch/fx/graph.py to follow torch format
def _stringify_shape(shape: Optional[torch.Size]) -> str:
if shape is None:
return ""
return f"[{', '.join(str(x) for x in shape)}]"


diagnose_call = functools.partial(
decorator.diagnose_call,
diagnostic_type=diagnostics.ExportDiagnostic,
Expand Down
2 changes: 1 addition & 1 deletion torch/onnx/_internal/fx/op_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _fx_args_to_torch_args(
wrapped_args.append(real_tensor)
elif isinstance(fake_tensor, (int, float, bool)):
wrapped_args.append(fake_tensor)
elif isinstance(fake_tensor, (torch.SymBool, torch.SymInt, torch.SymFloat)):
elif fx_type_utils.is_torch_symbolic_type(fake_tensor):
raise ValueError(
f"Unexpected input argument Sym type found inside fx.Node. arg: {arg}; "
f"arg.meta['static_shape']: {fake_tensor}; type(arg.meta['static_shape']): "
Expand Down
26 changes: 26 additions & 0 deletions torch/onnx/_internal/fx/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ def from_python_type_to_onnx_attribute_type(
return _PYTHON_TYPE_TO_ONNX_ATTRIBUTE_TYPE.get(dtype)


def is_torch_symbolic_type(value: Any) -> bool:
return isinstance(value, (torch.SymBool, torch.SymInt, torch.SymFloat))


def from_torch_dtype_to_abbr(dtype: Optional[torch.dtype]) -> str:
if dtype is None:
return ""
return _TORCH_DTYPE_TO_ABBREVIATION.get(dtype, "")


# NOTE: this is a mapping from torch dtype to a set of compatible onnx types
# It's used in dispatcher to find the best match overload for the input dtypes
_TORCH_DTYPE_TO_COMPATIBLE_ONNX_TYPE_STRINGS: Dict[
Expand Down Expand Up @@ -93,6 +103,22 @@ def from_python_type_to_onnx_attribute_type(
torch.complex128: torch.float64, # NOTE: ORT doesn't support torch.float64
}

_TORCH_DTYPE_TO_ABBREVIATION = {
torch.bfloat16: "bf16",
torch.float64: "f64",
torch.float32: "f32",
torch.float16: "f16",
torch.complex32: "c32",
torch.complex64: "c64",
torch.complex128: "c128",
torch.int8: "i8",
torch.int16: "i16",
torch.int32: "i32",
torch.int64: "i64",
torch.bool: "b8",
torch.uint8: "u8",
}

# NOTE: Belows are from torch/fx/node.py
BaseArgumentTypes = Union[
str,
Expand Down
8 changes: 8 additions & 0 deletions torch/onnx/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,14 @@ def from_value(
return cls.from_dtype(value.type().getElementType().dtype())
except RuntimeError:
return cls._from_name(str(value.type().getElementType()))
if isinstance(value.type(), torch._C.OptionalType):
if value.type().getElementType().dtype() is None:
if isinstance(default, JitScalarType):
return default
raise errors.OnnxExporterError(
"default value must be a JitScalarType object."
)
return cls.from_dtype(value.type().getElementType().dtype())

scalar_type = None
if value.node().kind() != "prim::Constant" or not isinstance(
Expand Down