Skip to content

Commit

Permalink
Update base for Update on "Extend SampleInput str representation with…
Browse files Browse the repository at this point in the history
… tensor data."

As in the title. The aim of this addition is to make debugging certain CI failures (that cannot be reproduced locally) easier. For instance, currently we see messages like
```
Exception: Caused by sample input at index 0: SampleInput(input=Tensor[size=(20,), device="cuda:0", dtype=torch.float64], args=(), kwargs={}, broadcasts_input=False, name='')
```
that is not really useful (as all those sample parameters can often be detected by other means) without showing actual sample data. The sample data can then be related to the `index` part in the error messages like:
```
Mismatched elements: 2 / 20 (10.0%)
Greatest absolute difference: nan at index (10,) (up to 1e-05 allowed)
Greatest relative difference: nan at index (10,) (up to 1e-07 allowed)
```

As an example of usefulness of this PR, consider the following failure message:
```
inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCPU::test_comprehensive_polygamma_polygamma_n_0_cpu_int32 ('RERUN', {'yellow': True}) [1.5510s] [ 70%]
inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCPU::test_comprehensive_polygamma_polygamma_n_0_cpu_int32 ('RERUN', {'yellow': True}) [0.0473s] [ 70%]
inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCPU::test_comprehensive_polygamma_polygamma_n_0_cpu_int32 FAILED [0.0493s] [ 70%]

==================================== RERUNS ====================================
__ TestInductorOpInfoCPU.test_comprehensive_polygamma_polygamma_n_0_cpu_int32 __
Traceback (most recent call last):
<snip>
AssertionError: Tensor-likes are not close!

Mismatched elements: 9 / 25 (36.0%)
Greatest absolute difference: inf at index (0, 0) (up to 1e-05 allowed), inf vs 20177651499008.0
Greatest relative difference: inf at index (0, 0) (up to 1.3e-06 allowed)

The above exception was the direct cause of the following exception:

<snip>
Exception: Caused by sample input at index 0: SampleInput(input=Tensor[size=(5, 5), device="cpu", dtype=torch.int32, data=[-8, 6, 9, 0, 0, 5, 5, 7, 6, 5, 1, -5, 2, -1, 8, -4, 0, -6, 3, -5]], args=(1), kwargs={}, broadcasts_input=False, name='')
```
from which we learn that `torch.polygamma` result is actually correct because `polygamma(0, -8) -> inf` while the used reference value (20177651499008.0) is wrong (see #106692 for more details).





[ghstack-poisoned]
  • Loading branch information
pearu committed Feb 10, 2024
2 parents 1f1e62a + 57d8f67 commit 5371022
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 131 deletions.
6 changes: 3 additions & 3 deletions test/inductor/test_compiled_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def wrapped(self: EagerAutogradTests):
"test_mark_non_differentiable_mixed", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertTrue
"test_materialize_grads", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {}
"test_naughty_autograd_function_stashing_ctx", # torch._dynamo.exc.TorchRuntimeError: Failed running call_function
"test_no_grad_copy", # torch._dynamo.exc.Unsupported: call_function args: TensorVariable() SkipFilesVariable()
"test_no_grad_copy", # torch._dynamo.exc.Unsupported: call_function args: TensorVariable() SkipFunctionVariable()
"test_no_grad_copy_sparse", # torch._dynamo.exc.Unsupported: Tensor.data_ptr
"test_reentrant_priority", # torch._dynamo.exc.InternalTorchDynamoError: '<' not supported between instances of
"test_reentrant_with_callbacks_both_depths", # torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable
Expand All @@ -827,10 +827,10 @@ def wrapped(self: EagerAutogradTests):
"test_return_leaf", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {}
"test_save_none_for_backward", # AssertionError:
"test_save_output_nr", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {}
"test_saved_variables_deprecated", # torch._dynamo.exc.Unsupported: UNPACK_SEQUENCE SkipFilesVariable()
"test_saved_variables_deprecated", # torch._dynamo.exc.Unsupported: UNPACK_SEQUENCE SkipFunctionVariable()
"test_set_materialize_non_diff_grads", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsNone
"test_setup_context_when_forward_has_default_args", # torch._dynamo.exc.Unsupported: call_function args
"test_simple_reentrant", # torch._dynamo.exc.Unsupported: call_method SkipFilesVariable() sum [] {}
"test_simple_reentrant", # torch._dynamo.exc.Unsupported: call_method SkipFunctionVariable() sum [] {}
"test_tensor_hooks_inplace_multiple_outputs", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {}
"test_lobpcg", # torch._dynamo.exc.Unsupported: 'call_function LOBPCGAutogradFunction.backward in skip_files
}
Expand Down
2 changes: 1 addition & 1 deletion test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def mps_ops_modifier(ops):
'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int8],
'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int64, torch.int8],
'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int64, torch.int8],
'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int8],
'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
from .variables.functions import (
BaseUserFunctionVariable,
NestedUserFunctionVariable,
SkipFunctionVariable,
UserFunctionVariable,
UserMethodVariable,
)
Expand All @@ -102,7 +103,6 @@
InlinedClosureVariable,
NullVariable,
PythonModuleVariable,
SkipFilesVariable,
UnknownVariable,
)
from .variables.nn_module import NNModuleVariable
Expand Down Expand Up @@ -2301,7 +2301,7 @@ def check_inlineable(func):
def inline_call_(
parent, func: VariableTracker, args: List[VariableTracker], kwargs
):
if isinstance(func, SkipFilesVariable):
if isinstance(func, SkipFunctionVariable):
unimplemented("inline with functions in skip files")
assert isinstance(
func,
Expand Down
90 changes: 45 additions & 45 deletions torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
BuiltinVariable,
FunctorchVmapHigherOrderVariable,
NestedUserFunctionVariable,
SkipFilesVariable,
SkipFunctionVariable,
TorchInGraphFunctionVariable,
UserFunctionVariable,
UserMethodVariable,
Expand All @@ -65,7 +65,7 @@
* TorchInGraphFunctionVariable: The functions should be put into the FX graph or can be constant folded. E.g.,
- torch.add: should be put into the FX graph.
- torch.is_floating_point: constant folded.
* SkipFilesVariable: The objects should be skipped from tracing.
* SkipFunctionVariable: The objects should be skipped from tracing.
* UserFunctionVariable: The functions should be inlined.
For developers: If you add/remove a torch level API, it may trigger failures from
Expand Down Expand Up @@ -100,29 +100,29 @@
"torch.overrides.get_default_nowrap_functions": TorchInGraphFunctionVariable,
"torch.fx._symbolic_trace.is_fx_tracing": TorchInGraphFunctionVariable,
"torch._dynamo.external_utils.is_compiling": TorchInGraphFunctionVariable,
"torch.autograd._profiler_enabled": SkipFilesVariable,
"torch.autograd._profiler_enabled": SkipFunctionVariable,
# We graph break on RNG state setters or getters like
# `torch.get_rng_state` or `torch.set_rng_state`. These functions
# are not aten operations and therefore they are completely ignored
# by the AOT dispatcher. As a result, the AOT graph does not have
# these setter or getter functions, producing an incorrect graph
# when it comes to rng states.
"torch.default_generator#get_state": SkipFilesVariable,
"torch._C.Generator#get_state": SkipFilesVariable,
"torch.get_rng_state": SkipFilesVariable,
"torch.cuda.get_rng_state": SkipFilesVariable,
"torch.default_generator#set_state": SkipFilesVariable,
"torch._C.Generator#set_state": SkipFilesVariable,
"torch.set_rng_state": SkipFilesVariable,
"torch.cuda.set_rng_state": SkipFilesVariable,
"torch.default_generator#get_state": SkipFunctionVariable,
"torch._C.Generator#get_state": SkipFunctionVariable,
"torch.get_rng_state": SkipFunctionVariable,
"torch.cuda.get_rng_state": SkipFunctionVariable,
"torch.default_generator#set_state": SkipFunctionVariable,
"torch._C.Generator#set_state": SkipFunctionVariable,
"torch.set_rng_state": SkipFunctionVariable,
"torch.cuda.set_rng_state": SkipFunctionVariable,
# https://github.com/pytorch/pytorch/issues/107187
"torch.manual_seed": SkipFilesVariable,
"torch.manual_seed": SkipFunctionVariable,
# https://github.com/pytorch/pytorch/issues/93501
"torch.nn.utils.rnn.pack_padded_sequence": SkipFilesVariable,
"torch.nn.utils.rnn.pack_padded_sequence": SkipFunctionVariable,
# https://github.com/pytorch/pytorch/issues/99569
"torch.nn.Parameter": SkipFilesVariable,
"torch._nested_tensor_from_mask": SkipFilesVariable,
"torch._nested_from_padded": SkipFilesVariable,
"torch.nn.Parameter": SkipFunctionVariable,
"torch._nested_tensor_from_mask": SkipFunctionVariable,
"torch._nested_from_padded": SkipFunctionVariable,
# symbol operators implemented in Python
"torch.sym_not": TorchInGraphFunctionVariable,
"torch.sym_float": TorchInGraphFunctionVariable,
Expand All @@ -131,28 +131,28 @@
"torch.sym_min": TorchInGraphFunctionVariable,
"torch.sym_sqrt": TorchInGraphFunctionVariable,
"torch.sym_ite": TorchInGraphFunctionVariable,
"torch.Tensor#_make_wrapper_subclass": SkipFilesVariable,
"torch.Tensor#__init__": SkipFilesVariable,
"torch.cuda.set_device": SkipFilesVariable,
"torch.cuda.current_device": SkipFilesVariable,
"torch._C.autocast_decrement_nesting": SkipFilesVariable,
"torch._C.autocast_increment_nesting": SkipFilesVariable,
"torch.autograd.grad": SkipFilesVariable,
"torch._C.clear_autocast_cache": SkipFilesVariable,
"torch.distributions.constraints.is_dependent": SkipFilesVariable,
"torch.jit.isinstance": SkipFilesVariable,
"torch._C.set_anomaly_enabled": SkipFilesVariable,
"torch._C.set_autocast_cache_enabled": SkipFilesVariable,
"torch._C.set_autocast_cpu_dtype": SkipFilesVariable,
"torch._C.set_autocast_cpu_enabled": SkipFilesVariable,
"torch._C.set_autocast_enabled": SkipFilesVariable,
"torch._C.set_autocast_gpu_dtype": SkipFilesVariable,
"torch._C.set_autocast_ipu_dtype": SkipFilesVariable,
"torch._C.set_autocast_ipu_enabled": SkipFilesVariable,
"torch._C.set_autocast_xla_dtype": SkipFilesVariable,
"torch._C.set_autocast_xla_enabled": SkipFilesVariable,
"torch.resize_as_": SkipFilesVariable,
"torch.resize_as_sparse_": SkipFilesVariable,
"torch.Tensor#_make_wrapper_subclass": SkipFunctionVariable,
"torch.Tensor#__init__": SkipFunctionVariable,
"torch.cuda.set_device": SkipFunctionVariable,
"torch.cuda.current_device": SkipFunctionVariable,
"torch._C.autocast_decrement_nesting": SkipFunctionVariable,
"torch._C.autocast_increment_nesting": SkipFunctionVariable,
"torch.autograd.grad": SkipFunctionVariable,
"torch._C.clear_autocast_cache": SkipFunctionVariable,
"torch.distributions.constraints.is_dependent": SkipFunctionVariable,
"torch.jit.isinstance": SkipFunctionVariable,
"torch._C.set_anomaly_enabled": SkipFunctionVariable,
"torch._C.set_autocast_cache_enabled": SkipFunctionVariable,
"torch._C.set_autocast_cpu_dtype": SkipFunctionVariable,
"torch._C.set_autocast_cpu_enabled": SkipFunctionVariable,
"torch._C.set_autocast_enabled": SkipFunctionVariable,
"torch._C.set_autocast_gpu_dtype": SkipFunctionVariable,
"torch._C.set_autocast_ipu_dtype": SkipFunctionVariable,
"torch._C.set_autocast_ipu_enabled": SkipFunctionVariable,
"torch._C.set_autocast_xla_dtype": SkipFunctionVariable,
"torch._C.set_autocast_xla_enabled": SkipFunctionVariable,
"torch.resize_as_": SkipFunctionVariable,
"torch.resize_as_sparse_": SkipFunctionVariable,
"torch.get_default_device": TorchInGraphFunctionVariable,
# functorch
"torch._functorch.vmap._check_int_or_none": UserFunctionVariable,
Expand Down Expand Up @@ -3027,8 +3027,8 @@ def is_numpy(obj) -> bool:
We should specify inline for the functions in `manual_torch_name_rule_map` or
put the corresponding python module into MOD_INLINELIST to make dynamo inline them.
* If you call functions under skipped modules/files, Dynamo will wrap these functions
as SkipFilesVariable. There are a few functions(e.g, collections.OrderedDict) that
we have special handling at SkipFilesVariable.call_function.
as SkipFunctionVariable. There are a few functions(e.g, collections.OrderedDict) that
we have special handling at SkipFunctionVariable.call_function.
Overall: *_INLINELIST has precedence over *_SKIPLIST has precedence over DEFAULT (inline)
Expand Down Expand Up @@ -3313,7 +3313,7 @@ def f3(x, y):
is in InliningInstructionTranslator.check_inlineable of symbolic_convert.py.
* If f2 is skipped by Dynamo, when evaluating the frame of f3, Dynamo need the inline/skip check again
and the call site is in catch_errors_wrapper.catch_errors of convert_frame.py.
* For global variables and function arguments, Dynamo needs to decide if they are wrapped as SkipFilesVariable in builder.py.
* For global variables and function arguments, Dynamo needs to decide if they are wrapped as SkipFunctionVariable in builder.py.
`is_inlined_call` is used to indicate if the current function call is inlined (f2 is inlined call if it passes check)
or not (f3 is not inlined call if f2 is skipped). Inside of the `check_verbose` function, there are more rules
Expand Down Expand Up @@ -3351,7 +3351,7 @@ def check_verbose(obj, is_inlined_call=False):
"inlined according trace_rules.lookup",
)
else:
assert rule == SkipFilesVariable, rule
assert rule == SkipFunctionVariable, rule
return SkipResult(
True,
"skipped according trace_rules.lookup",
Expand Down Expand Up @@ -3396,7 +3396,7 @@ def lookup_callable(obj):
return None
# Custom allow/disallow in graph takes precedence over the general lookup.
if is_callable_disallowed(obj):
return SkipFilesVariable
return SkipFunctionVariable
if is_callable_allowed(obj):
return TorchInGraphFunctionVariable
if is_builtin_callable(obj):
Expand Down Expand Up @@ -3430,7 +3430,7 @@ def lookup_inner(obj, name=None, filename=None, is_direct_call=True):
# Step 2: lookup obj's tracing rule by function name.
if is_direct_call:
if name == "patched_init":
return SkipFilesVariable
return SkipFunctionVariable
elif name == "__torch_function__":
return UserFunctionVariable

Expand All @@ -3439,6 +3439,6 @@ def lookup_inner(obj, name=None, filename=None, is_direct_call=True):
filename = getfile(obj)

if check_file(filename, is_direct_call).skipped:
return SkipFilesVariable
return SkipFunctionVariable
else:
return UserFunctionVariable
4 changes: 2 additions & 2 deletions torch/_dynamo/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .functions import (
FunctoolsPartialVariable,
NestedUserFunctionVariable,
SkipFunctionVariable,
UserFunctionVariable,
UserMethodVariable,
)
Expand Down Expand Up @@ -63,7 +64,6 @@
NewGlobalVariable,
NumpyVariable,
PythonModuleVariable,
SkipFilesVariable,
StringFormatVariable,
SuperVariable,
TypingVariable,
Expand Down Expand Up @@ -125,7 +125,7 @@
"RepeatIteratorVariable",
"RestrictedListSubclassVariable",
"SDPAParamsVariable",
"SkipFilesVariable",
"SkipFunctionVariable",
"SliceVariable",
"StringFormatVariable",
"SuperVariable",
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def is_hashable(x):
variables.EnumVariable,
variables.user_defined.UserDefinedClassVariable,
variables.UserFunctionVariable,
variables.misc.SkipFilesVariable,
variables.SkipFunctionVariable,
variables.misc.NumpyVariable,
variables.NNModuleVariable,
variables.MethodWrapperVariable,
Expand Down
79 changes: 77 additions & 2 deletions torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# mypy: ignore-errors

import collections
import functools
import inspect
import itertools
Expand All @@ -13,8 +14,8 @@
from ..exc import unimplemented, Unsupported
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
from ..utils import get_first_attr, identity, istype, make_cell
from .base import typestr, VariableTracker
from ..utils import check_constant_args, get_first_attr, identity, istype, make_cell
from .base import MutableLocal, typestr, VariableTracker
from .constant import ConstantVariable

if TYPE_CHECKING:
Expand Down Expand Up @@ -555,6 +556,80 @@ def reconstruct(self, codegen):
return []


class SkipFunctionVariable(VariableTracker):
def __init__(self, value, reason=None, **kwargs):
super().__init__(**kwargs)
self.value = value
self.reason = reason

def python_type(self):
return type(self.value)

def as_python_constant(self):
return self.value

@classmethod
def create_with_source(cls, value, source):
install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
return cls(
value,
source=source,
)

@staticmethod
@functools.lru_cache(None)
def fold_through_function_to_wrapper():
return {
collections.namedtuple: variables.UserDefinedClassVariable,
}

def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}")
# Fold through the functions(e.g, collections.namedtuple)
# that inputs & outputs are all python constants
elif (
self.value in self.fold_through_function_to_wrapper().keys()
and check_constant_args(args, kwargs)
):
value = self.value(
*[x.as_python_constant() for x in args],
**{k: v.as_python_constant() for k, v in kwargs.items()},
)
return self.fold_through_function_to_wrapper().get(self.value)(
value, mutable_local=MutableLocal()
)
elif (
self.value is functools.wraps
and not kwargs
and len(args) == 1
and (
args[0].source is not None or args[0].can_reconstruct(tx.output.root_tx)
)
):

def wraps(fn):
if isinstance(fn, variables.NestedUserFunctionVariable):
if args[0].source:
reconstructible = args[0].source
else:
reconstructible = args[0]
return fn.clone(wrapped_reconstructible=reconstructible)
unimplemented(f"functools.wraps({fn})")

return variables.LambdaVariable(wraps)
else:
try:
path = inspect.getfile(self.value)
except TypeError:
path = f"Builtin {self.value.__name__}"
msg = f"'skip function {self.value.__qualname__} in file {path}'"
msg += f"', {self.reason}'" if self.reason else ""
unimplemented(msg)


def _traceable_collective_remaps():
# We can't rely on importing from distributed, since it's not always built
if torch.distributed.is_available():
Expand Down
Loading

0 comments on commit 5371022

Please sign in to comment.