Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[generate_vmap_rule] Add generate_vmap_rule to autograd.Function #90966

Closed
wants to merge 6 commits into from

Conversation

zou3519
Copy link
Contributor

@zou3519 zou3519 commented Dec 15, 2022

Stack from ghstack:

Design document:
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit

This PR adds a generate_vmap_rule option (default False) to autograd.Function.
By setting it to True, a user promises to us that their autograd.Function's
{forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other
limitations of autograd.Function+functorch (such as the user not
capturing any Tensors being transformed over from outside of the
autograd.Function).

Concretely, the approach is:

  • we update custom_function_call to accept an additional
    generate_vmap_rule argument.
  • The vmap rule for custom_function_call and generate_vmap_rule=True
    is: we construct a vmapped version of the autograd.Function and dispatch
    on it.
  • The vmapped version of the autograd.Function can be thought of like
    the following: if we have an autograd.Function Foo, then
    VmappedFoo.apply(in_dims, ...) has the same semantics as
    vmap(Foo.apply, in_dims...)
  • VmappedFoo's forward, setup_context, and backward staticmethod are
    vmapped versions of Foo's staticmethods.
  • See the design doc for more motivation and explanation

Test Plan:

  • This PR introduces additional autograd.Function with the suffix "GenVmap" to
    autograd_function_db.
  • There are also some minor UX tests

Future:

  • jvp support
  • likely more testing to come, but please let me know if you have
    cases that you want me to test here.

Design document:
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit

This PR adds a `generate_vmap_rule` option (default False) to autograd.Function.
By setting it to True, a user promises to us that their autograd.Function's
{forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other
limitations of autograd.Function+functorch (such as the user not
capturing any Tensors being transformed over from outside of the
autograd.Function).

Concretely, the approach is:
- we update `custom_function_call` to accept an additional
`generate_vmap_rule` argument.
- The vmap rule for `custom_function_call` and `generate_vmap_rule=True`
is: we construct a vmapped version of the autograd.Function and dispatch
on it.
- The vmapped version of the autograd.Function can be thought of like
the following: if we have an autograd.Function Foo, then
VmappedFoo.apply(in_dims, ...) has the same semantics as
vmap(Foo.apply, in_dims...)
- VmappedFoo's forward, setup_context, and backward staticmethod are
vmapped versions of Foo's staticmethods.
- See the design doc for more motivation and explanation

Test Plan:
- This PR introduces additional autograd.Function with the suffix "GenVmap" to
autograd_function_db.
- There are also some minor UX tests

Future:
- jvp support
- likely more testing to come, but please let me know if you have
cases that you want me to test here.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 15, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90966

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d568583:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

zou3519 added a commit that referenced this pull request Dec 15, 2022
Design document:
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit

This PR adds a `generate_vmap_rule` option (default False) to autograd.Function.
By setting it to True, a user promises to us that their autograd.Function's
{forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other
limitations of autograd.Function+functorch (such as the user not
capturing any Tensors being transformed over from outside of the
autograd.Function).

Concretely, the approach is:
- we update `custom_function_call` to accept an additional
`generate_vmap_rule` argument.
- The vmap rule for `custom_function_call` and `generate_vmap_rule=True`
is: we construct a vmapped version of the autograd.Function and dispatch
on it.
- The vmapped version of the autograd.Function can be thought of like
the following: if we have an autograd.Function Foo, then
VmappedFoo.apply(in_dims, ...) has the same semantics as
vmap(Foo.apply, in_dims...)
- VmappedFoo's forward, setup_context, and backward staticmethod are
vmapped versions of Foo's staticmethods.
- See the design doc for more motivation and explanation

Test Plan:
- This PR introduces additional autograd.Function with the suffix "GenVmap" to
autograd_function_db.
- There are also some minor UX tests

Future:
- jvp support
- likely more testing to come, but please let me know if you have
cases that you want me to test here.

ghstack-source-id: b8d3e094fc86536a51b8b7c1460911d2e4261260
Pull Request resolved: #90966
@zou3519 zou3519 added the release notes: functorch release notes category; Pertaining to torch.func or pytorch/functorch label Dec 15, 2022
…nction"

Design document:
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit

This PR adds a `generate_vmap_rule` option (default False) to autograd.Function.
By setting it to True, a user promises to us that their autograd.Function's
{forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other
limitations of autograd.Function+functorch (such as the user not
capturing any Tensors being transformed over from outside of the
autograd.Function).

Concretely, the approach is:
- we update `custom_function_call` to accept an additional
`generate_vmap_rule` argument.
- The vmap rule for `custom_function_call` and `generate_vmap_rule=True`
is: we construct a vmapped version of the autograd.Function and dispatch
on it.
- The vmapped version of the autograd.Function can be thought of like
the following: if we have an autograd.Function Foo, then
VmappedFoo.apply(in_dims, ...) has the same semantics as
vmap(Foo.apply, in_dims...)
- VmappedFoo's forward, setup_context, and backward staticmethod are
vmapped versions of Foo's staticmethods.
- See the design doc for more motivation and explanation

Test Plan:
- This PR introduces additional autograd.Function with the suffix "GenVmap" to
autograd_function_db.
- There are also some minor UX tests

Future:
- jvp support
- likely more testing to come, but please let me know if you have
cases that you want me to test here.

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Dec 16, 2022
Design document:
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit

This PR adds a `generate_vmap_rule` option (default False) to autograd.Function.
By setting it to True, a user promises to us that their autograd.Function's
{forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other
limitations of autograd.Function+functorch (such as the user not
capturing any Tensors being transformed over from outside of the
autograd.Function).

Concretely, the approach is:
- we update `custom_function_call` to accept an additional
`generate_vmap_rule` argument.
- The vmap rule for `custom_function_call` and `generate_vmap_rule=True`
is: we construct a vmapped version of the autograd.Function and dispatch
on it.
- The vmapped version of the autograd.Function can be thought of like
the following: if we have an autograd.Function Foo, then
VmappedFoo.apply(in_dims, ...) has the same semantics as
vmap(Foo.apply, in_dims...)
- VmappedFoo's forward, setup_context, and backward staticmethod are
vmapped versions of Foo's staticmethods.
- See the design doc for more motivation and explanation

Test Plan:
- This PR introduces additional autograd.Function with the suffix "GenVmap" to
autograd_function_db.
- There are also some minor UX tests

Future:
- jvp support
- likely more testing to come, but please let me know if you have
cases that you want me to test here.

ghstack-source-id: 8e24c3dde15370943c8be01b0a30452a993146e9
Pull Request resolved: #90966
@zou3519 zou3519 requested a review from samdow December 16, 2022 15:40
…nction"

Design document:
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit

This PR adds a `generate_vmap_rule` option (default False) to autograd.Function.
By setting it to True, a user promises to us that their autograd.Function's
{forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other
limitations of autograd.Function+functorch (such as the user not
capturing any Tensors being transformed over from outside of the
autograd.Function).

Concretely, the approach is:
- we update `custom_function_call` to accept an additional
`generate_vmap_rule` argument.
- The vmap rule for `custom_function_call` and `generate_vmap_rule=True`
is: we construct a vmapped version of the autograd.Function and dispatch
on it.
- The vmapped version of the autograd.Function can be thought of like
the following: if we have an autograd.Function Foo, then
VmappedFoo.apply(in_dims, ...) has the same semantics as
vmap(Foo.apply, in_dims...)
- VmappedFoo's forward, setup_context, and backward staticmethod are
vmapped versions of Foo's staticmethods.
- See the design doc for more motivation and explanation

Test Plan:
- This PR introduces additional autograd.Function with the suffix "GenVmap" to
autograd_function_db.
- There are also some minor UX tests

Future:
- jvp support
- likely more testing to come, but please let me know if you have
cases that you want me to test here.

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Dec 16, 2022
Design document:
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit

This PR adds a `generate_vmap_rule` option (default False) to autograd.Function.
By setting it to True, a user promises to us that their autograd.Function's
{forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other
limitations of autograd.Function+functorch (such as the user not
capturing any Tensors being transformed over from outside of the
autograd.Function).

Concretely, the approach is:
- we update `custom_function_call` to accept an additional
`generate_vmap_rule` argument.
- The vmap rule for `custom_function_call` and `generate_vmap_rule=True`
is: we construct a vmapped version of the autograd.Function and dispatch
on it.
- The vmapped version of the autograd.Function can be thought of like
the following: if we have an autograd.Function Foo, then
VmappedFoo.apply(in_dims, ...) has the same semantics as
vmap(Foo.apply, in_dims...)
- VmappedFoo's forward, setup_context, and backward staticmethod are
vmapped versions of Foo's staticmethods.
- See the design doc for more motivation and explanation

Test Plan:
- This PR introduces additional autograd.Function with the suffix "GenVmap" to
autograd_function_db.
- There are also some minor UX tests

Future:
- jvp support
- likely more testing to come, but please let me know if you have
cases that you want me to test here.

ghstack-source-id: 6905e603296abe16102d712d86e5b70a3a7ef4ac
Pull Request resolved: #90966
TORCH_INTERNAL_ASSERT(wrapper == nullptr);
auto* batched = maybeGetBatchedImpl(tensor);
auto* batched = maybeGetBatchedImpl(result);
Copy link
Contributor Author

@zou3519 zou3519 Dec 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug-fix. The problem is, if you use vmap over dead GradTensorWrapper, then upon dispatching on an operator (like at::sin), sanityCheckStack does an internal assert (the original purpose of sanityCheckStack is to ensure that after we're done handling layers on the functorch stack (like vmap), the only tensors left are regular Tensors).

The fix is to unwrap the GradTensorWrapper here.

In general the dead GradTensorWrapper situation runs us into a lot of weird edge cases like this, which is why we want mode-only functorch to save us :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, this (always?) happens if we do vjp(vmap?

Copy link
Contributor Author

@zou3519 zou3519 Dec 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can happen if someone invokes a functorch transform on a dead GradTensorWrapper and ends up wrapping the dead GradTensorWrapper in an alive wrapper. Some situations where it can happen are:

  • if we do vjp(vmap over an autograd.Function with generate_vmap_rule=True (in the backward pass we end up invoking vmap on a dead GradTensorWrapper)
  • if we do vjp( over an autograd.Function with a vmap staticmethod AND the backward pass calls a functorch transform. vjp saves GradTensorWrapper for backward, so in the backward pass, all of those GradTensorWrapper are dead.

This doesn't happen if we do vjp(vmap over a regular PyTorch operator, because the implementation of the operator doesn't end up calling a functorch transform.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks for clarifying!

return (
torch.zeros(gx.shape, dtype=gx.dtype, device=gx.device),
torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device),
)
Copy link
Contributor

@soulitzer soulitzer Dec 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A little unrelated, but maybe we also want to a custom Function where the gradient is intentionally wrong for the case where vmap is provided by the user, just to make sure the backward is preserved when we do grad(vmap for that as well. Currently if the user provided vmap does not call back into the original custom Function it seems possible that the user-provided backward will not be preserved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently if the user provided vmap does not call back into the original custom Function it seems possible that the user-provided backward will not be preserved.

In this case, the user-provided backward is not preserved if the user provides vmap and does not call back into the original custom Function. This is intentional:

  • this is consistent with how aten operators work when we do grad(vmap(f))(x): let's say f is at::dot and vmap transforms at::dot into at::mm. the grad records an at::mm node, and MMBackward gets run instead of a vmapped DotBackward.
  • it's on the user to provide a correct vmap

Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks mostly good, just a few comments

torch/_functorch/autograd_function.py Outdated Show resolved Hide resolved
TORCH_INTERNAL_ASSERT(wrapper == nullptr);
auto* batched = maybeGetBatchedImpl(tensor);
auto* batched = maybeGetBatchedImpl(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, this (always?) happens if we do vjp(vmap?

torch/_functorch/autograd_function.py Show resolved Hide resolved
@@ -1517,6 +1519,9 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
xfail("_native_batch_norm_legit"),
xfail('native_dropout_backward'),
xfail('nn.functional.prelu'),

xfail('CubeGenVmapAutogradFunction'), # NYI
xfail('SortGenVmapAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90067
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been fixed now! #90067

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, that was quick

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to NYI (this PR doesn't add jvp support for generate_vmap_rule=True)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, I'm not immediately to remove the xfail in

xfail('NumpySortAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90067
after your PR.

The reason being is that it looks like it creates an torch.int64 tangent tensor instead of passing None:

assert ind_tangent is None
assert ind_inv_tangent is None
.

ctx.set_materialize_grads does make the problem go away, but... (1) morally I feel like there should never be a non-fp tangent tensor and (2) the set_materialize_grads documentation doesn't say anything about this case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that looks l like a bug, taking a look.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed by #91183

…nction"

Design document:
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit

This PR adds a `generate_vmap_rule` option (default False) to autograd.Function.
By setting it to True, a user promises to us that their autograd.Function's
{forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other
limitations of autograd.Function+functorch (such as the user not
capturing any Tensors being transformed over from outside of the
autograd.Function).

Concretely, the approach is:
- we update `custom_function_call` to accept an additional
`generate_vmap_rule` argument.
- The vmap rule for `custom_function_call` and `generate_vmap_rule=True`
is: we construct a vmapped version of the autograd.Function and dispatch
on it.
- The vmapped version of the autograd.Function can be thought of like
the following: if we have an autograd.Function Foo, then
VmappedFoo.apply(in_dims, ...) has the same semantics as
vmap(Foo.apply, in_dims...)
- VmappedFoo's forward, setup_context, and backward staticmethod are
vmapped versions of Foo's staticmethods.
- See the design doc for more motivation and explanation

Test Plan:
- This PR introduces additional autograd.Function with the suffix "GenVmap" to
autograd_function_db.
- There are also some minor UX tests

Future:
- jvp support
- likely more testing to come, but please let me know if you have
cases that you want me to test here.

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Dec 19, 2022
Design document:
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit

This PR adds a `generate_vmap_rule` option (default False) to autograd.Function.
By setting it to True, a user promises to us that their autograd.Function's
{forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other
limitations of autograd.Function+functorch (such as the user not
capturing any Tensors being transformed over from outside of the
autograd.Function).

Concretely, the approach is:
- we update `custom_function_call` to accept an additional
`generate_vmap_rule` argument.
- The vmap rule for `custom_function_call` and `generate_vmap_rule=True`
is: we construct a vmapped version of the autograd.Function and dispatch
on it.
- The vmapped version of the autograd.Function can be thought of like
the following: if we have an autograd.Function Foo, then
VmappedFoo.apply(in_dims, ...) has the same semantics as
vmap(Foo.apply, in_dims...)
- VmappedFoo's forward, setup_context, and backward staticmethod are
vmapped versions of Foo's staticmethods.
- See the design doc for more motivation and explanation

Test Plan:
- This PR introduces additional autograd.Function with the suffix "GenVmap" to
autograd_function_db.
- There are also some minor UX tests

Future:
- jvp support
- likely more testing to come, but please let me know if you have
cases that you want me to test here.

ghstack-source-id: 9930746416df86dd085e2d10b0beff19877c31dc
Pull Request resolved: #90966
# - input_shapes, saved_tensors_bdims: I'm a bit scared of nesting
# vmap(vmap( but not completely sure if it is a problem. If we
# assigned those fields to the ctx object, the worry is that they
# get overwritten.
Copy link
Contributor

@soulitzer soulitzer Dec 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For input_shapes and saved_tensor_bdims, isn't it less likely to get overridden when you do grad(grad(vmap if you save them on ctx instead of here because you make a new ctx for every layer of grad you have.

Currently I would think that it is getting overridden because the two grads refer to the same vmapified autograd_function, but it just doesn't matter because both grad layers should be saving the same input_shapes + saved-tensor_bdims (because there's no vmap in between). And even if you have something like grad(vmap(grad(vmap, that doesn't matter because now the two grads should be referring to different autograd_function objects.

(Actually you might mean overridden by what the user saves, in that case ignore the above)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you do grad(grad(vmap, then yeah, those are not being overridden because each grad layer gets a different ctx object.

The concern is for two consecutive vmap, like grad(vmap(vmap. I haven't actually tested this so it might just work out of the box, but another benefit of keeping these attributes out here is that I don't need to denylist more attrs for the user. Though there are other ways around the denylist (we can shove all extra attrs we want into a single object on the ctx)

…nction"

Design document:
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit

This PR adds a `generate_vmap_rule` option (default False) to autograd.Function.
By setting it to True, a user promises to us that their autograd.Function's
{forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other
limitations of autograd.Function+functorch (such as the user not
capturing any Tensors being transformed over from outside of the
autograd.Function).

Concretely, the approach is:
- we update `custom_function_call` to accept an additional
`generate_vmap_rule` argument.
- The vmap rule for `custom_function_call` and `generate_vmap_rule=True`
is: we construct a vmapped version of the autograd.Function and dispatch
on it.
- The vmapped version of the autograd.Function can be thought of like
the following: if we have an autograd.Function Foo, then
VmappedFoo.apply(in_dims, ...) has the same semantics as
vmap(Foo.apply, in_dims...)
- VmappedFoo's forward, setup_context, and backward staticmethod are
vmapped versions of Foo's staticmethods.
- See the design doc for more motivation and explanation

Test Plan:
- This PR introduces additional autograd.Function with the suffix "GenVmap" to
autograd_function_db.
- There are also some minor UX tests

Future:
- jvp support
- likely more testing to come, but please let me know if you have
cases that you want me to test here.

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Dec 19, 2022
Design document:
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit

This PR adds a `generate_vmap_rule` option (default False) to autograd.Function.
By setting it to True, a user promises to us that their autograd.Function's
{forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other
limitations of autograd.Function+functorch (such as the user not
capturing any Tensors being transformed over from outside of the
autograd.Function).

Concretely, the approach is:
- we update `custom_function_call` to accept an additional
`generate_vmap_rule` argument.
- The vmap rule for `custom_function_call` and `generate_vmap_rule=True`
is: we construct a vmapped version of the autograd.Function and dispatch
on it.
- The vmapped version of the autograd.Function can be thought of like
the following: if we have an autograd.Function Foo, then
VmappedFoo.apply(in_dims, ...) has the same semantics as
vmap(Foo.apply, in_dims...)
- VmappedFoo's forward, setup_context, and backward staticmethod are
vmapped versions of Foo's staticmethods.
- See the design doc for more motivation and explanation

Test Plan:
- This PR introduces additional autograd.Function with the suffix "GenVmap" to
autograd_function_db.
- There are also some minor UX tests

Future:
- jvp support
- likely more testing to come, but please let me know if you have
cases that you want me to test here.

ghstack-source-id: 7b65454f69ca077f9357eda22c5ae045a627446d
Pull Request resolved: #90966
Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

…nction"

Design document:
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit

This PR adds a `generate_vmap_rule` option (default False) to autograd.Function.
By setting it to True, a user promises to us that their autograd.Function's
{forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other
limitations of autograd.Function+functorch (such as the user not
capturing any Tensors being transformed over from outside of the
autograd.Function).

Concretely, the approach is:
- we update `custom_function_call` to accept an additional
`generate_vmap_rule` argument.
- The vmap rule for `custom_function_call` and `generate_vmap_rule=True`
is: we construct a vmapped version of the autograd.Function and dispatch
on it.
- The vmapped version of the autograd.Function can be thought of like
the following: if we have an autograd.Function Foo, then
VmappedFoo.apply(in_dims, ...) has the same semantics as
vmap(Foo.apply, in_dims...)
- VmappedFoo's forward, setup_context, and backward staticmethod are
vmapped versions of Foo's staticmethods.
- See the design doc for more motivation and explanation

Test Plan:
- This PR introduces additional autograd.Function with the suffix "GenVmap" to
autograd_function_db.
- There are also some minor UX tests

Future:
- jvp support
- likely more testing to come, but please let me know if you have
cases that you want me to test here.

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Dec 20, 2022
Design document:
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit

This PR adds a `generate_vmap_rule` option (default False) to autograd.Function.
By setting it to True, a user promises to us that their autograd.Function's
{forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other
limitations of autograd.Function+functorch (such as the user not
capturing any Tensors being transformed over from outside of the
autograd.Function).

Concretely, the approach is:
- we update `custom_function_call` to accept an additional
`generate_vmap_rule` argument.
- The vmap rule for `custom_function_call` and `generate_vmap_rule=True`
is: we construct a vmapped version of the autograd.Function and dispatch
on it.
- The vmapped version of the autograd.Function can be thought of like
the following: if we have an autograd.Function Foo, then
VmappedFoo.apply(in_dims, ...) has the same semantics as
vmap(Foo.apply, in_dims...)
- VmappedFoo's forward, setup_context, and backward staticmethod are
vmapped versions of Foo's staticmethods.
- See the design doc for more motivation and explanation

Test Plan:
- This PR introduces additional autograd.Function with the suffix "GenVmap" to
autograd_function_db.
- There are also some minor UX tests

Future:
- jvp support
- likely more testing to come, but please let me know if you have
cases that you want me to test here.

ghstack-source-id: 02395a796d8b957499bdfb1db8293e0e64cb6beb
Pull Request resolved: #90966
@zou3519 zou3519 added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 20, 2022
@zou3519
Copy link
Contributor Author

zou3519 commented Dec 21, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/zou3519/587/head branch June 8, 2023 19:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: functorch release notes category; Pertaining to torch.func or pytorch/functorch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants