Skip to content

Commit

Permalink
[Dynamo][10/N] Remove TorchVariable and is_allowed (#116312)
Browse files Browse the repository at this point in the history
After this refactor:
* ```TorchVariable``` definition and all references are removed.
* All ```is_allowed``` references except one are removed.
  - The only left one is in ```torch/_dynamo/decorators:_disallow_in_graph_helper```. It was called when users put ```disallow_in_graph``` decorator on a function. Since we use the lists in ```trace_rules``` to decide the function's trace rule, so the decorator would only be used as customer function rather than torch functions. I'll defer this to a separate decorator refactor PR.

Pull Request resolved: #116312
Approved by: https://github.com/jansel
  • Loading branch information
yanboliang authored and pytorchmergebot committed Dec 27, 2023
1 parent 87da0e1 commit f657b2b
Show file tree
Hide file tree
Showing 18 changed files with 255 additions and 283 deletions.
2 changes: 0 additions & 2 deletions test/functorch/test_eager_transforms.py
Expand Up @@ -3425,7 +3425,6 @@ def g(x):
with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"):
transform(g)(x)

@xfailIfTorchDynamo
def test_vjp_doesnt_support_saved_tensor_hooks(self, device):
def f(x):
return torch.sin(x).sum()
Expand All @@ -3442,7 +3441,6 @@ def g(x):
with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"):
vjp(g, x)

@xfailIfTorchDynamo
def test_jvp_doesnt_support_saved_tensor_hooks(self, device):
def f(x):
return torch.sin(x).sum()
Expand Down
2 changes: 2 additions & 0 deletions test/test_torch.py
Expand Up @@ -6740,6 +6740,7 @@ def test_sobolengine_fast_forward(self, scramble: bool = False):
def test_sobolengine_fast_forward_scrambled(self):
self.test_sobolengine_fast_forward(scramble=True)

@skipIfTorchDynamo("np.float64 restored as float32 after graph break.")
def test_sobolengine_distribution(self, scramble=False):
d = 50
engine = torch.quasirandom.SobolEngine(d, scramble=scramble, seed=123456)
Expand All @@ -6754,6 +6755,7 @@ def test_sobolengine_distribution(self, scramble=False):
np.percentile(sample, 75, axis=0), np.repeat(0.75, d), atol=2, rtol=2
)

@skipIfTorchDynamo("np.float64 restored as float32 after graph break.")
def test_sobolengine_distribution_scrambled(self):
self.test_sobolengine_distribution(scramble=True)

Expand Down
8 changes: 5 additions & 3 deletions torch/_dynamo/convert_frame.py
Expand Up @@ -26,7 +26,7 @@
from torch.utils._traceback import format_traceback_short

from . import config, exc
from .allowed_functions import is_allowed, is_numpy
from .allowed_functions import is_numpy
from .backends.registry import CompilerFn
from .bytecode_analysis import remove_dead_code, remove_pointless_jumps
from .bytecode_transformation import (
Expand Down Expand Up @@ -180,7 +180,9 @@ def has_tensor_in_frame(frame):
for co_name in frame.f_code.co_names:
if co_name in frame.f_globals:
obj = frame.f_globals[co_name]
if is_allowed(obj):
if isinstance(obj, types.ModuleType) and (
obj.__name__.startswith("torch.") or obj is torch
):
return True
# ... or a global import of numpy.*
if np and config.trace_numpy and (obj is np or is_numpy(obj)):
Expand Down Expand Up @@ -220,7 +222,7 @@ def has_tensor(obj):
elif istype(obj, (str, int, float, type(None), bool)):
seen_ids[obj_id] = False
return seen_ids[obj_id]
elif is_namedtuple(obj):
elif is_namedtuple(obj) and hasattr(obj, "_fields"):
seen_ids[obj_id] = any(has_tensor(getattr(obj, v)) for v in obj._fields)
return seen_ids[obj_id]
else:
Expand Down
2 changes: 2 additions & 0 deletions torch/_dynamo/guards.py
Expand Up @@ -394,6 +394,8 @@ def EQUALS_MATCH(self, guard: Guard):
torch.Size,
torch.device,
torch.dtype,
torch.memory_format,
torch.layout,
*np_types,
)
if istype(val, dict):
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/replay_record.py
Expand Up @@ -18,6 +18,7 @@ class ModuleRecord:
@dataclasses.dataclass
class DummyModule:
name: str
is_torch: bool = False


@dataclasses.dataclass
Expand Down
7 changes: 2 additions & 5 deletions torch/_dynamo/symbolic_convert.py
Expand Up @@ -25,7 +25,7 @@
from torch._guards import Checkpointable, tracing, TracingContext

from . import config, exc, logging as torchdynamo_logging, skipfiles, variables
from .allowed_functions import is_allowed, is_builtin_constant, is_forbidden
from .allowed_functions import is_builtin_constant, is_forbidden
from .bytecode_analysis import (
get_indexof,
JUMP_OPNAMES,
Expand Down Expand Up @@ -111,7 +111,6 @@
SymNodeVariable,
TensorVariable,
)
from .variables.torch import TorchVariable
from .variables.user_defined import (
RemovableHandleVariable,
UserDefinedClassVariable,
Expand Down Expand Up @@ -1022,9 +1021,7 @@ def IMPORT_NAME(self, inst):
if config.replay_record_enabled:
self.exec_recorder.add_local_mod(recorded_name, value)

if is_allowed(value):
self.push(TorchVariable(value, source=source))
elif istype(value, (types.ModuleType, DummyModule)):
if istype(value, (types.ModuleType, DummyModule)):
self.push(PythonModuleVariable(value, source=source))
else:
unimplemented(f"IMPORT_NAME {typestr(value)}")
Expand Down
14 changes: 11 additions & 3 deletions torch/_dynamo/trace_rules.py
Expand Up @@ -21,7 +21,7 @@

"""
Map of torch objects to their tracing rules (Dynamo variables).
* TorchVariable: The functions should be put into the FX graph or can be constant folded. E.g.,
* TorchInGraphFunctionVariable: The functions should be put into the FX graph or can be constant folded. E.g.,
- torch.add: should be put into the FX graph.
- torch.is_floating_point: constant folded.
* TorchCtxManagerClassVariable: The context manager classes are supported by Dynamo. E.g., torch.no_grad
Expand All @@ -42,7 +42,6 @@
* Remove the entry represented the API from this map or test/dynamo/test_trace_rules.ignored_torch_name_rule_set
depends on where it is.
TODO: Add torch object names mapping to TorchVariable for in graph and constant fold functions.
TODO: We would consolidate the skipfiles.check rules into trace_rules.lookup later.
TODO: We would support explictly list objects treated as skip/inline after the skipfiles.check
and trace_rules.lookup consolidation is done. Then the explicit listing of skip/inline objects have
Expand Down Expand Up @@ -93,6 +92,15 @@
"torch.nn.Parameter": SkipFilesVariable,
"torch._nested_tensor_from_mask": SkipFilesVariable,
"torch._nested_from_padded": SkipFilesVariable,
# symbol operators implemented in Python
"torch.sym_not": TorchInGraphFunctionVariable,
"torch.sym_float": TorchInGraphFunctionVariable,
"torch.sym_int": TorchInGraphFunctionVariable,
"torch.sym_max": TorchInGraphFunctionVariable,
"torch.sym_min": TorchInGraphFunctionVariable,
"torch.sym_sqrt": TorchInGraphFunctionVariable,
"torch.sym_ite": TorchInGraphFunctionVariable,
"torch.Tensor#_make_wrapper_subclass": SkipFilesVariable,
}


Expand Down Expand Up @@ -2812,7 +2820,7 @@ def lookup(obj):
# Custom allow/disallow in graph takes precedence over the `torch_name_rule_map`.
if id(obj) in _disallowed_function_ids:
return None
if is_user_defined_allowed(obj):
if callable(obj) and is_user_defined_allowed(obj):
return TorchInGraphFunctionVariable
# Unwrap if the function is wrapped by functools.lru_cache or functools.wraps.
if isinstance(obj, functools._lru_cache_wrapper) or (
Expand Down
2 changes: 2 additions & 0 deletions torch/_dynamo/utils.py
Expand Up @@ -908,6 +908,8 @@ def is_safe_constant(v):
type(type),
torch.device,
torch.dtype,
torch.memory_format,
torch.layout,
),
)

Expand Down
7 changes: 1 addition & 6 deletions torch/_dynamo/variables/__init__.py
Expand Up @@ -68,11 +68,7 @@
TensorVariable,
UnspecializedPythonVariable,
)
from .torch import (
TorchCtxManagerClassVariable,
TorchInGraphFunctionVariable,
TorchVariable,
)
from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable

__all__ = [
Expand Down Expand Up @@ -120,7 +116,6 @@
"TensorVariable",
"TorchCtxManagerClassVariable",
"TorchInGraphFunctionVariable",
"TorchVariable",
"TupleVariable",
"UnknownVariable",
"UnspecializedNNModuleVariable",
Expand Down
43 changes: 17 additions & 26 deletions torch/_dynamo/variables/builder.py
Expand Up @@ -38,12 +38,7 @@
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils.weak import TensorWeakRef
from .. import config, mutation_guard, replay_record, skipfiles, trace_rules
from ..allowed_functions import (
is_allowed,
is_builtin_callable,
is_numpy,
is_user_defined_allowed,
)
from ..allowed_functions import is_builtin_callable, is_numpy, is_user_defined_allowed

from ..device_interface import get_registered_device_interfaces
from ..exc import InternalTorchDynamoError, unimplemented
Expand Down Expand Up @@ -151,7 +146,7 @@
TensorVariable,
UnspecializedPythonVariable,
)
from .torch import torch_special_class_types, TorchVariable
from .torch import TorchInGraphFunctionVariable
from .torch_function import build_torch_function_fn, TensorWithTFOverrideVariable
from .user_defined import (
KeyedJaggedTensorVariable,
Expand Down Expand Up @@ -320,6 +315,8 @@ def _type_dispatch(cls):
torch.Size,
torch.device,
torch.dtype,
torch.memory_format,
torch.layout,
),
cls.wrap_literal,
),
Expand Down Expand Up @@ -475,7 +472,7 @@ def index_source(key):
elif ConstantVariable.is_literal(value): # non-atomic literals
return self.wrap_literal(value)
elif istype(value, frozenset) and (
all(is_allowed(x) or ConstantVariable.is_literal(x) for x in value)
ConstantVariable.is_literal(x) for x in value
):
# For frozenset, we can guard by object ID instead of value
# equality, this allows us to handle non-literal values
Expand Down Expand Up @@ -588,13 +585,6 @@ def index_source(key):
elif isinstance(value, HigherOrderOperator):
self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH)
return TorchHigherOrderOperatorVariable.make(value, source=self.source)
elif type(value).__name__ == "builtin_function_or_method" and isinstance(
value.__self__, torch_special_class_types
):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return TorchVariable(
value,
)
elif isinstance(value, _StreamBase):
self.install_guards(GuardBuilder.ID_MATCH)
return StreamVariable(
Expand Down Expand Up @@ -716,18 +706,18 @@ def index_source(key):
return trace_rules.lookup(value).create_with_source(
value, source=self.source
)
elif istype(value, (types.ModuleType, replay_record.DummyModule)):
elif (
istype(value, (types.ModuleType, replay_record.DummyModule))
# type(torch.backends.cudnn) -> <class 'torch.backends.cudnn.CudnnModule'>
# type(torch.ops) -> <class 'torch._ops._Ops'>
or value in [torch.backends.cudnn, torch.ops]
or isinstance(value, torch._ops._OpNamespace)
):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return PythonModuleVariable(
value,
source=self.source,
)
elif is_allowed(value):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return TorchVariable(
value,
source=self.source,
)
elif (
is_function(value)
and skipfiles.check(value, is_inlined_call=True)
Expand All @@ -747,7 +737,7 @@ def index_source(key):
source=self.source,
)
elif isinstance(value, types.MethodType) and isinstance(
value.__self__, torch.nn.Module
value.__self__, (torch.nn.Module, torch.utils._pytree.TreeSpec)
):
# don't let MethodTypes fall through to UserDefinedObject,
# which doesn't support 'CALL_FUNCTION'
Expand Down Expand Up @@ -800,6 +790,7 @@ def index_source(key):
),
)
else:
# breakpoint()
self.install_guards(GuardBuilder.TYPE_MATCH)
result = UserDefinedObjectVariable(value, source=self.source)
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
Expand Down Expand Up @@ -1478,7 +1469,7 @@ def _clone_input(value):
and isinstance(proxy.node.target.__self__, torch._C.Generator)
or proxy.node.target == torch.random.set_rng_state
):
return TorchVariable(proxy.node.target)
return TorchInGraphFunctionVariable(proxy.node.target)
elif (
proxy.node.target == torch._C._DisableFuncTorch
or proxy.node.target == torch.cuda._is_in_bad_fork
Expand Down Expand Up @@ -1898,10 +1889,10 @@ def __call__(self, tx, value) -> VariableTracker:
return SourcelessBuilder.wrap_constant_literal(value)
elif is_builtin_callable(value):
return BuiltinVariable(value)
elif is_allowed(value):
elif trace_rules.lookup(value) is not None:
if is_user_defined_allowed(value):
self.tx.output.has_user_defined_allowed_in_graph = True
return TorchVariable(value)
return trace_rules.lookup(value)(value)
elif isinstance(value, types.FunctionType):
return UserFunctionVariable(value)
elif isinstance(value, enum.Enum):
Expand Down
32 changes: 17 additions & 15 deletions torch/_dynamo/variables/builtin.py
Expand Up @@ -703,7 +703,9 @@ def _call_min_max_binary(self, tx, a, b):

# result of an item call is a scalar convert to a tensor
if isinstance(a, FakeItemVariable):
a = variables.TorchVariable(torch.tensor).call_function(tx, [a], {})
a = variables.TorchInGraphFunctionVariable(torch.tensor).call_function(
tx, [a], {}
)

# Dynamic input does not get resolved, rather, gets stored as call_function
if isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable):
Expand All @@ -726,7 +728,7 @@ def _call_min_max_binary(self, tx, a, b):

fn = variables.NumpyVariable(np.clip)
else:
fn = variables.TorchVariable(torch.clamp)
fn = variables.TorchInGraphFunctionVariable(torch.clamp)
kwargs = {"min": b} if (self.fn is max) else {"max": b}
result = fn.call_function(tx, [a], kwargs)
else:
Expand All @@ -737,7 +739,7 @@ def _call_min_max_binary(self, tx, a, b):
fn = variables.NumpyVariable(fn)
else:
fn = {max: torch.maximum, min: torch.minimum}[self.fn]
fn = variables.TorchVariable(fn)
fn = variables.TorchInGraphFunctionVariable(fn)
result = fn.call_function(tx, [a, b], {})

# return unspec if both a, b are unspec or const
Expand Down Expand Up @@ -1122,7 +1124,6 @@ def call_getattr(
GetAttrVariable,
PythonModuleVariable,
TorchInGraphFunctionVariable,
TorchVariable,
UserFunctionVariable,
)
from .builder import SourcelessBuilder, VariableBuilder
Expand Down Expand Up @@ -1235,11 +1236,19 @@ def _grad_changed(old, new):
except NotImplementedError:
return GetAttrVariable(obj, name, **options)
elif isinstance(obj, TorchInGraphFunctionVariable):
# Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default.
member = getattr(obj.value, name)
if trace_rules.lookup(member) is not None:
return trace_rules.lookup(member)(member, **options)
elif isinstance(obj, TorchVariable):
member = getattr(obj.value, name)
if trace_rules.is_aten_op_or_tensor_method(member):
return TorchInGraphFunctionVariable(member, **options)
elif isinstance(obj, (PythonModuleVariable, DummyModule)):
if obj.is_torch:
member = getattr(obj.value, name)
else:
member = obj.value.__dict__[name]

if config.replay_record_enabled:
tx.exec_recorder.record_module_access(obj.value, name, member)

if is_utils_checkpoint(member):
options["source"] = source
return build_checkpoint_variable(**options)
Expand All @@ -1249,13 +1258,6 @@ def _grad_changed(old, new):
return VariableBuilder(tx, source)(member)
else:
return SourcelessBuilder()(tx, member)
elif isinstance(obj, (PythonModuleVariable, DummyModule)):
member = obj.value.__dict__[name]

if config.replay_record_enabled:
tx.exec_recorder.record_module_access(obj.value, name, member)

return VariableBuilder(tx, source)(member)
elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
return ConstantVariable.create(getattr(obj.fn, name))
else:
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/variables/higher_order_ops.py
Expand Up @@ -1175,7 +1175,7 @@ def call_function(
"kwargs have not been implemented for torch.autograd.Function"
)

from . import TorchVariable
from . import TorchInGraphFunctionVariable

always_restore = self.value.__name__ == "trampoline_autograd_bwd"
if (
Expand All @@ -1184,7 +1184,7 @@ def call_function(
):
fn = UserFunctionVariable(self.value, source=self.source)
else:
fn = TorchVariable(self.value)
fn = TorchInGraphFunctionVariable(self.value)
# TODO(jansel): BUG!!! we aren't copying on the line below, so the post-pre check below is pointless
pre_guards = tx.output.guards
# In eager-mode PyTorch, if we only compute first-order gradients,
Expand Down

0 comments on commit f657b2b

Please sign in to comment.