-
Notifications
You must be signed in to change notification settings - Fork 22.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add forward AD layout check for storage numel #68631
Conversation
[ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
You can add a comment to the PR and tag @pytorchbot with the following commands: # ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun
# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow For more information, please take a look at the CI Flow Wiki. |
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 69a9168 (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
ghstack-source-id: 7ca455cf4b51c447c9799b42de05b745105c2d37 Pull Request resolved: #68631
[ghstack-poisoned]
ghstack-source-id: ae34720e83b79d68ca64335c0eae3fa4e42c9adc Pull Request resolved: #68631
- Instead of a native function to compare two tensors, maybe a native function that just retrieves the storage numel. - Using the stride of the first batch dim for batched tensors, seems like a reasonable replacement for storage numel. [ghstack-poisoned]
- Instead of a native function to compare two tensors, maybe a native function that just retrieves the storage numel. - Using the stride of the first batch dim for batched tensors, seems like a reasonable replacement for storage numel. [ghstack-poisoned]
ghstack-source-id: 8a6d8de4ce15346689e8a29b0b9d20e828d21265 Pull Request resolved: #68631
- Instead of a native function to compare two tensors, maybe a native function that just retrieves the storage numel. - Using the stride of the first batch dim for batched tensors, seems like a reasonable replacement for storage numel. [ghstack-poisoned]
ghstack-source-id: 3d9bf80ff54db8740e544bf69e7e5624d9f1fb41 Pull Request resolved: #68631
- Instead of a native function to compare two tensors, maybe a native function that just retrieves the storage numel. - Using the stride of the first batch dim for batched tensors, seems like a reasonable replacement for storage numel. [ghstack-poisoned]
- Instead of a native function to compare two tensors, maybe a native function that just retrieves the storage numel. - Using the stride of the first batch dim for batched tensors, seems like a reasonable replacement for storage numel. [ghstack-poisoned]
ghstack-source-id: 933dbcae5cc2034c485e1c688ef76f59703faa7f Pull Request resolved: #68631
std::string func_name) { | ||
if (std::find(physical_sizes.begin(), physical_sizes.end(), 0) != physical_sizes.end()) { | ||
// The tensor has zero numel | ||
return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is the right approach because I'm not clear on exactly what cases this check guards against. Does this check still make sense to do if the numel of the tensor is zero?
I added this to address the zero-numel case where _new_zeros_with_same_feature_meta
prepends zeros as the strides if the storage numel of the tensor is zero.
Relaxing this check also seemed to address some failures:
- several in OpInfos disabled for batched forward grad computation #66357.
- batched gradgrad tests for linalg ops (they were failing previously on empty inputs)
Sidenote: there were other batched gradgrad (and even some batched grad) tests in linalg and fft that I seem to pass even without this fix. Having xfail instead of skip would be much nicer.
cc @zou3519
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I initially understood this check to be "are the batch dimensions at the front". But I think there is a distinction between that and "are the batch dimensions at the front in layout" which is what this function actually seems to check.
Maybe this is not the check we need in _new_zeros_with_same_meta
and _storage_numel
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to modify this function?
Since we're almost done with getting vmap + forward-mode AD to compose in functorch, I want to stop supporting the vmap in core because it's too much unnecessary work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Maybe we wouldn't need to do it here, but do you think it makes sense to modify it in functorch though?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this some more -- if the tensor has zero numel then any batch dims are "trivially at the front in memory layout" because there is no storage. So yes it would make sense to modify it in functorch
@@ -85,6 +85,9 @@ namespace { | |||
return false; | |||
} | |||
} | |||
if (base._storage_numel() != other._storage_numel()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, there's a huge performance problem with _new_zeros_with_same_feature_meta
that arises with the invariant that base._storage_numel() == other._storage_numel()
Let's say we have x: Tensor[5], y: Tensor[3]
and we construct a dual tensor make_dual(primal=x[0], tangent=BatchedTensor(y))
.
When we set BatchedTensor(y)
to be the tangent, we require it to have the same storage numel as primal. So we end up creating coercing the tangent into a tensor of shape [3, 5].
If the user never actually (A) uses as_strided on their dual tensor to reveal elements that are not a subset of the view or (B) does certain in-place operations that are already unsupported by vmap (regular_tensor.inplace_(dual)), then they are paying for memory that they didn't ask for. (A) honestly sounds like undefined behavior to me and (B) is something vmap already doesn't support.
How important is it that we add this check? I want to say that forward-mode AD does not need this check during composition with vmap.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the example, it does seem like it would be an easy way for users to shoot themselves in the foot.
Looks like (A) is already unsupported by vmap as well (in core and in functorch)? In that case I would agree that forward AD does not need this check in composition with vmap.
from functorch import vmap
from torch._vmap_internals import _vmap
x = torch.rand(10, 2).select(1, 0)
# Both fail
vmap(lambda x: x.as_strided([1], [1], 1))(x)
_vmap(lambda x: x.as_strided([1], [1], 1))(x)
Also filed this issue regarding the error message: pytorch/functorch#308
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like (A) is already unsupported by vmap as well (in core and in functorch)?
Nice catch! I forgot we had those checks
- Instead of a native function to compare two tensors, maybe a native function that just retrieves the storage numel. - Using the stride of the first batch dim for batched tensors, seems like a reasonable replacement for storage numel. [ghstack-poisoned]
ghstack-source-id: 086415b7059ab79724ccc1c11cc809fa20e07c75 Pull Request resolved: #68631
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds mostly good. CI needs to be fixed though.
@@ -31,6 +31,9 @@ TORCH_LIBRARY_IMPL(aten, Conjugate, m) { | |||
m.impl("resolve_conj", torch::CppFunction::makeFallthrough()); | |||
m.impl("resolve_neg", torch::CppFunction::makeFallthrough()); | |||
|
|||
// See test_layout_check_for_primal_with_conj_bit in test_autograd.py | |||
m.impl("_has_same_storage_numel", torch::CppFunction::makeFallthrough()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have this one, you most likely also want it for neg version.
And you want a special impl for zero tensor as well? That returns true like the batched version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have this one, you most likely also want it for neg version.
Yup.
And you want a special impl for zero tensor as well? That returns true like the batched version?
Also a good point. Zero tensors aren't landed yet though right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Zero tensors aren't landed yet though right?
It was but maybe got reverted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was reverted indeed.
@soulitzer has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
This PR: - Adds the check that the storage numel of the base and tangent tensors are the same. This is to support the case when as_strided reveals elements that aren't indexable by the input tensor. - Skips the check when batched tensors are involved, because using as_strided to reveal elements that not indexable by the input tensor is already not allowed vmap. - Adds tests for the above two cases, as well as an edge case regarding conj bit (what about neg bit?) For functorch: - we need to copy the batching rule implemented here Differential Revision: [D32899678](https://our.internmc.facebook.com/intern/diff/D32899678) [ghstack-poisoned]
ghstack-source-id: 6d00ba54a031015312199b966724ec9e7fb03da2 Pull Request resolved: #68631
@soulitzer has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
This PR: - Adds the check that the storage numel of the base and tangent tensors are the same. This is to support the case when as_strided reveals elements that aren't indexable by the input tensor. - Skips the check when batched tensors are involved, because using as_strided to reveal elements that not indexable by the input tensor is already not allowed vmap. - Adds tests for the above two cases, as well as an edge case regarding conj bit (what about neg bit?) For functorch: - we need to copy the batching rule implemented here Differential Revision: [D32899678](https://our.internmc.facebook.com/intern/diff/D32899678) [ghstack-poisoned]
ghstack-source-id: a0b9f8cdc11c94b67575c30b8dd675fffa4229d8 Pull Request resolved: #68631
# always returns true because vmapped as_strided does not support accessing | ||
# storage locations not indexable by the input tensor. | ||
# See the note above for more information. | ||
- func: _has_same_storage_numel(Tensor self, Tensor other) -> bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the long term, maybe there should be two different modes for forward-mode AD? One that supports tensor subclasses, and one that supports the fast-path tensor code. If there are two different modes like that then we probably won't need a _has_same_storage_numel
operator.
But the approach in this PR makes sense from a short-to-mid-term perspective. It's a bit unfortunate that we have to register all these additional implementations for various dispatch keys for this new operator.
@soulitzer has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
The other failure should be unrelated, but |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok!
This PR: - Adds the check that the storage numel of the base and tangent tensors are the same. This is to support the case when as_strided reveals elements that aren't indexable by the input tensor. - Skips the check when batched tensors are involved, because using as_strided to reveal elements that not indexable by the input tensor is already not allowed vmap. - Adds tests for the above two cases, as well as an edge case regarding conj bit (what about neg bit?) For functorch: - we need to copy the batching rule implemented here Differential Revision: [D32899678](https://our.internmc.facebook.com/intern/diff/D32899678) [ghstack-poisoned]
ghstack-source-id: 89549a7a48c08455c3d1184e02e8bed240d9a4ae Pull Request resolved: #68631
@soulitzer has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Stack from ghstack:
This PR:
For functorch:
Differential Revision: D32899678