Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[user errors] compulsory case names, allow multiple #110878

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,7 @@ def check_signature_rewritable(graph):
"like its value to be embedded as an exported constant, wrap its access "
"in a function marked with @assume_constant_result.\n\n"
+ "\n\n".join(input_errors),
case_names=[],
)


Expand Down Expand Up @@ -1289,7 +1290,7 @@ def graph_with_interpreter(*args):
raise UserError(
UserErrorType.DYNAMIC_CONTROL_FLOW,
str(e),
case_name="cond_operands",
case_names=["cond_predicate", "cond_operands"],
)

if same_signature:
Expand Down
29 changes: 13 additions & 16 deletions torch/_dynamo/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,22 @@
import textwrap
from enum import auto, Enum
from traceback import extract_stack, format_exc, format_list, StackSummary
from typing import cast, Optional
from typing import cast, List, Optional

import torch._guards

from . import config
from .config import is_fbcode

from .utils import counters

if is_fbcode():
from torch.fb.exportdb.logging import exportdb_error_message
else:

def exportdb_error_message(case_name):
return (
"For more information about this error, see: "
+ "https://pytorch.org/docs/main/generated/exportdb/index.html#"
+ case_name.replace("_", "-")
)
def exportdb_error_message(case_names):
case_names_str = ", ".join(
"https://pytorch.org/docs/main/generated/exportdb/index.html#"
+ case_name.replace("_", "-")
for case_name in case_names
)
return f"For more information about this error, see: {case_names_str}"


import logging
Expand Down Expand Up @@ -127,22 +124,22 @@ class UserErrorType(Enum):


class UserError(Unsupported):
def __init__(self, error_type: UserErrorType, msg, case_name=None):
def __init__(self, error_type: UserErrorType, msg: str, case_names: List[str]):
"""
Type of errors that would be valid in Eager, but not supported in TorchDynamo.
The error message should tell user about next actions.

error_type: Type of user error
msg: Actionable error message
case_name: (Optional) Unique name (snake case) for the usage example in exportdb.
case_name: Unique names (snake case) for relevant examples in exportdb.
"""
if case_name is not None:
assert isinstance(case_name, str)
if case_names:
assert all(isinstance(case_name, str) for case_name in case_names)
if msg.endswith("."):
msg += " "
else:
msg += "\n"
msg += exportdb_error_message(case_name)
msg += exportdb_error_message(case_names)
super().__init__(msg)
self.error_type = error_type
self.message = msg
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction):
exc.UserErrorType.DYNAMIC_CONTROL_FLOW,
"Dynamic control flow is not supported at the moment. Please use "
"functorch.experimental.control_flow.cond to explicitly capture the control flow.",
case_name="cond_operands",
case_names=["cond_predicate", "cond_operands"],
)

return inner
Expand Down
6 changes: 4 additions & 2 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1415,10 +1415,12 @@ def visit(n: torch.fx.Node):
"This can happen when we encounter unbounded dynamic value that is unknown during tracing time."
"You will need to explicitly give hint to the compiler. Please take a look at "
"constrain_as_value OR constrain_as_size APIs",
case_name="constrain_as_size_example",
case_names=["constrain_as_value_example", "constrain_as_size_example"],
)
elif isinstance(cause, torch.utils._sympy.value_ranges.ValueRangeError):
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, e.args[0]) from e
raise UserError(
UserErrorType.CONSTRAINT_VIOLATION, e.args[0], case_names=[]
) from e
raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None


Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def call_function(
UserErrorType.STANDARD_LIBRARY,
"Calling round() on symbolic value is not supported. "
"You can use floor() to implement this functionality",
case_name="dynamic_shape_round",
case_names=["dynamic_shape_round"],
)
return super().call_function(tx, args, kwargs)

Expand Down Expand Up @@ -1273,7 +1273,7 @@ def call_type(self, tx, obj: VariableTracker):
UserErrorType.ANTI_PATTERN,
f"Can't call type() on generated custom object {obj}. "
"Please use __class__ instead",
case_name="type_reflection_method",
case_names=["type_reflection_method"],
)

def call_reversed(self, tx, obj: VariableTracker):
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def const_getattr(self, tx, name):
UserErrorType.ANTI_PATTERN,
"Can't access members of type(obj) for a generated custom object. "
"Please use __class__ instead",
case_name="type_reflection_method",
case_names=["type_reflection_method"],
)
member = getattr(self.value, name)
if callable(member):
Expand Down
3 changes: 3 additions & 0 deletions torch/_dynamo/variables/higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,13 +802,15 @@ def to_python_ints(argnums):
raise UserError(
UserErrorType.INVALID_INPUT,
f"argnums is expected to be int or tuple of ints. Got {argnums}.",
case_names=[],
)

if isinstance(argnums, ConstantVariable):
if not isinstance(argnums.value, (int, tuple)):
raise UserError(
UserErrorType.INVALID_INPUT,
f"argnums is expected to be int or tuple of ints. Got {argnums}.",
case_names=[],
)
return argnums.value
else:
Expand All @@ -820,6 +822,7 @@ def to_python_ints(argnums):
raise UserError(
UserErrorType.INVALID_INPUT,
f"argnums is expected to contain int only. Got {const_vars}.",
case_names=[],
)
return tuple(var.value for var in const_vars)

