Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 65 additions & 70 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import traceback
import weakref
from dataclasses import dataclass
from functools import partial
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -46,7 +45,7 @@
TorchDispatchMode,
)

from torch.utils._pytree import PyTree, tree_map, tree_map_only
from torch.utils._pytree import PyTree, tree_map
from torch.utils._stats import count, count_label
from torch.utils.weak import WeakIdRef

Expand Down Expand Up @@ -511,7 +510,8 @@ def is_symbolic(x):
is_symbolic(x) for x in itertools.chain(args, kwargs.values())
)
if not require_dynamic:
return run_fallback_kernel(fake_mode, func, args, kwargs, None)
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None)

raise UnsupportedOperatorException(func)

Expand Down Expand Up @@ -1206,7 +1206,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return func(*args, **kwargs)

@staticmethod
def _find_common_device(func, args, kwargs) -> Tuple[torch.device, bool]:
def _find_common_device(func, flat_args) -> Tuple[torch.device, bool]:
# Returns: (common_device, has_scalar_only_inputs)

# cpu - zero-dim tensors can be called in cuda kernels,
Expand Down Expand Up @@ -1253,8 +1253,8 @@ def merge_devices(t):
f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}"
)

pytree.tree_map_(merge_devices, args)
pytree.tree_map_(merge_devices, kwargs)
for arg in flat_args:
merge_devices(arg)

# some functions that allow Python numbers to bind to Tensors
# if we have failed to find a device, and we're running one of these operators,
Expand Down Expand Up @@ -1434,19 +1434,23 @@ def dispatch(self, func, types, args=(), kwargs=None):
with in_kernel_invocation_manager(self):
return func(*args, **kwargs)

flat_args, args_spec = pytree.tree_flatten((args, kwargs))

flat_arg_fake_tensors = [
t
for t in tree_flatten_only(FakeTensor, (args, kwargs))
if self.is_our_fake(t)
t for t in flat_args if isinstance(t, FakeTensor) and self.is_our_fake(t)
]
flat_symints = tree_flatten_only(torch.SymInt, (args, kwargs))
has_symbolic_sizes = (
any(i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors)
or len(flat_symints) > 0
)
has_symbolic_sizes = any(
i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors
) or any(isinstance(a, torch.SymInt) for a in flat_args)

converter = self.fake_tensor_converter

def maybe_to_constant(t):
if isinstance(t, FakeTensor) and self.is_our_fake(t):
return t.constant
else:
return t

# To constant propagate through these functions:
# 1, If this is a lift, the input tensor is guaranteed to be a
# constant, so we keep a copy of the original argument along so
Expand All @@ -1460,11 +1464,8 @@ def dispatch(self, func, types, args=(), kwargs=None):
assert all(
t.constant is not None for t in flat_arg_fake_tensors
), f"{func} should not have fake inputs without constants"
const_args, const_kwargs = pytree.tree_map_only(
FakeTensor,
lambda t: t.constant if self.is_our_fake(t) else t,
(args, kwargs),
)
const_flat_args = [maybe_to_constant(a) for a in flat_args]
const_args, const_kwargs = pytree.tree_unflatten(const_flat_args, args_spec)
out = func(*const_args, **const_kwargs)
if type(out) is torch.Tensor and self.may_turn_const(out):
# NB: not in_kernel_invocation_manager because we're doing real
Expand All @@ -1481,7 +1482,7 @@ def dispatch(self, func, types, args=(), kwargs=None):
# tensor, it might be related to this line. Though I'm not sure
# how you'll know to read this comment, as this line won't show up
# in the stack trace.
unrecognized_types = self.check_for_subclass(args, kwargs)
unrecognized_types = self.check_for_subclass(flat_args)
if unrecognized_types:
not_implemented_log.debug(
"FakeTensorMode unrecognized subclass(es): %s", unrecognized_types
Expand All @@ -1503,11 +1504,10 @@ def dispatch(self, func, types, args=(), kwargs=None):

# Recompute flat_arg_fake_tensors here again in case some of the inputs
# were real tensors and fakified in validate_and_convert_non_fake_tensors
(
args,
kwargs,
flat_arg_fake_tensors,
) = self.validate_and_convert_non_fake_tensors(func, converter, args, kwargs)
(flat_args, flat_arg_fake_tensors) = self.validate_and_convert_non_fake_tensors(
func, converter, flat_args, args_spec
)
del args, kwargs # Invalidated

# The current constant handling only support tracing systems
# (aot autograd, torchdynamo) where each operation is run consecutively.
Expand All @@ -1527,20 +1527,17 @@ def dispatch(self, func, types, args=(), kwargs=None):
and len(flat_arg_fake_tensors) != 0
and not has_symbolic_sizes
):
const_args, const_kwargs = pytree.tree_map_only(
FakeTensor,
lambda t: t.constant if self.is_our_fake(t) else t,
(args, kwargs),
)
const_flat_args = [maybe_to_constant(a) for a in flat_args]
const_args, const_kwargs = pytree.tree_unflatten(const_flat_args, args_spec)

# NB: not in_kernel_invocation_manager(self) as we want to do REAL
# compute
with no_dispatch():
out = func(*const_args, **const_kwargs)

all_constant = pytree.tree_all_only(
torch.Tensor, lambda t: self.may_turn_const(t), out
)
flat_out = pytree.tree_leaves(out)
flat_out_tensors = [t for t in flat_out if isinstance(t, torch.Tensor)]
all_constant = all(self.may_turn_const(t) for t in flat_out_tensors)

if all_constant:
return pytree.tree_map_only(
Expand All @@ -1551,11 +1548,12 @@ def dispatch(self, func, types, args=(), kwargs=None):

# we weren't able to turn outputs to constants,
# so invalidate all constants that might be aliases of the outputs
for ten in tree_flatten_only(torch.Tensor, out):
for ten in flat_out_tensors:
converter.invalidate_constant_aliases(ten)

# we are falling through to running non constant tensors, any input constant that
# is written to must be invalidated
args, kwargs = pytree.tree_unflatten(flat_args, args_spec)
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)

# Try for fastpath
Expand Down Expand Up @@ -1659,7 +1657,7 @@ def maybe_run_unsafe_fallback(error=None):
raise UnsupportedOperatorException(func)
if error is None:
error = UnsupportedOperatorException(func)
return run_fallback_kernel(self, func, args, kwargs, error)
return run_fallback_kernel(self, func, flat_args, args_spec, error)

# Optimization: If there is no Meta kernel, it takes a surprisingly long
# amount of time to catch the NotImplementedError, so we check it here.
Expand All @@ -1677,7 +1675,9 @@ def maybe_run_unsafe_fallback(error=None):
except NotImplementedError as not_implemented_error:
return maybe_run_unsafe_fallback(not_implemented_error)

return self.wrap_meta_outputs_with_default_device_logic(r, func, args, kwargs)
return self.wrap_meta_outputs_with_default_device_logic(
r, func, flat_args, device=kwargs.get("device")
)

# [subclass inputs]
# Suppose we enable fake tensor mode. This means that fake tensor
Expand All @@ -1689,19 +1689,20 @@ def maybe_run_unsafe_fallback(error=None):
# fake tensor is not supported. What we actually wanted to happen
# was to give the subclass a chance to figure out what it wants to
# before erroring out. Returning NotImplemented here allows this.
def check_for_subclass(self, args, kwargs):
def check_for_subclass(self, flat_args):
def check(x):
return (
not isinstance(x, FakeTensor)
isinstance(x, torch.Tensor)
and not isinstance(x, FakeTensor)
and type(x) is not torch.Tensor
and type(x) is not torch.nn.Parameter
)

return [
type(x) for x in tree_flatten_only(torch.Tensor, (args, kwargs)) if check(x)
]
return [type(x) for x in flat_args if check(x)]

