diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index ad56670addc4..79f29d2d4633 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -3425,7 +3425,6 @@ def g(x): with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"): transform(g)(x) - @xfailIfTorchDynamo def test_vjp_doesnt_support_saved_tensor_hooks(self, device): def f(x): return torch.sin(x).sum() @@ -3442,7 +3441,6 @@ def g(x): with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"): vjp(g, x) - @xfailIfTorchDynamo def test_jvp_doesnt_support_saved_tensor_hooks(self, device): def f(x): return torch.sin(x).sum() diff --git a/test/test_torch.py b/test/test_torch.py index 7ee49086e956..c63437a208ba 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6740,6 +6740,7 @@ def test_sobolengine_fast_forward(self, scramble: bool = False): def test_sobolengine_fast_forward_scrambled(self): self.test_sobolengine_fast_forward(scramble=True) + @skipIfTorchDynamo("np.float64 restored as float32 after graph break.") def test_sobolengine_distribution(self, scramble=False): d = 50 engine = torch.quasirandom.SobolEngine(d, scramble=scramble, seed=123456) @@ -6754,6 +6755,7 @@ def test_sobolengine_distribution(self, scramble=False): np.percentile(sample, 75, axis=0), np.repeat(0.75, d), atol=2, rtol=2 ) + @skipIfTorchDynamo("np.float64 restored as float32 after graph break.") def test_sobolengine_distribution_scrambled(self): self.test_sobolengine_distribution(scramble=True) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index cc4bc1309868..7b63c43bc39a 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -26,7 +26,7 @@ from torch.utils._traceback import format_traceback_short from . import config, exc -from .allowed_functions import is_allowed, is_numpy +from .allowed_functions import is_numpy from .backends.registry import CompilerFn from .bytecode_analysis import remove_dead_code, remove_pointless_jumps from .bytecode_transformation import ( @@ -180,7 +180,9 @@ def has_tensor_in_frame(frame): for co_name in frame.f_code.co_names: if co_name in frame.f_globals: obj = frame.f_globals[co_name] - if is_allowed(obj): + if isinstance(obj, types.ModuleType) and ( + obj.__name__.startswith("torch.") or obj is torch + ): return True # ... or a global import of numpy.* if np and config.trace_numpy and (obj is np or is_numpy(obj)): @@ -220,7 +222,7 @@ def has_tensor(obj): elif istype(obj, (str, int, float, type(None), bool)): seen_ids[obj_id] = False return seen_ids[obj_id] - elif is_namedtuple(obj): + elif is_namedtuple(obj) and hasattr(obj, "_fields"): seen_ids[obj_id] = any(has_tensor(getattr(obj, v)) for v in obj._fields) return seen_ids[obj_id] else: diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index e79cada9120e..90cabf0877f3 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -394,6 +394,8 @@ def EQUALS_MATCH(self, guard: Guard): torch.Size, torch.device, torch.dtype, + torch.memory_format, + torch.layout, *np_types, ) if istype(val, dict): diff --git a/torch/_dynamo/replay_record.py b/torch/_dynamo/replay_record.py index aaaaa1d3fe59..05ca85515c6d 100644 --- a/torch/_dynamo/replay_record.py +++ b/torch/_dynamo/replay_record.py @@ -18,6 +18,7 @@ class ModuleRecord: @dataclasses.dataclass class DummyModule: name: str + is_torch: bool = False @dataclasses.dataclass diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index bea9efc38059..48d562154480 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -25,7 +25,7 @@ from torch._guards import Checkpointable, tracing, TracingContext from . import config, exc, logging as torchdynamo_logging, skipfiles, variables -from .allowed_functions import is_allowed, is_builtin_constant, is_forbidden +from .allowed_functions import is_builtin_constant, is_forbidden from .bytecode_analysis import ( get_indexof, JUMP_OPNAMES, @@ -111,7 +111,6 @@ SymNodeVariable, TensorVariable, ) -from .variables.torch import TorchVariable from .variables.user_defined import ( RemovableHandleVariable, UserDefinedClassVariable, @@ -1022,9 +1021,7 @@ def IMPORT_NAME(self, inst): if config.replay_record_enabled: self.exec_recorder.add_local_mod(recorded_name, value) - if is_allowed(value): - self.push(TorchVariable(value, source=source)) - elif istype(value, (types.ModuleType, DummyModule)): + if istype(value, (types.ModuleType, DummyModule)): self.push(PythonModuleVariable(value, source=source)) else: unimplemented(f"IMPORT_NAME {typestr(value)}") diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 63ae469c9340..b13a17c32fa9 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -21,7 +21,7 @@ """ Map of torch objects to their tracing rules (Dynamo variables). -* TorchVariable: The functions should be put into the FX graph or can be constant folded. E.g., +* TorchInGraphFunctionVariable: The functions should be put into the FX graph or can be constant folded. E.g., - torch.add: should be put into the FX graph. - torch.is_floating_point: constant folded. * TorchCtxManagerClassVariable: The context manager classes are supported by Dynamo. E.g., torch.no_grad @@ -42,7 +42,6 @@ * Remove the entry represented the API from this map or test/dynamo/test_trace_rules.ignored_torch_name_rule_set depends on where it is. -TODO: Add torch object names mapping to TorchVariable for in graph and constant fold functions. TODO: We would consolidate the skipfiles.check rules into trace_rules.lookup later. TODO: We would support explictly list objects treated as skip/inline after the skipfiles.check and trace_rules.lookup consolidation is done. Then the explicit listing of skip/inline objects have @@ -93,6 +92,15 @@ "torch.nn.Parameter": SkipFilesVariable, "torch._nested_tensor_from_mask": SkipFilesVariable, "torch._nested_from_padded": SkipFilesVariable, + # symbol operators implemented in Python + "torch.sym_not": TorchInGraphFunctionVariable, + "torch.sym_float": TorchInGraphFunctionVariable, + "torch.sym_int": TorchInGraphFunctionVariable, + "torch.sym_max": TorchInGraphFunctionVariable, + "torch.sym_min": TorchInGraphFunctionVariable, + "torch.sym_sqrt": TorchInGraphFunctionVariable, + "torch.sym_ite": TorchInGraphFunctionVariable, + "torch.Tensor#_make_wrapper_subclass": SkipFilesVariable, } @@ -2812,7 +2820,7 @@ def lookup(obj): # Custom allow/disallow in graph takes precedence over the `torch_name_rule_map`. if id(obj) in _disallowed_function_ids: return None - if is_user_defined_allowed(obj): + if callable(obj) and is_user_defined_allowed(obj): return TorchInGraphFunctionVariable # Unwrap if the function is wrapped by functools.lru_cache or functools.wraps. if isinstance(obj, functools._lru_cache_wrapper) or ( diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 7dd2d8f75b63..6ae96459bdae 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -908,6 +908,8 @@ def is_safe_constant(v): type(type), torch.device, torch.dtype, + torch.memory_format, + torch.layout, ), ) diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 1e2f646f4d9b..4f211383bc2a 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -68,11 +68,7 @@ TensorVariable, UnspecializedPythonVariable, ) -from .torch import ( - TorchCtxManagerClassVariable, - TorchInGraphFunctionVariable, - TorchVariable, -) +from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable __all__ = [ @@ -120,7 +116,6 @@ "TensorVariable", "TorchCtxManagerClassVariable", "TorchInGraphFunctionVariable", - "TorchVariable", "TupleVariable", "UnknownVariable", "UnspecializedNNModuleVariable", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 732825bdf468..ca5d408c815b 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -38,12 +38,7 @@ from torch.utils._python_dispatch import is_traceable_wrapper_subclass from torch.utils.weak import TensorWeakRef from .. import config, mutation_guard, replay_record, skipfiles, trace_rules -from ..allowed_functions import ( - is_allowed, - is_builtin_callable, - is_numpy, - is_user_defined_allowed, -) +from ..allowed_functions import is_builtin_callable, is_numpy, is_user_defined_allowed from ..device_interface import get_registered_device_interfaces from ..exc import InternalTorchDynamoError, unimplemented @@ -151,7 +146,7 @@ TensorVariable, UnspecializedPythonVariable, ) -from .torch import torch_special_class_types, TorchVariable +from .torch import TorchInGraphFunctionVariable from .torch_function import build_torch_function_fn, TensorWithTFOverrideVariable from .user_defined import ( KeyedJaggedTensorVariable, @@ -320,6 +315,8 @@ def _type_dispatch(cls): torch.Size, torch.device, torch.dtype, + torch.memory_format, + torch.layout, ), cls.wrap_literal, ), @@ -475,7 +472,7 @@ def index_source(key): elif ConstantVariable.is_literal(value): # non-atomic literals return self.wrap_literal(value) elif istype(value, frozenset) and ( - all(is_allowed(x) or ConstantVariable.is_literal(x) for x in value) + ConstantVariable.is_literal(x) for x in value ): # For frozenset, we can guard by object ID instead of value # equality, this allows us to handle non-literal values @@ -588,13 +585,6 @@ def index_source(key): elif isinstance(value, HigherOrderOperator): self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH) return TorchHigherOrderOperatorVariable.make(value, source=self.source) - elif type(value).__name__ == "builtin_function_or_method" and isinstance( - value.__self__, torch_special_class_types - ): - self.install_guards(GuardBuilder.FUNCTION_MATCH) - return TorchVariable( - value, - ) elif isinstance(value, _StreamBase): self.install_guards(GuardBuilder.ID_MATCH) return StreamVariable( @@ -716,18 +706,18 @@ def index_source(key): return trace_rules.lookup(value).create_with_source( value, source=self.source ) - elif istype(value, (types.ModuleType, replay_record.DummyModule)): + elif ( + istype(value, (types.ModuleType, replay_record.DummyModule)) + # type(torch.backends.cudnn) -> + # type(torch.ops) -> + or value in [torch.backends.cudnn, torch.ops] + or isinstance(value, torch._ops._OpNamespace) + ): self.install_guards(GuardBuilder.FUNCTION_MATCH) return PythonModuleVariable( value, source=self.source, ) - elif is_allowed(value): - self.install_guards(GuardBuilder.FUNCTION_MATCH) - return TorchVariable( - value, - source=self.source, - ) elif ( is_function(value) and skipfiles.check(value, is_inlined_call=True) @@ -747,7 +737,7 @@ def index_source(key): source=self.source, ) elif isinstance(value, types.MethodType) and isinstance( - value.__self__, torch.nn.Module + value.__self__, (torch.nn.Module, torch.utils._pytree.TreeSpec) ): # don't let MethodTypes fall through to UserDefinedObject, # which doesn't support 'CALL_FUNCTION' @@ -800,6 +790,7 @@ def index_source(key): ), ) else: + # breakpoint() self.install_guards(GuardBuilder.TYPE_MATCH) result = UserDefinedObjectVariable(value, source=self.source) if not SideEffects.cls_supports_mutation_side_effects(type(value)): @@ -1478,7 +1469,7 @@ def _clone_input(value): and isinstance(proxy.node.target.__self__, torch._C.Generator) or proxy.node.target == torch.random.set_rng_state ): - return TorchVariable(proxy.node.target) + return TorchInGraphFunctionVariable(proxy.node.target) elif ( proxy.node.target == torch._C._DisableFuncTorch or proxy.node.target == torch.cuda._is_in_bad_fork @@ -1898,10 +1889,10 @@ def __call__(self, tx, value) -> VariableTracker: return SourcelessBuilder.wrap_constant_literal(value) elif is_builtin_callable(value): return BuiltinVariable(value) - elif is_allowed(value): + elif trace_rules.lookup(value) is not None: if is_user_defined_allowed(value): self.tx.output.has_user_defined_allowed_in_graph = True - return TorchVariable(value) + return trace_rules.lookup(value)(value) elif isinstance(value, types.FunctionType): return UserFunctionVariable(value) elif isinstance(value, enum.Enum): diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 4ef121f4f201..237fc32531a9 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -703,7 +703,9 @@ def _call_min_max_binary(self, tx, a, b): # result of an item call is a scalar convert to a tensor if isinstance(a, FakeItemVariable): - a = variables.TorchVariable(torch.tensor).call_function(tx, [a], {}) + a = variables.TorchInGraphFunctionVariable(torch.tensor).call_function( + tx, [a], {} + ) # Dynamic input does not get resolved, rather, gets stored as call_function if isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable): @@ -726,7 +728,7 @@ def _call_min_max_binary(self, tx, a, b): fn = variables.NumpyVariable(np.clip) else: - fn = variables.TorchVariable(torch.clamp) + fn = variables.TorchInGraphFunctionVariable(torch.clamp) kwargs = {"min": b} if (self.fn is max) else {"max": b} result = fn.call_function(tx, [a], kwargs) else: @@ -737,7 +739,7 @@ def _call_min_max_binary(self, tx, a, b): fn = variables.NumpyVariable(fn) else: fn = {max: torch.maximum, min: torch.minimum}[self.fn] - fn = variables.TorchVariable(fn) + fn = variables.TorchInGraphFunctionVariable(fn) result = fn.call_function(tx, [a, b], {}) # return unspec if both a, b are unspec or const @@ -1122,7 +1124,6 @@ def call_getattr( GetAttrVariable, PythonModuleVariable, TorchInGraphFunctionVariable, - TorchVariable, UserFunctionVariable, ) from .builder import SourcelessBuilder, VariableBuilder @@ -1235,11 +1236,19 @@ def _grad_changed(old, new): except NotImplementedError: return GetAttrVariable(obj, name, **options) elif isinstance(obj, TorchInGraphFunctionVariable): + # Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default. member = getattr(obj.value, name) - if trace_rules.lookup(member) is not None: - return trace_rules.lookup(member)(member, **options) - elif isinstance(obj, TorchVariable): - member = getattr(obj.value, name) + if trace_rules.is_aten_op_or_tensor_method(member): + return TorchInGraphFunctionVariable(member, **options) + elif isinstance(obj, (PythonModuleVariable, DummyModule)): + if obj.is_torch: + member = getattr(obj.value, name) + else: + member = obj.value.__dict__[name] + + if config.replay_record_enabled: + tx.exec_recorder.record_module_access(obj.value, name, member) + if is_utils_checkpoint(member): options["source"] = source return build_checkpoint_variable(**options) @@ -1249,13 +1258,6 @@ def _grad_changed(old, new): return VariableBuilder(tx, source)(member) else: return SourcelessBuilder()(tx, member) - elif isinstance(obj, (PythonModuleVariable, DummyModule)): - member = obj.value.__dict__[name] - - if config.replay_record_enabled: - tx.exec_recorder.record_module_access(obj.value, name, member) - - return VariableBuilder(tx, source)(member) elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"): return ConstantVariable.create(getattr(obj.fn, name)) else: diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 1e0754bd0ec4..cd8d1e1e29e2 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -1175,7 +1175,7 @@ def call_function( "kwargs have not been implemented for torch.autograd.Function" ) - from . import TorchVariable + from . import TorchInGraphFunctionVariable always_restore = self.value.__name__ == "trampoline_autograd_bwd" if ( @@ -1184,7 +1184,7 @@ def call_function( ): fn = UserFunctionVariable(self.value, source=self.source) else: - fn = TorchVariable(self.value) + fn = TorchInGraphFunctionVariable(self.value) # TODO(jansel): BUG!!! we aren't copying on the line below, so the post-pre check below is pointless pre_guards = tx.output.guards # In eager-mode PyTorch, if we only compute first-order gradients, diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index c7d66a6c1665..c24d3b6c0c73 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -8,7 +8,6 @@ import torch.nn from .. import skipfiles, variables -from ..allowed_functions import is_allowed from ..exc import unimplemented, UnspecializeRestartAnalysis, Unsupported from ..guards import GuardBuilder, install_guard from ..mutation_guard import GenerationTracker @@ -289,7 +288,9 @@ def call_function( # If we are tracing the higher order op, we want Dynamo to step # inside the module call so that Dynamo can see the underlying # parameters and buffers and raise them as inputs to the graph. - if tx.output.is_root_tracer() and is_allowed(mod.__class__): + if tx.output.is_root_tracer() and mod.__module__.startswith( + ("torch.nn.", "torch.ao.") + ): if nnmodule_has_hooks( mod, check_forward_hooks=True, check_backward_hooks=True ): diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index bac958fbd989..ab31253681fa 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -216,7 +216,7 @@ def dynamic_getattr(self, tx, name): return VariableBuilder(tx, attr_source)(real_value) def var_getattr(self, tx, name): - from . import ConstantVariable, TorchVariable + from . import ConstantVariable, UserDefinedClassVariable if tx.strict_checks_enabled: if name in self._strict_mode_banned_ops(): @@ -230,7 +230,7 @@ def var_getattr(self, tx, name): elif name == "device" and self.device is not None: result = ConstantVariable.create(self.device) elif name == "layout" and self.layout is not None: - result = TorchVariable(self.layout) + result = ConstantVariable.create(self.layout) elif name == "is_cuda" and self.device is not None: result = ConstantVariable.create(self.device.type == "cuda") elif name == "shape" and self.size is not None: @@ -249,7 +249,7 @@ def var_getattr(self, tx, name): elif name == "data": result = self.call_method(tx, "detach", [], {}) if name == "__class__": - return TorchVariable(self.python_type()) + return UserDefinedClassVariable(self.python_type()) # Add a guard for type matching, these guards are checked before tensor guards # In some cases, a . guard can be evaluated first, and break if @@ -351,7 +351,7 @@ def call_method( if tx.strict_checks_enabled: if name in self._strict_mode_banned_ops(): unimplemented(f"Illegal method invocation {name} in strict mode") - from . import ConstantVariable, TorchVariable, TupleVariable + from . import ConstantVariable, TorchInGraphFunctionVariable, TupleVariable from .builder import wrap_fx_proxy kwargs = dict(kwargs) @@ -625,7 +625,7 @@ def has_bool_key(v): elif ( name == "add_" and len(args) == 1 and len(kwargs) == 1 and "alpha" in kwargs ): - result = TorchVariable(torch.mul).call_function( + result = TorchInGraphFunctionVariable(torch.mul).call_function( tx, args + [kwargs["alpha"]], {} ) return self.call_method(tx, "add_", [result], {}) @@ -635,8 +635,8 @@ def has_bool_key(v): and len(kwargs) == 1 and "value" in kwargs ): - result = TorchVariable(torch.div).call_function(tx, args, {}) - result = TorchVariable(torch.mul).call_function( + result = TorchInGraphFunctionVariable(torch.div).call_function(tx, args, {}) + result = TorchInGraphFunctionVariable(torch.mul).call_function( tx, [result, kwargs["value"]], {} ) return self.call_method(tx, "add_", [result], {}) @@ -645,8 +645,12 @@ def has_bool_key(v): # without dealing with unbacked symbool. Roughly the code we translate is: # def __contains__(self, x): # return (x == self).any().item() - result = TorchVariable(torch.eq).call_function(tx, [self, args[0]], {}) - result = TorchVariable(torch.any).call_function(tx, [result], {}) + result = TorchInGraphFunctionVariable(torch.eq).call_function( + tx, [self, args[0]], {} + ) + result = TorchInGraphFunctionVariable(torch.any).call_function( + tx, [result], {} + ) return result.call_method(tx, "item", [], {}) elif name == "redistribute": # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function @@ -678,7 +682,7 @@ def redistribute_fn_with_prim_types(x): ( variables.functions.FunctoolsPartialVariable, variables.UserFunctionVariable, - variables.TorchVariable, + variables.TorchInGraphFunctionVariable, variables.NNModuleVariable, ), ): diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 8e445ecf15c1..faeb805695e3 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -3,7 +3,6 @@ import math import re -import types from typing import Dict, List from torch._streambase import _StreamBase @@ -29,10 +28,8 @@ check_constant_args, check_unspec_python_args, has_torch_function, - istype, product, proxy_args_kwargs, - tensortype_to_dtype, ) from .base import VariableTracker from .ctx_manager import ( @@ -48,8 +45,6 @@ log = logging.getLogger(__name__) -torch_special_class_types = (torch._C.Generator,) - REWRITE_OPS_TO_TENSOR_SIZE_METHOD = [ torch.onnx.operators.shape_as_tensor, torch._shape_as_tensor, @@ -59,12 +54,9 @@ torch._assert, torch._utils._get_device_index, torch.cuda.is_available, - torch.device, torch.distributed.is_available, - torch.finfo, torch.get_autocast_gpu_dtype, torch.get_default_dtype, - torch.iinfo, torch.is_autocast_cache_enabled, torch.is_autocast_cpu_enabled, torch.is_autocast_enabled, @@ -649,174 +641,3 @@ def handle_ntuple(value): return variables.LambdaVariable(handle_ntuple) else: return handle_ntuple(args[0]) - - -class TorchVariable(BaseTorchVariable): - """Points to a module, classes or functions in torch.*""" - - def __init__(self, value, **kwargs): - assert not isinstance( - value, (torch.dtype, torch.device) - ), "should use ConstantVariable" - - super().__init__(value, **kwargs) - - # the remainder of this is just optional debug checks - try: - self_should_be_none = getattr(self.value, "__self__", None) - except RuntimeError as e: - assert "No such operator" in str(e), str(e) - self_should_be_none = None - except AssertionError as e: - assert "Unknown attribute" in str(e), str(e) - self_should_be_none = None - - if self_should_be_none is None: - pass - elif isinstance(self_should_be_none, types.ModuleType): - # weird ones like torch.nn.functional.avg_pool2d have __self__ - name = self_should_be_none.__name__ - assert re.match(r"^(torch|math)([.]|$)", name), f"__self__ set to {name}" - elif isinstance( - self_should_be_none, type(torch._C._get_tracing_state.__self__) - ): - # some _C functions have __self__ as a null capsule - pass - elif isinstance(self_should_be_none, torch_special_class_types): - pass - else: - raise AssertionError(f"{value} found with __self__ set") - - def __repr__(self): - return f"TorchVariable({self.value})" - - def python_type(self): - if isinstance(self.value, (torch.Tensor, torch.nn.Module, torch.device)): - return type(self.value) - if isinstance(self.value, type): - return type - return super().python_type() - - def call_function( - self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" - ) -> "VariableTracker": - from . import ConstantVariable - - from .builder import wrap_fx_proxy - - constant_args = check_constant_args(args, kwargs) - unspec_python_args = check_unspec_python_args(args, kwargs) - - if self.can_constant_fold_through() and (constant_args or unspec_python_args): - # constant fold - return ConstantVariable.create( - self.as_python_constant()( - *[x.as_python_constant() for x in args], - **{k: v.as_python_constant() for k, v in kwargs.items()}, - ), - ) - elif istype(self.value, type) and issubclass(self.value, torch.nn.Module): - if self.value is torch.nn.CrossEntropyLoss: - return self._call_cross_entropy_loss(tx, args, kwargs) - else: - return variables.UserDefinedClassVariable( - self.value, source=self.source - ).call_function(tx, args, kwargs) - elif can_dispatch_torch_function(tx, args, kwargs): - return dispatch_torch_function(tx, self, args, kwargs) - else: - # torch.LongTensor cannot accept a list of FakeTensors. - # So we stack the list of FakeTensors instead. - if ( - np - and self.value in tensortype_to_dtype - and len(args) == 1 - and isinstance(args[0], ListVariable) - and len(args[0].items) > 1 - and all(isinstance(x, variables.TensorVariable) for x in args[0].items) - ): - # Stack FakeTensor - stacked = wrap_fx_proxy( - tx=tx, - proxy=tx.output.create_proxy( - "call_function", - torch.stack, - *proxy_args_kwargs(args, kwargs), - ), - ) - args = [stacked] - - tensor_variable = wrap_fx_proxy( - tx=tx, - proxy=tx.output.create_proxy( - "call_function", - self.value, - *proxy_args_kwargs(args, kwargs), - ), - ) - - return tensor_variable - - def _call_cross_entropy_loss(self, tx, args, kwargs): - """ - functional: input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', - label_smoothing=0.0 - - non functional ctor: weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', - label_smoothing=0.0 - - non functional loss call: input, target, optional_output - """ - from . import ConstantVariable - - def normalize_args( - weight=ConstantVariable.create(None), - size_average=ConstantVariable.create(None), - ignore_index=ConstantVariable.create(-100), - reduce=ConstantVariable.create(None), - reduction=ConstantVariable.create("mean"), - label_smoothing=ConstantVariable.create(0.0), - ): - return ( - weight, - size_average, - ignore_index, - reduce, - reduction, - label_smoothing, - ) - - ( - weight, - size_average, - ignore_index, - reduce_arg, - reduction, - label_smoothing, - ) = normalize_args(*args, **kwargs) - - def fake_cross_entropy_loss(input, target): - from .builder import wrap_fx_proxy - - return wrap_fx_proxy( - tx=tx, - proxy=tx.output.create_proxy( - "call_function", - torch.nn.functional.cross_entropy, - *proxy_args_kwargs( - [ - input, - target, - weight, - size_average, - ignore_index, - reduce_arg, - reduction, - label_smoothing, - ], - {}, - ), - ), - ) - - return variables.LambdaVariable(fake_cross_entropy_loss) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 3d54481954cd..42db67294fd3 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -10,13 +10,17 @@ import types from typing import Dict, List +try: + import numpy as np +except ModuleNotFoundError: + np = None + import torch._dynamo.config import torch.nn from torch._guards import TracingContext from .. import variables -from ..allowed_functions import is_allowed from ..exc import unimplemented from ..guards import GuardBuilder, install_guard from ..source import AttrSource, GetItemSource, ODictGetItemSource, RandomValueSource @@ -31,6 +35,8 @@ istype, namedtuple_fields, object_has_getattribute, + proxy_args_kwargs, + tensortype_to_dtype, ) from .base import MutableLocal, VariableTracker from .ctx_manager import GenericContextWrappingVariable, NullContextVariable @@ -52,10 +58,42 @@ def as_python_constant(self): def python_type(self): return type(self.value) + def as_proxy(self): + return self.value + + def __repr__(self): + return f"UserDefinedClassVariable({self.value})" + + @staticmethod + @functools.lru_cache(None) + def _constant_fold_classes(): + return { + torch.device, + torch.finfo, + torch.iinfo, + torch.Size, + } + + @staticmethod + @functools.lru_cache(None) + def _in_graph_classes(): + return set(tensortype_to_dtype.keys()) | { + torch.Tensor, + torch.cuda.Stream, + torch.cuda.Event, + } + + def can_constant_fold_through(self): + return self.value in self._constant_fold_classes() + def var_getattr(self, tx, name: str) -> "VariableTracker": + from .. import trace_rules from . import ConstantVariable from .builder import VariableBuilder + if name == "__name__": + return ConstantVariable.create(self.value.__name__) + source = AttrSource(self.source, name) if self.source is not None else None try: obj = inspect.getattr_static(self.value, name) @@ -63,9 +101,11 @@ def var_getattr(self, tx, name: str) -> "VariableTracker": obj = None if isinstance(obj, staticmethod): - return variables.UserFunctionVariable( - obj.__get__(self.value), source=source - ) + func = obj.__get__(self.value) + if trace_rules.lookup(func) is not None: + return trace_rules.lookup(func).create_with_source(func, source=source) + else: + return variables.UserFunctionVariable(func, source=source) elif isinstance(obj, classmethod): return variables.UserMethodVariable(obj.__func__, self, source=source) elif source and inspect.ismemberdescriptor(obj): @@ -79,16 +119,81 @@ def var_getattr(self, tx, name: str) -> "VariableTracker": if self.value is collections.OrderedDict and name == "fromkeys": return super().var_getattr(tx, name) - if name in getattr(self.value, "__dict__", {}) or ConstantVariable.is_literal( - obj + if name in getattr(self.value, "__dict__", {}) or ( + self.value.__module__.startswith("torch.") + or self.value.__module__ == "torch" ): if source: return VariableBuilder(tx, source)(obj) - elif ConstantVariable.is_literal(obj): - return ConstantVariable.create(obj) + elif ConstantVariable.is_literal(obj): + return ConstantVariable.create(obj) return super().var_getattr(tx, name) + def _call_cross_entropy_loss(self, tx, args, kwargs): + """ + functional: input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', + label_smoothing=0.0 + + non functional ctor: weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', + label_smoothing=0.0 + + non functional loss call: input, target, optional_output + """ + from . import ConstantVariable + + def normalize_args( + weight=ConstantVariable.create(None), + size_average=ConstantVariable.create(None), + ignore_index=ConstantVariable.create(-100), + reduce=ConstantVariable.create(None), + reduction=ConstantVariable.create("mean"), + label_smoothing=ConstantVariable.create(0.0), + ): + return ( + weight, + size_average, + ignore_index, + reduce, + reduction, + label_smoothing, + ) + + ( + weight, + size_average, + ignore_index, + reduce_arg, + reduction, + label_smoothing, + ) = normalize_args(*args, **kwargs) + + def fake_cross_entropy_loss(input, target): + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + torch.nn.functional.cross_entropy, + *proxy_args_kwargs( + [ + input, + target, + weight, + size_average, + ignore_index, + reduce_arg, + reduction, + label_smoothing, + ], + {}, + ), + ), + ) + + return variables.LambdaVariable(fake_cross_entropy_loss) + def call_method( self, tx, @@ -127,10 +232,22 @@ def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": from ..side_effects import SideEffects - from .builder import SourcelessBuilder + from .builder import SourcelessBuilder, wrap_fx_proxy from .builtin import BuiltinVariable - if self.value is contextlib.nullcontext: + constant_args = check_constant_args(args, kwargs) + + if self.can_constant_fold_through() and constant_args: + # constant fold + return variables.ConstantVariable.create( + self.as_python_constant()( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + elif self.value is torch.nn.CrossEntropyLoss: + return self._call_cross_entropy_loss(tx, args, kwargs) + elif self.value is contextlib.nullcontext: return NullContextVariable() elif self.value is collections.OrderedDict: return BuiltinVariable.call_custom_dict( @@ -245,6 +362,38 @@ def call_function( user_cls_source=self.source, mutable_local=MutableLocal(), ) + elif self.value in self._in_graph_classes(): + # torch.LongTensor cannot accept a list of FakeTensors. + # So we stack the list of FakeTensors instead. + if ( + np + and self.value in tensortype_to_dtype + and len(args) == 1 + and isinstance(args[0], variables.ListVariable) + and len(args[0].items) > 1 + and all(isinstance(x, variables.TensorVariable) for x in args[0].items) + ): + # Stack FakeTensor + stacked = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + torch.stack, + *proxy_args_kwargs(args, kwargs), + ), + ) + args = [stacked] + + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + *proxy_args_kwargs(args, kwargs), + ), + ) + + return tensor_variable return super().call_function(tx, args, kwargs) @@ -476,7 +625,8 @@ def call_function( ).call_function(tx, [var], kwargs) elif ( istype(self.value, functools.partial) - and is_allowed(self.value.func) + and trace_rules.lookup(self.value.func) + == variables.TorchInGraphFunctionVariable and all( variables.ConstantVariable.is_literal(v) for v in itertools.chain(self.value.args, self.value.keywords.values()) @@ -506,9 +656,9 @@ def call_function( return build_checkpoint_variable().call_function( tx, partial_args, partial_kwargs ) - return variables.TorchVariable(self.value.func).call_function( - tx, partial_args, partial_kwargs - ) + return variables.TorchInGraphFunctionVariable( + self.value.func + ).call_function(tx, partial_args, partial_kwargs) elif callable(self.value): if self.source: install_guard(self.source.make_guard(GuardBuilder.FUNCTION_MATCH)) @@ -536,6 +686,7 @@ def _getattr_static(self, name): return subobj def var_getattr(self, tx, name): + from .. import trace_rules from . import ConstantVariable from .builder import VariableBuilder @@ -568,9 +719,11 @@ class NO_SUCH_SUBOBJ: subobj.__get__.__func__, subobj_var, source=source ).call_function(tx, [self], {}) elif isinstance(subobj, staticmethod): - return variables.UserFunctionVariable( - subobj.__get__(self.value), source=source - ) + func = subobj.__get__(self.value) + if trace_rules.lookup(func) is not None: + return trace_rules.lookup(func).create_with_source(func, source=source) + else: + return variables.UserFunctionVariable(func, source=source) elif isinstance(subobj, classmethod): return variables.UserMethodVariable(subobj.__func__, self, source=source) elif isinstance(subobj, types.FunctionType) or ( @@ -600,9 +753,12 @@ class NO_SUCH_SUBOBJ: elif inspect.isfunction(dynamic_subobj): if is_utils_checkpoint(func): return build_checkpoint_variable(source=source) - elif is_allowed(func): - return variables.TorchVariable(func, source=source) - return variables.UserFunctionVariable(func, source=source) + elif trace_rules.lookup(func) is not None: + return trace_rules.lookup(func).create_with_source( + func, source=source + ) + else: + return variables.UserFunctionVariable(func, source=source) if ( name in getattr(value, "__dict__", {}) diff --git a/torch/distributed/_tensor/__init__.py b/torch/distributed/_tensor/__init__.py index b3755e79564a..3e5e628b0522 100644 --- a/torch/distributed/_tensor/__init__.py +++ b/torch/distributed/_tensor/__init__.py @@ -340,7 +340,3 @@ def zeros( device_mesh=device_mesh, placements=placements, ) - - -if not torch._running_with_deploy(): - import torch.distributed._tensor._dynamo_utils diff --git a/torch/distributed/_tensor/_dynamo_utils.py b/torch/distributed/_tensor/_dynamo_utils.py deleted file mode 100644 index f5c73edbe9d1..000000000000 --- a/torch/distributed/_tensor/_dynamo_utils.py +++ /dev/null @@ -1,6 +0,0 @@ -from torch._dynamo import allow_in_graph -from torch.distributed._tensor.api import DTensor - -# dynamo/torch.compile utils for -allow_in_graph(DTensor) -allow_in_graph(DTensor.from_local)