-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[generate_vmap_rule] Add generate_vmap_rule to autograd.Function #90966
Conversation
Design document: https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit This PR adds a `generate_vmap_rule` option (default False) to autograd.Function. By setting it to True, a user promises to us that their autograd.Function's {forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other limitations of autograd.Function+functorch (such as the user not capturing any Tensors being transformed over from outside of the autograd.Function). Concretely, the approach is: - we update `custom_function_call` to accept an additional `generate_vmap_rule` argument. - The vmap rule for `custom_function_call` and `generate_vmap_rule=True` is: we construct a vmapped version of the autograd.Function and dispatch on it. - The vmapped version of the autograd.Function can be thought of like the following: if we have an autograd.Function Foo, then VmappedFoo.apply(in_dims, ...) has the same semantics as vmap(Foo.apply, in_dims...) - VmappedFoo's forward, setup_context, and backward staticmethod are vmapped versions of Foo's staticmethods. - See the design doc for more motivation and explanation Test Plan: - This PR introduces additional autograd.Function with the suffix "GenVmap" to autograd_function_db. - There are also some minor UX tests Future: - jvp support - likely more testing to come, but please let me know if you have cases that you want me to test here. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90966
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d568583: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Design document: https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit This PR adds a `generate_vmap_rule` option (default False) to autograd.Function. By setting it to True, a user promises to us that their autograd.Function's {forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other limitations of autograd.Function+functorch (such as the user not capturing any Tensors being transformed over from outside of the autograd.Function). Concretely, the approach is: - we update `custom_function_call` to accept an additional `generate_vmap_rule` argument. - The vmap rule for `custom_function_call` and `generate_vmap_rule=True` is: we construct a vmapped version of the autograd.Function and dispatch on it. - The vmapped version of the autograd.Function can be thought of like the following: if we have an autograd.Function Foo, then VmappedFoo.apply(in_dims, ...) has the same semantics as vmap(Foo.apply, in_dims...) - VmappedFoo's forward, setup_context, and backward staticmethod are vmapped versions of Foo's staticmethods. - See the design doc for more motivation and explanation Test Plan: - This PR introduces additional autograd.Function with the suffix "GenVmap" to autograd_function_db. - There are also some minor UX tests Future: - jvp support - likely more testing to come, but please let me know if you have cases that you want me to test here. ghstack-source-id: b8d3e094fc86536a51b8b7c1460911d2e4261260 Pull Request resolved: #90966
…nction" Design document: https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit This PR adds a `generate_vmap_rule` option (default False) to autograd.Function. By setting it to True, a user promises to us that their autograd.Function's {forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other limitations of autograd.Function+functorch (such as the user not capturing any Tensors being transformed over from outside of the autograd.Function). Concretely, the approach is: - we update `custom_function_call` to accept an additional `generate_vmap_rule` argument. - The vmap rule for `custom_function_call` and `generate_vmap_rule=True` is: we construct a vmapped version of the autograd.Function and dispatch on it. - The vmapped version of the autograd.Function can be thought of like the following: if we have an autograd.Function Foo, then VmappedFoo.apply(in_dims, ...) has the same semantics as vmap(Foo.apply, in_dims...) - VmappedFoo's forward, setup_context, and backward staticmethod are vmapped versions of Foo's staticmethods. - See the design doc for more motivation and explanation Test Plan: - This PR introduces additional autograd.Function with the suffix "GenVmap" to autograd_function_db. - There are also some minor UX tests Future: - jvp support - likely more testing to come, but please let me know if you have cases that you want me to test here. [ghstack-poisoned]
Design document: https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit This PR adds a `generate_vmap_rule` option (default False) to autograd.Function. By setting it to True, a user promises to us that their autograd.Function's {forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other limitations of autograd.Function+functorch (such as the user not capturing any Tensors being transformed over from outside of the autograd.Function). Concretely, the approach is: - we update `custom_function_call` to accept an additional `generate_vmap_rule` argument. - The vmap rule for `custom_function_call` and `generate_vmap_rule=True` is: we construct a vmapped version of the autograd.Function and dispatch on it. - The vmapped version of the autograd.Function can be thought of like the following: if we have an autograd.Function Foo, then VmappedFoo.apply(in_dims, ...) has the same semantics as vmap(Foo.apply, in_dims...) - VmappedFoo's forward, setup_context, and backward staticmethod are vmapped versions of Foo's staticmethods. - See the design doc for more motivation and explanation Test Plan: - This PR introduces additional autograd.Function with the suffix "GenVmap" to autograd_function_db. - There are also some minor UX tests Future: - jvp support - likely more testing to come, but please let me know if you have cases that you want me to test here. ghstack-source-id: 8e24c3dde15370943c8be01b0a30452a993146e9 Pull Request resolved: #90966
…nction" Design document: https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit This PR adds a `generate_vmap_rule` option (default False) to autograd.Function. By setting it to True, a user promises to us that their autograd.Function's {forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other limitations of autograd.Function+functorch (such as the user not capturing any Tensors being transformed over from outside of the autograd.Function). Concretely, the approach is: - we update `custom_function_call` to accept an additional `generate_vmap_rule` argument. - The vmap rule for `custom_function_call` and `generate_vmap_rule=True` is: we construct a vmapped version of the autograd.Function and dispatch on it. - The vmapped version of the autograd.Function can be thought of like the following: if we have an autograd.Function Foo, then VmappedFoo.apply(in_dims, ...) has the same semantics as vmap(Foo.apply, in_dims...) - VmappedFoo's forward, setup_context, and backward staticmethod are vmapped versions of Foo's staticmethods. - See the design doc for more motivation and explanation Test Plan: - This PR introduces additional autograd.Function with the suffix "GenVmap" to autograd_function_db. - There are also some minor UX tests Future: - jvp support - likely more testing to come, but please let me know if you have cases that you want me to test here. [ghstack-poisoned]
Design document: https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit This PR adds a `generate_vmap_rule` option (default False) to autograd.Function. By setting it to True, a user promises to us that their autograd.Function's {forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other limitations of autograd.Function+functorch (such as the user not capturing any Tensors being transformed over from outside of the autograd.Function). Concretely, the approach is: - we update `custom_function_call` to accept an additional `generate_vmap_rule` argument. - The vmap rule for `custom_function_call` and `generate_vmap_rule=True` is: we construct a vmapped version of the autograd.Function and dispatch on it. - The vmapped version of the autograd.Function can be thought of like the following: if we have an autograd.Function Foo, then VmappedFoo.apply(in_dims, ...) has the same semantics as vmap(Foo.apply, in_dims...) - VmappedFoo's forward, setup_context, and backward staticmethod are vmapped versions of Foo's staticmethods. - See the design doc for more motivation and explanation Test Plan: - This PR introduces additional autograd.Function with the suffix "GenVmap" to autograd_function_db. - There are also some minor UX tests Future: - jvp support - likely more testing to come, but please let me know if you have cases that you want me to test here. ghstack-source-id: 6905e603296abe16102d712d86e5b70a3a7ef4ac Pull Request resolved: #90966
TORCH_INTERNAL_ASSERT(wrapper == nullptr); | ||
auto* batched = maybeGetBatchedImpl(tensor); | ||
auto* batched = maybeGetBatchedImpl(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bug-fix. The problem is, if you use vmap over dead GradTensorWrapper, then upon dispatching on an operator (like at::sin), sanityCheckStack
does an internal assert (the original purpose of sanityCheckStack is to ensure that after we're done handling layers on the functorch stack (like vmap), the only tensors left are regular Tensors).
The fix is to unwrap the GradTensorWrapper here.
In general the dead GradTensorWrapper situation runs us into a lot of weird edge cases like this, which is why we want mode-only functorch to save us :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm, this (always?) happens if we do vjp(vmap?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can happen if someone invokes a functorch transform on a dead GradTensorWrapper and ends up wrapping the dead GradTensorWrapper in an alive wrapper. Some situations where it can happen are:
- if we do vjp(vmap over an autograd.Function with generate_vmap_rule=True (in the backward pass we end up invoking vmap on a dead GradTensorWrapper)
- if we do vjp( over an autograd.Function with a vmap staticmethod AND the backward pass calls a functorch transform. vjp saves GradTensorWrapper for backward, so in the backward pass, all of those GradTensorWrapper are dead.
This doesn't happen if we do vjp(vmap over a regular PyTorch operator, because the implementation of the operator doesn't end up calling a functorch transform.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, thanks for clarifying!
return ( | ||
torch.zeros(gx.shape, dtype=gx.dtype, device=gx.device), | ||
torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A little unrelated, but maybe we also want to a custom Function where the gradient is intentionally wrong for the case where vmap is provided by the user, just to make sure the backward is preserved when we do grad(vmap for that as well. Currently if the user provided vmap does not call back into the original custom Function it seems possible that the user-provided backward will not be preserved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently if the user provided vmap does not call back into the original custom Function it seems possible that the user-provided backward will not be preserved.
In this case, the user-provided backward is not preserved if the user provides vmap
and does not call back into the original custom Function. This is intentional:
- this is consistent with how aten operators work when we do grad(vmap(f))(x): let's say f is at::dot and vmap transforms at::dot into at::mm. the grad records an at::mm node, and MMBackward gets run instead of a vmapped DotBackward.
- it's on the user to provide a correct
vmap
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks mostly good, just a few comments
TORCH_INTERNAL_ASSERT(wrapper == nullptr); | ||
auto* batched = maybeGetBatchedImpl(tensor); | ||
auto* batched = maybeGetBatchedImpl(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm, this (always?) happens if we do vjp(vmap?
test/functorch/test_ops.py
Outdated
@@ -1517,6 +1519,9 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): | |||
xfail("_native_batch_norm_legit"), | |||
xfail('native_dropout_backward'), | |||
xfail('nn.functional.prelu'), | |||
|
|||
xfail('CubeGenVmapAutogradFunction'), # NYI | |||
xfail('SortGenVmapAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90067 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has been fixed now! #90067
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, that was quick
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to NYI (this PR doesn't add jvp support for generate_vmap_rule=True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, I'm not immediately to remove the xfail in
pytorch/test/functorch/test_ops.py
Line 1360 in f02e93b
xfail('NumpySortAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90067 |
The reason being is that it looks like it creates an torch.int64 tangent tensor instead of passing None:
pytorch/torch/testing/_internal/autograd_function_db.py
Lines 238 to 239 in f02e93b
assert ind_tangent is None | |
assert ind_inv_tangent is None |
ctx.set_materialize_grads does make the problem go away, but... (1) morally I feel like there should never be a non-fp tangent tensor and (2) the set_materialize_grads documentation doesn't say anything about this case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that looks l like a bug, taking a look.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be fixed by #91183
…nction" Design document: https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit This PR adds a `generate_vmap_rule` option (default False) to autograd.Function. By setting it to True, a user promises to us that their autograd.Function's {forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other limitations of autograd.Function+functorch (such as the user not capturing any Tensors being transformed over from outside of the autograd.Function). Concretely, the approach is: - we update `custom_function_call` to accept an additional `generate_vmap_rule` argument. - The vmap rule for `custom_function_call` and `generate_vmap_rule=True` is: we construct a vmapped version of the autograd.Function and dispatch on it. - The vmapped version of the autograd.Function can be thought of like the following: if we have an autograd.Function Foo, then VmappedFoo.apply(in_dims, ...) has the same semantics as vmap(Foo.apply, in_dims...) - VmappedFoo's forward, setup_context, and backward staticmethod are vmapped versions of Foo's staticmethods. - See the design doc for more motivation and explanation Test Plan: - This PR introduces additional autograd.Function with the suffix "GenVmap" to autograd_function_db. - There are also some minor UX tests Future: - jvp support - likely more testing to come, but please let me know if you have cases that you want me to test here. [ghstack-poisoned]
Design document: https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit This PR adds a `generate_vmap_rule` option (default False) to autograd.Function. By setting it to True, a user promises to us that their autograd.Function's {forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other limitations of autograd.Function+functorch (such as the user not capturing any Tensors being transformed over from outside of the autograd.Function). Concretely, the approach is: - we update `custom_function_call` to accept an additional `generate_vmap_rule` argument. - The vmap rule for `custom_function_call` and `generate_vmap_rule=True` is: we construct a vmapped version of the autograd.Function and dispatch on it. - The vmapped version of the autograd.Function can be thought of like the following: if we have an autograd.Function Foo, then VmappedFoo.apply(in_dims, ...) has the same semantics as vmap(Foo.apply, in_dims...) - VmappedFoo's forward, setup_context, and backward staticmethod are vmapped versions of Foo's staticmethods. - See the design doc for more motivation and explanation Test Plan: - This PR introduces additional autograd.Function with the suffix "GenVmap" to autograd_function_db. - There are also some minor UX tests Future: - jvp support - likely more testing to come, but please let me know if you have cases that you want me to test here. ghstack-source-id: 9930746416df86dd085e2d10b0beff19877c31dc Pull Request resolved: #90966
# - input_shapes, saved_tensors_bdims: I'm a bit scared of nesting | ||
# vmap(vmap( but not completely sure if it is a problem. If we | ||
# assigned those fields to the ctx object, the worry is that they | ||
# get overwritten. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For input_shapes and saved_tensor_bdims, isn't it less likely to get overridden when you do grad(grad(vmap
if you save them on ctx instead of here because you make a new ctx for every layer of grad you have.
Currently I would think that it is getting overridden because the two grads refer to the same vmapified autograd_function, but it just doesn't matter because both grad layers should be saving the same input_shapes + saved-tensor_bdims (because there's no vmap in between). And even if you have something like grad(vmap(grad(vmap
, that doesn't matter because now the two grads should be referring to different autograd_function objects.
(Actually you might mean overridden by what the user saves, in that case ignore the above)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you do grad(grad(vmap, then yeah, those are not being overridden because each grad layer gets a different ctx object.
The concern is for two consecutive vmap, like grad(vmap(vmap. I haven't actually tested this so it might just work out of the box, but another benefit of keeping these attributes out here is that I don't need to denylist more attrs for the user. Though there are other ways around the denylist (we can shove all extra attrs we want into a single object on the ctx)
…nction" Design document: https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit This PR adds a `generate_vmap_rule` option (default False) to autograd.Function. By setting it to True, a user promises to us that their autograd.Function's {forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other limitations of autograd.Function+functorch (such as the user not capturing any Tensors being transformed over from outside of the autograd.Function). Concretely, the approach is: - we update `custom_function_call` to accept an additional `generate_vmap_rule` argument. - The vmap rule for `custom_function_call` and `generate_vmap_rule=True` is: we construct a vmapped version of the autograd.Function and dispatch on it. - The vmapped version of the autograd.Function can be thought of like the following: if we have an autograd.Function Foo, then VmappedFoo.apply(in_dims, ...) has the same semantics as vmap(Foo.apply, in_dims...) - VmappedFoo's forward, setup_context, and backward staticmethod are vmapped versions of Foo's staticmethods. - See the design doc for more motivation and explanation Test Plan: - This PR introduces additional autograd.Function with the suffix "GenVmap" to autograd_function_db. - There are also some minor UX tests Future: - jvp support - likely more testing to come, but please let me know if you have cases that you want me to test here. [ghstack-poisoned]
Design document: https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit This PR adds a `generate_vmap_rule` option (default False) to autograd.Function. By setting it to True, a user promises to us that their autograd.Function's {forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other limitations of autograd.Function+functorch (such as the user not capturing any Tensors being transformed over from outside of the autograd.Function). Concretely, the approach is: - we update `custom_function_call` to accept an additional `generate_vmap_rule` argument. - The vmap rule for `custom_function_call` and `generate_vmap_rule=True` is: we construct a vmapped version of the autograd.Function and dispatch on it. - The vmapped version of the autograd.Function can be thought of like the following: if we have an autograd.Function Foo, then VmappedFoo.apply(in_dims, ...) has the same semantics as vmap(Foo.apply, in_dims...) - VmappedFoo's forward, setup_context, and backward staticmethod are vmapped versions of Foo's staticmethods. - See the design doc for more motivation and explanation Test Plan: - This PR introduces additional autograd.Function with the suffix "GenVmap" to autograd_function_db. - There are also some minor UX tests Future: - jvp support - likely more testing to come, but please let me know if you have cases that you want me to test here. ghstack-source-id: 7b65454f69ca077f9357eda22c5ae045a627446d Pull Request resolved: #90966
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…nction" Design document: https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit This PR adds a `generate_vmap_rule` option (default False) to autograd.Function. By setting it to True, a user promises to us that their autograd.Function's {forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other limitations of autograd.Function+functorch (such as the user not capturing any Tensors being transformed over from outside of the autograd.Function). Concretely, the approach is: - we update `custom_function_call` to accept an additional `generate_vmap_rule` argument. - The vmap rule for `custom_function_call` and `generate_vmap_rule=True` is: we construct a vmapped version of the autograd.Function and dispatch on it. - The vmapped version of the autograd.Function can be thought of like the following: if we have an autograd.Function Foo, then VmappedFoo.apply(in_dims, ...) has the same semantics as vmap(Foo.apply, in_dims...) - VmappedFoo's forward, setup_context, and backward staticmethod are vmapped versions of Foo's staticmethods. - See the design doc for more motivation and explanation Test Plan: - This PR introduces additional autograd.Function with the suffix "GenVmap" to autograd_function_db. - There are also some minor UX tests Future: - jvp support - likely more testing to come, but please let me know if you have cases that you want me to test here. [ghstack-poisoned]
Design document: https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit This PR adds a `generate_vmap_rule` option (default False) to autograd.Function. By setting it to True, a user promises to us that their autograd.Function's {forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other limitations of autograd.Function+functorch (such as the user not capturing any Tensors being transformed over from outside of the autograd.Function). Concretely, the approach is: - we update `custom_function_call` to accept an additional `generate_vmap_rule` argument. - The vmap rule for `custom_function_call` and `generate_vmap_rule=True` is: we construct a vmapped version of the autograd.Function and dispatch on it. - The vmapped version of the autograd.Function can be thought of like the following: if we have an autograd.Function Foo, then VmappedFoo.apply(in_dims, ...) has the same semantics as vmap(Foo.apply, in_dims...) - VmappedFoo's forward, setup_context, and backward staticmethod are vmapped versions of Foo's staticmethods. - See the design doc for more motivation and explanation Test Plan: - This PR introduces additional autograd.Function with the suffix "GenVmap" to autograd_function_db. - There are also some minor UX tests Future: - jvp support - likely more testing to come, but please let me know if you have cases that you want me to test here. ghstack-source-id: 02395a796d8b957499bdfb1db8293e0e64cb6beb Pull Request resolved: #90966
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack:
Design document:
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit
This PR adds a
generate_vmap_rule
option (default False) to autograd.Function.By setting it to True, a user promises to us that their autograd.Function's
{forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other
limitations of autograd.Function+functorch (such as the user not
capturing any Tensors being transformed over from outside of the
autograd.Function).
Concretely, the approach is:
custom_function_call
to accept an additionalgenerate_vmap_rule
argument.custom_function_call
andgenerate_vmap_rule=True
is: we construct a vmapped version of the autograd.Function and dispatch
on it.
the following: if we have an autograd.Function Foo, then
VmappedFoo.apply(in_dims, ...) has the same semantics as
vmap(Foo.apply, in_dims...)
vmapped versions of Foo's staticmethods.
Test Plan:
autograd_function_db.
Future:
cases that you want me to test here.