def validate_and_convert_non_fake_tensors(self, func, converter, args, kwargs):
def validate_and_convert_non_fake_tensors(
self, func, converter, flat_args, args_spec
):
"""
Checks if the list of tensors are fake tensors.
If not, try to convert them to fake tensors.
Expand All @@ -1710,15 +1711,20 @@ def validate_and_convert_non_fake_tensors(self, func, converter, args, kwargs):
flat_arg_fake_tensors = []

def validate(x):
if not isinstance(x, torch.Tensor):
return x

nonlocal flat_arg_fake_tensors
if not self.is_our_fake(x):
if torch.Tag.inplace_view in func.tags:
args, kwargs = pytree.tree_unflatten(flat_args, args_spec)
raise Exception(
f"Can't call metadata mutating ops on non-Fake Tensor inputs. Found in {render_call(func, args, kwargs)}"
)
if not self.allow_non_fake_inputs:
if isinstance(x, FakeTensor) and x.fake_mode is not self:
raise AssertionError("Mixing fake modes NYI")
args, kwargs = pytree.tree_unflatten(flat_args, args_spec)
raise Exception(
f"Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode "
f"with 'allow_non_fake_inputs'. Found in {render_call(func, args, kwargs)}"
Expand All @@ -1729,38 +1735,25 @@ def validate(x):
flat_arg_fake_tensors.append(x)
return x

args, kwargs = tree_map_only(
torch.Tensor,
validate,
(args, kwargs),
)
return args, kwargs, flat_arg_fake_tensors

def wrap_meta_outputs_with_default_device_logic(self, r, func, args, kwargs):
wrap = self.gen_wrap_fn(func, args, kwargs)

# if device is specified, use that
if kwargs.get("device", None):
return tree_map(partial(wrap, device=kwargs["device"]), r)
validated_args = [validate(a) for a in flat_args]
return validated_args, flat_arg_fake_tensors

return tree_map(partial(wrap), r)

def gen_wrap_fn(self, func, args, kwargs):
def wrap_meta_outputs_with_default_device_logic(self, r, func, flat_args, device):
converter = self.fake_tensor_converter

# Lazily initialized, in case there are no tensor returns
common_device = None
has_scalar_only_inputs = False

def wrap(e, device=None):
def wrap(e):
nonlocal common_device
nonlocal has_scalar_only_inputs

if isinstance(e, torch.Tensor) and common_device is None:
(
common_device,
has_scalar_only_inputs,
) = FakeTensor._find_common_device(func, args, kwargs)
) = FakeTensor._find_common_device(func, flat_args)

if self.is_our_fake(e):
torch._check(
Expand All @@ -1785,7 +1778,7 @@ def wrap(e, device=None):
else:
return e

return wrap
return tree_map(wrap, r)

def cpp_meta_supports_symint(self, func):
if torch.Tag.view_copy in func.tags:
Expand Down Expand Up @@ -1820,8 +1813,8 @@ def invalidate_written_to_constants(
self, func, flat_arg_fake_tensors, args, kwargs
):
any_constant = any(e.constant is not None for e in flat_arg_fake_tensors)
if any_constant and get_schema_info(func).is_mutable():
schema_info = get_schema_info(func)
schema_info = get_schema_info(func)
if any_constant and schema_info.is_mutable():
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
Expand Down Expand Up @@ -1868,7 +1861,9 @@ def from_tensor(


# NB: returns fake tensors
def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exception):
def run_fallback_kernel(
fake_mode, func, flat_args, args_spec, orig_not_implemented_exception
):
# these should all be supported, just to be safe
# avoid fallback for operators which inplace modify metadata
# because the input fake tensors would be umodified
Expand All @@ -1890,15 +1885,15 @@ def to_real_tensor(e):
return out
return e

args = tree_map(to_real_tensor, args)
kwargs = tree_map(to_real_tensor, kwargs)
flat_args = [to_real_tensor(a) for a in flat_args]
args, kwargs = pytree.tree_unflatten(flat_args, args_spec)

r = func(*args, **kwargs)

tensor_impls = set()
storages = set()

for e in pytree.arg_tree_leaves(*args, **kwargs):
for e in flat_args:
if isinstance(e, torch.Tensor):
if not e.is_sparse:
storages.add(e._typed_storage()._cdata)
Expand All @@ -1907,15 +1902,15 @@ def to_real_tensor(e):
# proper aliasing/metadata relationship between outputs and inputs will
# not be set up, bc of conversion to device, unless we can reuse an
# input impl
for e in pytree.tree_leaves(r):

def map_out(e):
if id(e) not in inp_impls and (
isinstance(e, torch.Tensor)
and not e.is_sparse
and e._typed_storage()._cdata in storages
):
raise orig_not_implemented_exception

def map_out(e):
if isinstance(e, torch.Tensor):
if id(e) in inp_impls:
return inp_impls[id(e)]
Expand All @@ -1924,7 +1919,7 @@ def map_out(e):
else:
return e

return tree_map(map_out, r)
return pytree.tree_map(map_out, r)


# Just for use to allow copying a module to fake tensors,
Expand Down