Skip to content

Commit

Permalink
Make propagate_real_tensor more safe (#126281)
Browse files Browse the repository at this point in the history
Internal xref: https://fb.workplace.com/groups/6829516587176185/posts/7228787720582401/

There a few improvements here, which luckily fix some xfails:

* In generally, it can be unsafe to call operations on Tensors under a `no_dispatch()` mode that is purely trying to disable ambient modes, because this ALSO disables tensor subclass handling. So we test to see if there is a tensor subclass and don't propagate real tensors if that's the case. Another acceptable outcome might be to try to only disable the ambient fake tensor mode, this would help us propagate real tensors through more exotic tensor types, but I'm not going to do it until someone asks for it.
* We're graph breaking for wrapped tensors too late. Pull it up earlier so we do it before we try to muck around with the real tensor.
* I noticed that occasionally when I do `storage.copy_(real_storage)`, the sizes mismatch. Careful code reading suggests that I should just copy in the real data when the tensor was initially allocated, so that's what I do now, eliminating the need for a storage copy.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: #126281
Approved by: https://github.com/Skylion007
  • Loading branch information
ezyang authored and pytorchmergebot committed May 15, 2024
1 parent b2d9b80 commit 3ae1182
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 36 deletions.
5 changes: 0 additions & 5 deletions test/test_fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,9 +868,6 @@ def test__adaptive_avg_pool2d_backward(self):
== torch.channels_last
)

# Propagate real tensors doesn't work when original input arguments are
# fake
@expectedFailurePropagateRealTensors
def test_export_numpy(self):
class MyNumpyModel(torch.nn.Module):
def forward(self, input):
Expand Down Expand Up @@ -1477,7 +1474,6 @@ def to_fake_tensor(x):
failed = True
self.assertTrue(failed)

@expectedFailurePropagateRealTensors # Propagate real tensors doesn't work with fake-on-fake
def test_fake_tensor_prop_on_nn_module_with_optional_args(self):
class OptionalArgumentInBetween(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -1510,7 +1506,6 @@ def forward(self, value, another_value=None, another_optional_value=None):
value, None, another_optional_value
)

@expectedFailurePropagateRealTensors # TODO: not sure about this one, kinda strange
def test_unbacked_shape_realloc(self):
def f(x):
return x.nonzero()
Expand Down
73 changes: 42 additions & 31 deletions torch/_subclasses/meta_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,30 @@ def shape(self):
return self.size


# A more faithful reproduction would do a copy on the entire
# storage, but this needs to be done carefully because the
# underlying storage could have larger extent than is implied
# by size/stride. The real fix is to properly call
# meta_storage recursively here.
#
# These "safe" functions are intended to be used under no_dispatch() mode.
# The no_dispatch() here is intended to prevent ambient fake tensor mode from
# fakeifying the operation. But if we are given an honest to goodness
# FakeTensor as src, we MUST NOT run the copy/clone operation. A better way
# to do this would be to not use no_dispatch and instead just disable fake
# tensor mode only (allowing for subclass dispatch to occur)
def _safe_copy(dst, src):
if type(src) is not torch.Tensor:
return
dst.copy_(src)


def _safe_clone(src):
if type(src) is not torch.Tensor:
return None
return src.clone()


# This is a class for converting multiple tensors into meta tensors which
# share the same view/storage structure. The operation model is you allocate
# one of these, and then call it repeatedly on all the tensors you want to
Expand Down Expand Up @@ -513,6 +537,8 @@ def meta_storage(self, s: MetaStorageDesc, callback):
lambda: torch.empty(s.size, dtype=torch.uint8, device="meta"),
).untyped_storage()
if self.copy_data:
# NB: no_dispatch is needed because internally storage copy is
# implemented as Tensor operations
with torch.no_grad(), no_dispatch():
assert s.data is not None
r_s.real_storage = s.data.clone()
Expand Down Expand Up @@ -677,13 +703,6 @@ def transform(attr, inner_t):
),
)
)
# Note [Inaccessible data is not copied]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# A more faithful reproduction would do a copy on the entire
# storage, but this needs to be done carefully because the
# underlying storage could have larger extent than is implied
# by size/stride. The real fix is to properly call
# meta_storage recursively here.
if self.copy_data:
with torch.no_grad(), no_dispatch():
r.real_tensor = torch.empty_strided(
Expand All @@ -693,9 +712,7 @@ def transform(attr, inner_t):
device=inner_t.device,
)
assert inner_t.data is not None
r.real_tensor.copy_(
inner_t.data
) # Note [Inaccessible data is not copied]
_safe_copy(r.real_tensor, inner_t.data)
return r

