-
Notifications
You must be signed in to change notification settings - Fork 21.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make propagate_real_tensor more safe (#126281)
Internal xref: https://fb.workplace.com/groups/6829516587176185/posts/7228787720582401/ There a few improvements here, which luckily fix some xfails: * In generally, it can be unsafe to call operations on Tensors under a `no_dispatch()` mode that is purely trying to disable ambient modes, because this ALSO disables tensor subclass handling. So we test to see if there is a tensor subclass and don't propagate real tensors if that's the case. Another acceptable outcome might be to try to only disable the ambient fake tensor mode, this would help us propagate real tensors through more exotic tensor types, but I'm not going to do it until someone asks for it. * We're graph breaking for wrapped tensors too late. Pull it up earlier so we do it before we try to muck around with the real tensor. * I noticed that occasionally when I do `storage.copy_(real_storage)`, the sizes mismatch. Careful code reading suggests that I should just copy in the real data when the tensor was initially allocated, so that's what I do now, eliminating the need for a storage copy. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #126281 Approved by: https://github.com/Skylion007
- Loading branch information
Showing
2 changed files
with
42 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters