Skip to content

Commit

Permalink
Support tracing base torch_function impl (#111731)
Browse files Browse the repository at this point in the history
Pull Request resolved: #111731
Approved by: https://github.com/jansel
ghstack dependencies: #111730
  • Loading branch information
mlazos authored and pytorchmergebot committed Oct 23, 2023
1 parent 0b424ee commit fb88760
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 93 deletions.
1 change: 1 addition & 0 deletions torch/_dynamo/allowed_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def _disallowed_function_ids():
torch._C._dynamo.eval_frame.unsupported,
torch.Tensor.__init__,
torch.resize_as_,
torch._tensor._convert,
]

# extract all dtypes from torch
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/skipfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def _module_dir(m: types.ModuleType):
"torch.utils._contextlib",
"torch.utils._foreach_utils",
"torch.utils._pytree",
"torch._tensor",
}


Expand Down
14 changes: 11 additions & 3 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2252,10 +2252,18 @@ def check_inlineable(func):
unimplemented("Patched init cannot be inlined.")

try:
if id(func.get_function()) in allowed_functions._disallowed_function_ids:
unimplemented(f"inlining disallowed: {func.get_function()}")
func_value = func.get_function()
except NotImplementedError:
pass # closures
func_value = None

if (
func.get_name() == "__torch_function__"
or func_value is torch._tensor._convert
):
return skipfiles.SkipResult(False, "Allow __torch_function__")

if func_value and id(func_value) in allowed_functions._disallowed_function_ids:
unimplemented(f"inlining disallowed: {func_value}")

result = skipfiles.check_verbose(func, allow_torch=True)
if result.skipped:
Expand Down
13 changes: 8 additions & 5 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@
UnspecializedPythonVariable,
)
from .torch import tensor_dunder_fns, torch_special_class_types, TorchVariable
from .torch_function import TensorWithTFOverrideVariable
from .torch_function import build_torch_function_fn, TensorWithTFOverrideVariable
from .user_defined import (
KeyedJaggedTensorVariable,
UserDefinedClassVariable,
Expand Down Expand Up @@ -1083,10 +1083,9 @@ def wrap_tensor(self, value: torch.Tensor):
)
options = {}
if type(value) in config.traceable_tensor_subclasses:
options["torch_function_fn"] = VariableBuilder(
self.tx,
AttrSource(AttrSource(self.source, "__torch_function__"), "__func__"),
)(value.__torch_function__.__func__)
options["torch_function_fn"] = build_torch_function_fn(
self.tx, value, self.source
)
options["guards"] = self.make_guards(GuardBuilder.TYPE_MATCH)
else:
options["guards"] = set()
Expand Down Expand Up @@ -1842,6 +1841,10 @@ def __call__(self, tx, value) -> VariableTracker:
dict,
mutable_local=MutableLocal(),
)
elif isinstance(value, set):
return SetVariable(
[self(tx, x) for x in value], mutable_local=MutableLocal()
)
elif isinstance(value, (tuple, list)):
cls = BaseListVariable.cls_for(type(value))
return cls([self(tx, x) for x in value], mutable_local=MutableLocal())
Expand Down
7 changes: 7 additions & 0 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,13 @@ def call_isinstance(self, tx, arg, isinstance_type):
val = arg_type is isinstance_type
return variables.ConstantVariable.create(val)

def call_issubclass(self, tx, left_ty, right_ty):
"""Checks if first arg is subclass of right arg"""
left_ty = left_ty.as_python_constant()
right_ty = right_ty.as_python_constant()

return variables.ConstantVariable(issubclass(left_ty, right_ty))

def call_super(self, tx, a, b):
return variables.SuperVariable(a, b)

Expand Down
3 changes: 3 additions & 0 deletions torch/_dynamo/variables/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,7 @@ def reconstruct(self, codegen):
# Note - this is only used for producing a set
def _as_set_element(self, vt):
from .base import VariableTracker
from .misc import MethodWrapperVariable
from .tensor import TensorVariable