Expand Down
35 changes: 26 additions & 9 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,13 @@ def tree_zip(combined_args, dynamic_shapes):
UserErrorType.INVALID_INPUT,
f"Expected dynamic_shapes of a {type(combined_args)} to be a Sequence, "
f"got {dynamic_shapes} instead",
case_names=[],
)
if len(combined_args) != len(dynamic_shapes):
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expected {dynamic_shapes} to have {len(combined_args)} items",
case_names=[],
)
for i, shape in enumerate(dynamic_shapes):
yield from tree_zip(combined_args[i], shape)
Expand All @@ -118,11 +120,13 @@ def tree_zip(combined_args, dynamic_shapes):
UserErrorType.INVALID_INPUT,
f"Expected dynamic_shapes of a {type(combined_args)} to be a Mapping, "
f"got {dynamic_shapes} instead",
case_names=[],
)
if len(combined_args) != len(dynamic_shapes):
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expected {dynamic_shapes} to have {len(combined_args)} items",
case_names=[],
)
for k, shape in dynamic_shapes.items():
yield from tree_zip(combined_args[k], shape)
Expand All @@ -132,6 +136,7 @@ def tree_zip(combined_args, dynamic_shapes):
UserErrorType.INVALID_INPUT,
f"Expected dynamic_shapes of a {type(combined_args)} to be a {type(combined_args)}, "
f"got {dynamic_shapes} instead",
case_names=[],
)
for f in dataclasses.fields(combined_args):
yield from tree_zip(getattr(combined_args, f.name), getattr(dynamic_shapes, f.name))
Expand All @@ -143,6 +148,7 @@ def tree_zip(combined_args, dynamic_shapes):
UserErrorType.INVALID_INPUT,
f"Expected dynamic_shapes of a {type(combined_args)} to be None, "
f"got {dynamic_shapes} instead",
case_names=[],
)

from collections import defaultdict
Expand All @@ -167,6 +173,7 @@ def update_symbols(tensor, shape):
UserErrorType.INVALID_INPUT,
f"Unexpected item #{i} ({dim}) in dynamic_shape {shape} of Tensor, "
"try None instead",
case_names=[],
)
elif isinstance(shape, (tuple, list)):
for i, dim in enumerate(shape):
Expand All @@ -178,13 +185,15 @@ def update_symbols(tensor, shape):
UserErrorType.INVALID_INPUT,
f"Unexpected item #{i} ({dim}) in dynamic_shape {shape} of Tensor, "
"try None instead",
case_names=[],
)
else:
if shape is not None:
raise UserError(
UserErrorType.INVALID_INPUT,
f"Unexpected dynamic_shape {shape} of Tensor, "
"try None instead",
case_names=[],
)

if isinstance(f, ExportedProgram):
Expand Down Expand Up @@ -216,20 +225,23 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None):
if not isinstance(t, torch.Tensor):
raise UserError(
UserErrorType.DYNAMIC_DIM,
f"Expected tensor as input to dynamic_dim but got {type(t)}"
f"Expected tensor as input to dynamic_dim but got {type(t)}",
case_names=[],
)

if t.dim() < 1:
raise UserError(
UserErrorType.DYNAMIC_DIM,
"Cannot mark 0-dimension tensors to be dynamic"
"Cannot mark 0-dimension tensors to be dynamic",
case_names=[],
)

if index >= t.dim():
raise UserError(
UserErrorType.DYNAMIC_DIM,
f"Expected the dimension passed to dynamic_dim to be in the range [0:{t.dim()-1}]"
f" but got {index}, which is out of bounds for the given tensor."
f" but got {index}, which is out of bounds for the given tensor.",
case_names=[],
)

return _create_constraint(
Expand Down Expand Up @@ -489,16 +501,20 @@ def _export(
kwargs = kwargs or {}

if not isinstance(args, tuple):
raise UserError(UserErrorType.INVALID_INPUT,
f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}")
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}",
case_names=[],
)

# We convert to nn.Module because __call__ of ExportedProgram
# is untracable right now.
if isinstance(f, ExportedProgram):
if len(constraints) > 0:
raise UserError(
UserErrorType.INVALID_INPUT,
"Cannot provide constraints for already exported program."
"Cannot provide constraints for already exported program.",
case_names=[],
)
f = f.module()

Expand All @@ -510,7 +526,8 @@ def _export(
if len(constraints) > 0:
raise UserError(
UserErrorType.INVALID_INPUT,
"Cannot provide constraints for already exported program."
"Cannot provide constraints for already exported program.",
case_names=[],
)
gm_torch_level = f
else:
Expand All @@ -525,12 +542,12 @@ def _export(
**kwargs,
)
except (ConstraintViolationError, ValueRangeError) as e:
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e), case_names=[])
except GuardOnDataDependentSymNode as e:
raise UserError(
UserErrorType.ANTI_PATTERN,
f"Consider annotating your code using constrain_as_*(). {str(e)}",
case_name="constrain_as_size_example",
case_names=["constrain_as_value_example", "constrain_as_size_example"],
)

params_buffers: Dict[str, Union[torch.Tensor, torch.nn.Parameter]] = {}
Expand Down
Loading