Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add forward AD layout check for storage numel #68631

Closed
wants to merge 11 commits into from

Conversation

soulitzer
Copy link
Contributor

@soulitzer soulitzer commented Nov 19, 2021

Stack from ghstack:

This PR:

  • Adds the check that the storage numel of the base and tangent tensors are the same. This is to support the case when as_strided reveals elements that aren't indexable by the input tensor.
  • Skips the check when batched tensors are involved, because using as_strided to reveal elements that not indexable by the input tensor is already not allowed vmap.
  • Adds tests for the above two cases, as well as an edge case regarding conj bit (what about neg bit?)

For functorch:

  • we need to copy the batching rule implemented here

Differential Revision: D32899678

@pytorch-probot
Copy link

pytorch-probot bot commented Nov 19, 2021

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/69a916838083edcbc24acb70eab2ddb41e194c65/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/xla ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux ✅ triggered
linux-vulkan-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3.6-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers ✅ triggered
linux-xenial-py3.6-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx ✅ triggered
linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/win ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
docker-builds ciflow/all 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos 🚫 skipped
parallelnative-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.6-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Nov 19, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 69a9168 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-scanned failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

soulitzer added a commit that referenced this pull request Nov 19, 2021
ghstack-source-id: 7ca455cf4b51c447c9799b42de05b745105c2d37
Pull Request resolved: #68631
soulitzer added a commit that referenced this pull request Nov 22, 2021
ghstack-source-id: ae34720e83b79d68ca64335c0eae3fa4e42c9adc
Pull Request resolved: #68631
@soulitzer soulitzer changed the title Add forward AD layout check for storage numel [WIP] Add forward AD layout check for storage numel Nov 22, 2021
- Instead of a native function to compare two tensors, maybe a native function that just retrieves the storage numel. 
- Using the stride of the first batch dim for batched tensors, seems like a reasonable replacement for storage numel.


[ghstack-poisoned]
- Instead of a native function to compare two tensors, maybe a native function that just retrieves the storage numel. 
- Using the stride of the first batch dim for batched tensors, seems like a reasonable replacement for storage numel.