transformed_tensors_dict = {
Expand Down Expand Up @@ -943,7 +960,7 @@ def tensor_visitor_fn(
# Pray that sparse clone doesn't lose information
assert t.data is not None
with torch.no_grad(), no_dispatch():
r.real_tensor = t.data.clone()
r.real_tensor = _safe_clone(t.data)
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
# Note [is_coalesced is dispatched]
# Strangely enough, is_coalesced() is a dispatched operator,
Expand Down Expand Up @@ -995,7 +1012,7 @@ def tensor_visitor_fn(
# Pray sparse clone doesn't lose information
assert t.data is not None
with torch.no_grad(), no_dispatch():
r.real_tensor = t.data.clone()
r.real_tensor = _safe_clone(t.data)
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
if t.requires_grad:
r.requires_grad = True
Expand Down Expand Up @@ -1033,9 +1050,7 @@ def tensor_visitor_fn(
t.size, t.stride, dtype=t.dtype, device=t.device
)
assert t.data is not None
r.real_tensor.copy_(
t.data
) # Note [Inaccessible data is not copied]
_safe_copy(r.real_tensor, t.data)
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
if t.requires_grad:
r.requires_grad = True
Expand Down Expand Up @@ -1135,10 +1150,7 @@ def _to_fake_tensor(t: MetaTensorDesc):
device=t.device,
)
assert t.data is not None
# Note [Inaccessible data is not copied]
r.real_tensor.copy_( # type: ignore[attr-defined]
t.data
)
_safe_copy(r.real_tensor, t.data) # type: ignore[attr-defined]
return r

r = _to_fake_tensor(t)
Expand Down Expand Up @@ -1273,6 +1285,13 @@ def is_c_of_r(complex_dtype, real_dtype):
else:
is_leaf = t.is_leaf

# Graph-Break for wrapped tensors
if (
not (t.is_batchedtensor or t.is_gradtrackingtensor)
and t.is_functorch_wrapped
) or t.is_legacy_batchedtensor:
return NotImplemented

(
sizes,
strides,
Expand Down Expand Up @@ -1301,6 +1320,7 @@ def is_c_of_r(complex_dtype, real_dtype):
r.real_tensor = torch.empty_strided(
t.size, t.stride, dtype=t.dtype, device=t.device
)
_safe_copy(r.real_tensor, t.data)

assert safe_is_leaf(r), "the callback you passed in doesn't detach"
if t.requires_grad:
Expand All @@ -1320,13 +1340,6 @@ def is_c_of_r(complex_dtype, real_dtype):
1,
)(r)

# Graph-Break for wrapped tensors
if (
not (t.is_batchedtensor or t.is_gradtrackingtensor)
and t.is_functorch_wrapped
) or t.is_legacy_batchedtensor:
return NotImplemented

s = t.storage
assert s is not None
if s.id not in self.storage_memo and (
Expand All @@ -1339,11 +1352,9 @@ def is_c_of_r(complex_dtype, real_dtype):
# You're normal and happy, install the fresh storage into the memo
self.set_storage_memo(s, r.untyped_storage())
if self.copy_data:
with torch.no_grad(), no_dispatch():
r.real_tensor.untyped_storage().copy_(s.data)
r.untyped_storage().real_storage = (
r.real_tensor.untyped_storage()
)
r.untyped_storage().real_storage = (
r.real_tensor.untyped_storage()
)
else:
# You're in crazy town; somehow you gave us a tensor
# that wasn't a view, but had nonzero storage offset,
Expand Down

0 comments on commit 3ae1182

Please sign in to comment.