assert isinstance(vt, VariableTracker)
Expand All @@ -832,6 +833,8 @@ def _as_set_element(self, vt):
return SetVariable.SetElement(vt, fake_tensor)
if isinstance(vt, ConstantVariable):
return SetVariable.SetElement(vt, vt.value)
if isinstance(vt, MethodWrapperVariable):
return SetVariable.SetElement(vt, vt.as_python_constant())

unimplemented(f"Sets with {type(vt)} NYI")

Expand Down
60 changes: 1 addition & 59 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,65 +113,7 @@ def call_method(
)
inner_fn, source = self._resolved_getattr_and_source(self, name)

# This variable is True when it corresponds to user code such as
#
# super().__torch_function__(...)
#
# and the super().__torch_function__ attribute resolves
# to torch.Tensor.__torch_function__.
is_original_tensor_torch_function = (
name == "__torch_function__"
# for now, only support one level of inheritance
and len(self.objvar.value.__mro__) > 1
and self.objvar.value.__mro__[1] == torch.Tensor
)
if is_original_tensor_torch_function:
# Instead of tracing inside torch.Tensor.__torch_function__,
# record the `call_function` or `call_method` call into the graph.
from . import ConstantVariable, ConstDictVariable, TorchVariable
from .builder import wrap_fx_proxy

original_torch_or_getattr_variable = args[0]
new_args = args[2].items
# TODO (mlazos): this is a hack to handle kwargs properly for a test
# we assume that kwargs is either a dict or None.
# Rather than inserting the base torch function impl into the graph
# we should trace it properly. We should be able to remove all of this
# code starting from "is_original_tensor_torch_function" above.
if not isinstance(args[3], ConstantVariable):
new_kwargs = args[3].items
options = VariableTracker.propagate(self, new_args, new_kwargs)
else:
new_kwargs = ConstDictVariable(dict(), dict).items
options = VariableTracker.propagate(self, new_args, [args[3]])
# Disable __torch_function__ here to prevent the clone of the
# example tensor from going into the override.
with torch._C.DisableTorchFunctionSubclass():
if isinstance(args[0], TorchVariable):
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
original_torch_or_getattr_variable.value,
*proxy_args_kwargs(new_args, new_kwargs),
),
**options,
)
elif isinstance(args[0], GetAttrVariable):
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_method",
original_torch_or_getattr_variable.name,
*proxy_args_kwargs(new_args, new_kwargs),
),
**options,
)
else:
unimplemented(
f"GetAttrVariable.call_function original __torch_function__ {args}"
)
elif inner_fn is object.__init__:
if inner_fn is object.__init__:
return LambdaVariable(identity, **options)
elif inner_fn is torch.nn.Module.__init__:
objvar = self.objvar
Expand Down
24 changes: 24 additions & 0 deletions torch/_dynamo/variables/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def call_method(
unimplemented(f"Illegal method invocation {name} in strict mode")
from . import ConstantVariable, TorchVariable, TupleVariable
from .builder import wrap_fx_proxy
from .user_defined import UserDefinedClassVariable

kwargs = dict(kwargs)
options = VariableTracker.propagate(self, args, kwargs.values())
Expand Down Expand Up @@ -478,6 +479,29 @@ def make_const_size_variable(x, **options):
),
**options,
)
elif (
name == "as_subclass"
and len(args) == 1
and isinstance(args[0], UserDefinedClassVariable)
):
from .builder import VariableBuilder
from .torch_function import TensorWithTFOverrideVariable

# [Note: __torch_function__] coerce this tensor variable into a TensorWithTFOverrideVariable
# in eager, this is just a type change. This isn't sound if a __torch_function__ tensor subclass
# defines a constructor, but if only a __torch_function__ impl is defined, this is okay to call.
# It is up to the user whether this is correct behavior or not.
py_cls = args[0].as_python_constant()
torch_fn = VariableBuilder(
tx,
AttrSource(
AttrSource(args[0].source, "__torch_function__"), "__func__"
),
)(py_cls.__torch_function__.__func__)

return TensorWithTFOverrideVariable.from_tensor_var(
tx, self, py_cls, torch_fn
)
elif name == "get_device" and isinstance(self.device, torch.device):
index = self.device.index if self.device.type != "cpu" else -1
constant_result = ConstantVariable.create(index, **options)
Expand Down
37 changes: 12 additions & 25 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,7 @@
from .distributed import is_constant_pg_functions, is_from_local, ProcessGroupVariable
from .higher_order_ops import TorchHigherOrderOperatorVariable
from .lists import ListVariable, TupleVariable
from .torch_function import (
can_dispatch_torch_function,
dispatch_torch_function,
TensorWithTFOverrideVariable,
)
from .torch_function import can_dispatch_torch_function, dispatch_torch_function

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -249,6 +245,16 @@ def call_function(
self.value,
source=self.source,
).call_function(tx, args, kwargs)
if self.value is torch.overrides.get_default_nowrap_functions:
# [Note: __torch_function__] we return empty here because we restrict
# the set of functions that we trace __torch_function__ on to
# functions outside of the actual set. Implementing this properly will require implementing
# some variable types to track and compare tensor getset descriptors
from .builder import SourcelessBuilder

return SourcelessBuilder()(
tx, torch.overrides.get_default_nowrap_functions()
).add_options(options)
elif self.value in config.constant_functions:
assert not args and not kwargs
return ConstantVariable.create(
Expand Down Expand Up @@ -428,26 +434,7 @@ def call_function(
else:
unimplemented(f"torch.from_numpy(<{type(t)}>)")
elif can_dispatch_torch_function(tx, args, kwargs):
unwrapped = dispatch_torch_function(tx, self, args, kwargs)
# The wrapping here follows the logic in
# `torch.Tensor.__torch_function__`.
# TODO: This shouldn't be here as well, this should be traced in the base torch function
# impl
if self.value in torch.overrides.get_default_nowrap_functions():
return unwrapped

# TODO: It's not correct to always rewrap args[0]; with multiple subclasses the dispatch
# may be on the second or later argument. Fix this to respect what dispatch_torch_function says
# the dispatch should be.
# TODO: We also should not be rewrapping unconditionally, it's possible that
# the return value *MAY NOT* be a torch function override tensor.
# The solution here is to trace the base torch function impl
return TensorWithTFOverrideVariable.from_tensor_var(
tx,
unwrapped,
args[0].class_type,
args[0].torch_function_fn,
)
return dispatch_torch_function(tx, self, args, kwargs)
elif self.value in [
torch.amp.autocast_mode.autocast,
torch.cuda.amp.autocast,
Expand Down
14 changes: 13 additions & 1 deletion torch/_dynamo/variables/torch_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
from torch.utils._pytree import tree_flatten
from ..exc import unimplemented
from ..source import GlobalSource
from ..source import AttrSource, GlobalSource
from ..utils import is_tensor_base_attr_getter
from .base import VariableTracker
from .constant import ConstantVariable
Expand Down Expand Up @@ -54,6 +54,18 @@ def call_torch_function(
return tx.inline_user_function_return(torch_function_var, tf_args, kwargs)


def build_torch_function_fn(tx, value, source):
from .builder import SourcelessBuilder, VariableBuilder

if not source:
return VariableBuilder(
tx,
AttrSource(AttrSource(source, "__torch_function__"), "__func__"),
)(value.__torch_function__.__func__)
else:
return SourcelessBuilder()(tx, value.__torch_function__.__func__)


def can_dispatch_torch_function(tx, args, kwargs):
if tx.output.torch_function_enabled:
all_args = tree_flatten(args)[0] + tree_flatten(kwargs)[0]
Expand Down

0 comments on commit fb88760

Please sign in to comment.