[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 1, 2021
ghstack-source-id: 8a6d8de4ce15346689e8a29b0b9d20e828d21265
Pull Request resolved: #68631
@soulitzer soulitzer changed the title [WIP] Add forward AD layout check for storage numel Add forward AD layout check for storage numel Dec 1, 2021
aten/src/ATen/BatchingRegistrations.cpp Outdated Show resolved Hide resolved
tools/autograd/derivatives.yaml Outdated Show resolved Hide resolved
- Instead of a native function to compare two tensors, maybe a native function that just retrieves the storage numel. 
- Using the stride of the first batch dim for batched tensors, seems like a reasonable replacement for storage numel.


[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 1, 2021
ghstack-source-id: 3d9bf80ff54db8740e544bf69e7e5624d9f1fb41
Pull Request resolved: #68631
- Instead of a native function to compare two tensors, maybe a native function that just retrieves the storage numel. 
- Using the stride of the first batch dim for batched tensors, seems like a reasonable replacement for storage numel.


[ghstack-poisoned]
- Instead of a native function to compare two tensors, maybe a native function that just retrieves the storage numel. 
- Using the stride of the first batch dim for batched tensors, seems like a reasonable replacement for storage numel.


[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 2, 2021
ghstack-source-id: 933dbcae5cc2034c485e1c688ef76f59703faa7f
Pull Request resolved: #68631
std::string func_name) {
if (std::find(physical_sizes.begin(), physical_sizes.end(), 0) != physical_sizes.end()) {
// The tensor has zero numel
return;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is the right approach because I'm not clear on exactly what cases this check guards against. Does this check still make sense to do if the numel of the tensor is zero?

I added this to address the zero-numel case where _new_zeros_with_same_feature_meta prepends zeros as the strides if the storage numel of the tensor is zero.

Relaxing this check also seemed to address some failures:

Sidenote: there were other batched gradgrad (and even some batched grad) tests in linalg and fft that I seem to pass even without this fix. Having xfail instead of skip would be much nicer.

cc @zou3519

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially understood this check to be "are the batch dimensions at the front". But I think there is a distinction between that and "are the batch dimensions at the front in layout" which is what this function actually seems to check.

Maybe this is not the check we need in _new_zeros_with_same_meta and _storage_numel?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to modify this function?

Since we're almost done with getting vmap + forward-mode AD to compose in functorch, I want to stop supporting the vmap in core because it's too much unnecessary work

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Maybe we wouldn't need to do it here, but do you think it makes sense to modify it in functorch though?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this some more -- if the tensor has zero numel then any batch dims are "trivially at the front in memory layout" because there is no storage. So yes it would make sense to modify it in functorch

@@ -85,6 +85,9 @@ namespace {
return false;
}
}
if (base._storage_numel() != other._storage_numel()) {
Copy link
Contributor

@zou3519 zou3519 Dec 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, there's a huge performance problem with _new_zeros_with_same_feature_meta that arises with the invariant that base._storage_numel() == other._storage_numel()

Let's say we have x: Tensor[5], y: Tensor[3] and we construct a dual tensor make_dual(primal=x[0], tangent=BatchedTensor(y)).

When we set BatchedTensor(y) to be the tangent, we require it to have the same storage numel as primal. So we end up creating coercing the tangent into a tensor of shape [3, 5].

If the user never actually (A) uses as_strided on their dual tensor to reveal elements that are not a subset of the view or (B) does certain in-place operations that are already unsupported by vmap (regular_tensor.inplace_(dual)), then they are paying for memory that they didn't ask for. (A) honestly sounds like undefined behavior to me and (B) is something vmap already doesn't support.

How important is it that we add this check? I want to say that forward-mode AD does not need this check during composition with vmap.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the example, it does seem like it would be an easy way for users to shoot themselves in the foot.

Looks like (A) is already unsupported by vmap as well (in core and in functorch)? In that case I would agree that forward AD does not need this check in composition with vmap.

from functorch import vmap
from torch._vmap_internals import _vmap

x = torch.rand(10, 2).select(1, 0)

# Both fail
vmap(lambda x: x.as_strided([1], [1], 1))(x)
_vmap(lambda x: x.as_strided([1], [1], 1))(x)

Also filed this issue regarding the error message: pytorch/functorch#308

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like (A) is already unsupported by vmap as well (in core and in functorch)?

Nice catch! I forgot we had those checks

- Instead of a native function to compare two tensors, maybe a native function that just retrieves the storage numel. 
- Using the stride of the first batch dim for batched tensors, seems like a reasonable replacement for storage numel.


[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 3, 2021
ghstack-source-id: 086415b7059ab79724ccc1c11cc809fa20e07c75
Pull Request resolved: #68631
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds mostly good. CI needs to be fixed though.

@@ -31,6 +31,9 @@ TORCH_LIBRARY_IMPL(aten, Conjugate, m) {
m.impl("resolve_conj", torch::CppFunction::makeFallthrough());
m.impl("resolve_neg", torch::CppFunction::makeFallthrough());

// See test_layout_check_for_primal_with_conj_bit in test_autograd.py
m.impl("_has_same_storage_numel", torch::CppFunction::makeFallthrough());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have this one, you most likely also want it for neg version.
And you want a special impl for zero tensor as well? That returns true like the batched version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have this one, you most likely also want it for neg version.

Yup.

And you want a special impl for zero tensor as well? That returns true like the batched version?

Also a good point. Zero tensors aren't landed yet though right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Zero tensors aren't landed yet though right?

It was but maybe got reverted?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was reverted indeed.

torch/csrc/autograd/autograd_meta.cpp Show resolved Hide resolved
@soulitzer
Copy link
Contributor Author

@soulitzer has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

This PR:
- Adds the check that the storage numel of the base and tangent tensors are the same. This is to support the case when as_strided reveals elements that aren't indexable by the input tensor.
- Skips the check when batched tensors are involved, because using as_strided to reveal elements that not indexable by the input tensor is already not allowed vmap.
- Adds tests for the above two cases, as well as an edge case regarding conj bit (what about neg bit?)

For functorch:
- we need to copy the batching rule implemented here

Differential Revision: [D32899678](https://our.internmc.facebook.com/intern/diff/D32899678)

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 6, 2021
ghstack-source-id: 6d00ba54a031015312199b966724ec9e7fb03da2
Pull Request resolved: #68631
@soulitzer
Copy link
Contributor Author

@soulitzer has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

This PR:
- Adds the check that the storage numel of the base and tangent tensors are the same. This is to support the case when as_strided reveals elements that aren't indexable by the input tensor.
- Skips the check when batched tensors are involved, because using as_strided to reveal elements that not indexable by the input tensor is already not allowed vmap.
- Adds tests for the above two cases, as well as an edge case regarding conj bit (what about neg bit?)

For functorch:
- we need to copy the batching rule implemented here

Differential Revision: [D32899678](https://our.internmc.facebook.com/intern/diff/D32899678)

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 7, 2021
ghstack-source-id: a0b9f8cdc11c94b67575c30b8dd675fffa4229d8
Pull Request resolved: #68631
# always returns true because vmapped as_strided does not support accessing
# storage locations not indexable by the input tensor.
# See the note above for more information.
- func: _has_same_storage_numel(Tensor self, Tensor other) -> bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the long term, maybe there should be two different modes for forward-mode AD? One that supports tensor subclasses, and one that supports the fast-path tensor code. If there are two different modes like that then we probably won't need a _has_same_storage_numel operator.

But the approach in this PR makes sense from a short-to-mid-term perspective. It's a bit unfortunate that we have to register all these additional implementations for various dispatch keys for this new operator.

@soulitzer
Copy link
Contributor Author

@soulitzer has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@soulitzer
Copy link
Contributor Author

The other failure should be unrelated, but test_ops is failing because we need to wait for #69558 to land.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok!

This PR:
- Adds the check that the storage numel of the base and tangent tensors are the same. This is to support the case when as_strided reveals elements that aren't indexable by the input tensor.
- Skips the check when batched tensors are involved, because using as_strided to reveal elements that not indexable by the input tensor is already not allowed vmap.
- Adds tests for the above two cases, as well as an edge case regarding conj bit (what about neg bit?)

For functorch:
- we need to copy the batching rule implemented here

Differential Revision: [D32899678](https://our.internmc.facebook.com/intern/diff/D32899678)

[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 13, 2021
ghstack-source-id: 89549a7a48c08455c3d1184e02e8bed240d9a4ae
Pull Request resolved: #68631
@soulitzer
Copy link
Contributor Author

@soulitzer has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants