-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
functorch.grad support for autograd.Function #89860
Conversation
Happy to split this PR more if it helps. This PR adds functorch.grad support for autograd.Function. There's a lot going on; here is the high level picture and there are more details as comments in the code. Mechanism (PyOperator) - Somehow, autograd.Function needs to dispatch with functorch. This is necessary because every layer of functorch needs to see the autograd.Function; grad layers need to preserve the backward pass. - The mechanism for this is via PyOperator. If functorch transforms are active, then we wrap the autograd.Function in a `custom_function_call` PyOperator where we are able to define various rules for functorch transforms. - `custom_function_call` has a rule for the functorch grad transform. autograd.Function changes - I needed to make some changes to autograd.Function to make this work. - First, this PR splits autograd.Function into a _SingleLevelFunction (that works with a single level of functorch transform) and autograd.Function (which works with multiple levels). This is necessary because functorch's grad rule needs some way of specifying a backward pass for that level only. - This PR changes autograd.Function's apply to eitehr call `custom_function_call` (if functorch is active) or super().apply (if functorch isn't active). Testing - Most of this PR is just testing. It creates an autograd.Function OpInfo database that then gets passed to the functorch grad-based tests (grad, vjp, vjpvjp). - Since functorch transform tests are autogenerated from OpInfo tests, this is the easiest way to test various autograd.Function with functorch. Future - jvp and vmap support coming next - better error message (functorch only supports autograd.Function that have the optional setup_context staticmethod) - documentation to come when we remove the feature flag [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89860
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 3f953d1: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Happy to split this PR more if it helps. This PR adds functorch.grad support for autograd.Function. There's a lot going on; here is the high level picture and there are more details as comments in the code. Mechanism (PyOperator) - Somehow, autograd.Function needs to dispatch with functorch. This is necessary because every layer of functorch needs to see the autograd.Function; grad layers need to preserve the backward pass. - The mechanism for this is via PyOperator. If functorch transforms are active, then we wrap the autograd.Function in a `custom_function_call` PyOperator where we are able to define various rules for functorch transforms. - `custom_function_call` has a rule for the functorch grad transform. autograd.Function changes - I needed to make some changes to autograd.Function to make this work. - First, this PR splits autograd.Function into a _SingleLevelFunction (that works with a single level of functorch transform) and autograd.Function (which works with multiple levels). This is necessary because functorch's grad rule needs some way of specifying a backward pass for that level only. - This PR changes autograd.Function's apply to eitehr call `custom_function_call` (if functorch is active) or super().apply (if functorch isn't active). Testing - Most of this PR is just testing. It creates an autograd.Function OpInfo database that then gets passed to the functorch grad-based tests (grad, vjp, vjpvjp). - Since functorch transform tests are autogenerated from OpInfo tests, this is the easiest way to test various autograd.Function with functorch. Future - jvp and vmap support coming next - better error message (functorch only supports autograd.Function that have the optional setup_context staticmethod) - documentation to come when we remove the feature flag ghstack-source-id: 75a2e68ed8a4032f17c70271bac9e648a604bb4d Pull Request resolved: #89860
Happy to split this PR more if it helps. This PR adds functorch.grad support for autograd.Function. There's a lot going on; here is the high level picture and there are more details as comments in the code. Mechanism (PyOperator) - Somehow, autograd.Function needs to dispatch with functorch. This is necessary because every layer of functorch needs to see the autograd.Function; grad layers need to preserve the backward pass. - The mechanism for this is via PyOperator. If functorch transforms are active, then we wrap the autograd.Function in a `custom_function_call` PyOperator where we are able to define various rules for functorch transforms. - `custom_function_call` has a rule for the functorch grad transform. autograd.Function changes - I needed to make some changes to autograd.Function to make this work. - First, this PR splits autograd.Function into a _SingleLevelFunction (that works with a single level of functorch transform) and autograd.Function (which works with multiple levels). This is necessary because functorch's grad rule needs some way of specifying a backward pass for that level only. - This PR changes autograd.Function's apply to eitehr call `custom_function_call` (if functorch is active) or super().apply (if functorch isn't active). Testing - Most of this PR is just testing. It creates an autograd.Function OpInfo database that then gets passed to the functorch grad-based tests (grad, vjp, vjpvjp). - Since functorch transform tests are autogenerated from OpInfo tests, this is the easiest way to test various autograd.Function with functorch. Future - jvp and vmap support coming next - better error message (functorch only supports autograd.Function that have the optional setup_context staticmethod) - documentation to come when we remove the feature flag [ghstack-poisoned]
Happy to split this PR more if it helps. This PR adds functorch.grad support for autograd.Function. There's a lot going on; here is the high level picture and there are more details as comments in the code. Mechanism (PyOperator) - Somehow, autograd.Function needs to dispatch with functorch. This is necessary because every layer of functorch needs to see the autograd.Function; grad layers need to preserve the backward pass. - The mechanism for this is via PyOperator. If functorch transforms are active, then we wrap the autograd.Function in a `custom_function_call` PyOperator where we are able to define various rules for functorch transforms. - `custom_function_call` has a rule for the functorch grad transform. autograd.Function changes - I needed to make some changes to autograd.Function to make this work. - First, this PR splits autograd.Function into a _SingleLevelFunction (that works with a single level of functorch transform) and autograd.Function (which works with multiple levels). This is necessary because functorch's grad rule needs some way of specifying a backward pass for that level only. - This PR changes autograd.Function's apply to eitehr call `custom_function_call` (if functorch is active) or super().apply (if functorch isn't active). Testing - Most of this PR is just testing. It creates an autograd.Function OpInfo database that then gets passed to the functorch grad-based tests (grad, vjp, vjpvjp). - Since functorch transform tests are autogenerated from OpInfo tests, this is the easiest way to test various autograd.Function with functorch. Future - jvp and vmap support coming next - better error message (functorch only supports autograd.Function that have the optional setup_context staticmethod) - documentation to come when we remove the feature flag ghstack-source-id: e99b1780f089e57cf4554a2e752a754944f7f91a Pull Request resolved: #89860
return autograd_function.apply(*args[1:], **kwargs) | ||
|
||
|
||
# "custom_function_call" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naming question: I refer to autograd.Function as "AutogradFunction" sometimes, and sometimes as "CustomFunction". What do you guys want me to call it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the instance of CustomFunctionPyOperator be private? I see it being called as part of "apply', but is there a case when someone would someone call it directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything inside torch/_functorch/autograd_function.py
is private, including custom_function_call. Users should not be calling it directly.
Developers should interact with custom_function_call directly, instead of CustomFunctionPyOperator if that is the question. custom_function_call is an instance of a PyOperator, CustomFunctionPyOperator is just a class.
torch/autograd/function.py
Outdated
# TODO: fix circular import | ||
from torch._functorch.autograd_function import custom_function_call |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer to handle this in a follow-up:
- due to the feature flag this is not exercised in the fast-path
- It might involve a refactor of torch/_ops.py
Happy to split this PR more if it helps. This PR adds functorch.grad support for autograd.Function. There's a lot going on; here is the high level picture and there are more details as comments in the code. Mechanism (PyOperator) - Somehow, autograd.Function needs to dispatch with functorch. This is necessary because every layer of functorch needs to see the autograd.Function; grad layers need to preserve the backward pass. - The mechanism for this is via PyOperator. If functorch transforms are active, then we wrap the autograd.Function in a `custom_function_call` PyOperator where we are able to define various rules for functorch transforms. - `custom_function_call` has a rule for the functorch grad transform. autograd.Function changes - I needed to make some changes to autograd.Function to make this work. - First, this PR splits autograd.Function into a _SingleLevelFunction (that works with a single level of functorch transform) and autograd.Function (which works with multiple levels). This is necessary because functorch's grad rule needs some way of specifying a backward pass for that level only. - This PR changes autograd.Function's apply to eitehr call `custom_function_call` (if functorch is active) or super().apply (if functorch isn't active). Testing - Most of this PR is just testing. It creates an autograd.Function OpInfo database that then gets passed to the functorch grad-based tests (grad, vjp, vjpvjp). - Since functorch transform tests are autogenerated from OpInfo tests, this is the easiest way to test various autograd.Function with functorch. Future - jvp and vmap support coming next - better error message (functorch only supports autograd.Function that have the optional setup_context staticmethod) - documentation to come when we remove the feature flag [ghstack-poisoned]
Happy to split this PR more if it helps. This PR adds functorch.grad support for autograd.Function. There's a lot going on; here is the high level picture and there are more details as comments in the code. Mechanism (PyOperator) - Somehow, autograd.Function needs to dispatch with functorch. This is necessary because every layer of functorch needs to see the autograd.Function; grad layers need to preserve the backward pass. - The mechanism for this is via PyOperator. If functorch transforms are active, then we wrap the autograd.Function in a `custom_function_call` PyOperator where we are able to define various rules for functorch transforms. - `custom_function_call` has a rule for the functorch grad transform. autograd.Function changes - I needed to make some changes to autograd.Function to make this work. - First, this PR splits autograd.Function into a _SingleLevelFunction (that works with a single level of functorch transform) and autograd.Function (which works with multiple levels). This is necessary because functorch's grad rule needs some way of specifying a backward pass for that level only. - This PR changes autograd.Function's apply to eitehr call `custom_function_call` (if functorch is active) or super().apply (if functorch isn't active). Testing - Most of this PR is just testing. It creates an autograd.Function OpInfo database that then gets passed to the functorch grad-based tests (grad, vjp, vjpvjp). - Since functorch transform tests are autogenerated from OpInfo tests, this is the easiest way to test various autograd.Function with functorch. Future - jvp and vmap support coming next - better error message (functorch only supports autograd.Function that have the optional setup_context staticmethod) - documentation to come when we remove the feature flag ghstack-source-id: b507a3058ef079b8611b258c55e90e9fb31920e8 Pull Request resolved: #89860
Happy to split this PR more if it helps. This PR adds functorch.grad support for autograd.Function. There's a lot going on; here is the high level picture and there are more details as comments in the code. Mechanism (PyOperator) - Somehow, autograd.Function needs to dispatch with functorch. This is necessary because every layer of functorch needs to see the autograd.Function; grad layers need to preserve the backward pass. - The mechanism for this is via PyOperator. If functorch transforms are active, then we wrap the autograd.Function in a `custom_function_call` PyOperator where we are able to define various rules for functorch transforms. - `custom_function_call` has a rule for the functorch grad transform. autograd.Function changes - I needed to make some changes to autograd.Function to make this work. - First, this PR splits autograd.Function into a _SingleLevelFunction (that works with a single level of functorch transform) and autograd.Function (which works with multiple levels). This is necessary because functorch's grad rule needs some way of specifying a backward pass for that level only. - This PR changes autograd.Function's apply to eitehr call `custom_function_call` (if functorch is active) or super().apply (if functorch isn't active). Testing - Most of this PR is just testing. It creates an autograd.Function OpInfo database that then gets passed to the functorch grad-based tests (grad, vjp, vjpvjp). - Since functorch transform tests are autogenerated from OpInfo tests, this is the easiest way to test various autograd.Function with functorch. Future - jvp and vmap support coming next - better error message (functorch only supports autograd.Function that have the optional setup_context staticmethod) - documentation to come when we remove the feature flag [ghstack-poisoned]
Happy to split this PR more if it helps. This PR adds functorch.grad support for autograd.Function. There's a lot going on; here is the high level picture and there are more details as comments in the code. Mechanism (PyOperator) - Somehow, autograd.Function needs to dispatch with functorch. This is necessary because every layer of functorch needs to see the autograd.Function; grad layers need to preserve the backward pass. - The mechanism for this is via PyOperator. If functorch transforms are active, then we wrap the autograd.Function in a `custom_function_call` PyOperator where we are able to define various rules for functorch transforms. - `custom_function_call` has a rule for the functorch grad transform. autograd.Function changes - I needed to make some changes to autograd.Function to make this work. - First, this PR splits autograd.Function into a _SingleLevelFunction (that works with a single level of functorch transform) and autograd.Function (which works with multiple levels). This is necessary because functorch's grad rule needs some way of specifying a backward pass for that level only. - This PR changes autograd.Function's apply to eitehr call `custom_function_call` (if functorch is active) or super().apply (if functorch isn't active). Testing - Most of this PR is just testing. It creates an autograd.Function OpInfo database that then gets passed to the functorch grad-based tests (grad, vjp, vjpvjp). - Since functorch transform tests are autogenerated from OpInfo tests, this is the easiest way to test various autograd.Function with functorch. Future - jvp and vmap support coming next - better error message (functorch only supports autograd.Function that have the optional setup_context staticmethod) - documentation to come when we remove the feature flag ghstack-source-id: aee6e80eff134b781ad4cb00f4b16f0ffe1741c0 Pull Request resolved: #89860
There appear to be some issues with dynamo, but other than that, this should be ready for review |
Happy to split this PR more if it helps. This PR adds functorch.grad support for autograd.Function. There's a lot going on; here is the high level picture and there are more details as comments in the code. Mechanism (PyOperator) - Somehow, autograd.Function needs to dispatch with functorch. This is necessary because every layer of functorch needs to see the autograd.Function; grad layers need to preserve the backward pass. - The mechanism for this is via PyOperator. If functorch transforms are active, then we wrap the autograd.Function in a `custom_function_call` PyOperator where we are able to define various rules for functorch transforms. - `custom_function_call` has a rule for the functorch grad transform. autograd.Function changes - I needed to make some changes to autograd.Function to make this work. - First, this PR splits autograd.Function into a _SingleLevelFunction (that works with a single level of functorch transform) and autograd.Function (which works with multiple levels). This is necessary because functorch's grad rule needs some way of specifying a backward pass for that level only. - This PR changes autograd.Function's apply to eitehr call `custom_function_call` (if functorch is active) or super().apply (if functorch isn't active). Testing - Most of this PR is just testing. It creates an autograd.Function OpInfo database that then gets passed to the functorch grad-based tests (grad, vjp, vjpvjp). - Since functorch transform tests are autogenerated from OpInfo tests, this is the easiest way to test various autograd.Function with functorch. Future - jvp and vmap support coming next - better error message (functorch only supports autograd.Function that have the optional setup_context staticmethod) - documentation to come when we remove the feature flag [ghstack-poisoned]
Happy to split this PR more if it helps. This PR adds functorch.grad support for autograd.Function. There's a lot going on; here is the high level picture and there are more details as comments in the code. Mechanism (PyOperator) - Somehow, autograd.Function needs to dispatch with functorch. This is necessary because every layer of functorch needs to see the autograd.Function; grad layers need to preserve the backward pass. - The mechanism for this is via PyOperator. If functorch transforms are active, then we wrap the autograd.Function in a `custom_function_call` PyOperator where we are able to define various rules for functorch transforms. - `custom_function_call` has a rule for the functorch grad transform. autograd.Function changes - I needed to make some changes to autograd.Function to make this work. - First, this PR splits autograd.Function into a _SingleLevelFunction (that works with a single level of functorch transform) and autograd.Function (which works with multiple levels). This is necessary because functorch's grad rule needs some way of specifying a backward pass for that level only. - This PR changes autograd.Function's apply to eitehr call `custom_function_call` (if functorch is active) or super().apply (if functorch isn't active). Testing - Most of this PR is just testing. It creates an autograd.Function OpInfo database that then gets passed to the functorch grad-based tests (grad, vjp, vjpvjp). - Since functorch transform tests are autogenerated from OpInfo tests, this is the easiest way to test various autograd.Function with functorch. Future - jvp and vmap support coming next - better error message (functorch only supports autograd.Function that have the optional setup_context staticmethod) - documentation to come when we remove the feature flag ghstack-source-id: c4242fe11a8c2655f3199d377340fc6b7f9932f4 Pull Request resolved: #89860
Happy to split this PR more if it helps. This PR adds functorch.grad support for autograd.Function. There's a lot going on; here is the high level picture and there are more details as comments in the code. Mechanism (PyOperator) - Somehow, autograd.Function needs to dispatch with functorch. This is necessary because every layer of functorch needs to see the autograd.Function; grad layers need to preserve the backward pass. - The mechanism for this is via PyOperator. If functorch transforms are active, then we wrap the autograd.Function in a `custom_function_call` PyOperator where we are able to define various rules for functorch transforms. - `custom_function_call` has a rule for the functorch grad transform. autograd.Function changes - I needed to make some changes to autograd.Function to make this work. - First, this PR splits autograd.Function into a _SingleLevelFunction (that works with a single level of functorch transform) and autograd.Function (which works with multiple levels). This is necessary because functorch's grad rule needs some way of specifying a backward pass for that level only. - This PR changes autograd.Function's apply to eitehr call `custom_function_call` (if functorch is active) or super().apply (if functorch isn't active). Testing - Most of this PR is just testing. It creates an autograd.Function OpInfo database that then gets passed to the functorch grad-based tests (grad, vjp, vjpvjp). - Since functorch transform tests are autogenerated from OpInfo tests, this is the easiest way to test various autograd.Function with functorch. Future - jvp and vmap support coming next - better error message (functorch only supports autograd.Function that have the optional setup_context staticmethod) - documentation to come when we remove the feature flag cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
Happy to split this PR more if it helps. This PR adds functorch.grad support for autograd.Function. There's a lot going on; here is the high level picture and there are more details as comments in the code. Mechanism (PyOperator) - Somehow, autograd.Function needs to dispatch with functorch. This is necessary because every layer of functorch needs to see the autograd.Function; grad layers need to preserve the backward pass. - The mechanism for this is via PyOperator. If functorch transforms are active, then we wrap the autograd.Function in a `custom_function_call` PyOperator where we are able to define various rules for functorch transforms. - `custom_function_call` has a rule for the functorch grad transform. autograd.Function changes - I needed to make some changes to autograd.Function to make this work. - First, this PR splits autograd.Function into a _SingleLevelFunction (that works with a single level of functorch transform) and autograd.Function (which works with multiple levels). This is necessary because functorch's grad rule needs some way of specifying a backward pass for that level only. - This PR changes autograd.Function's apply to eitehr call `custom_function_call` (if functorch is active) or super().apply (if functorch isn't active). Testing - Most of this PR is just testing. It creates an autograd.Function OpInfo database that then gets passed to the functorch grad-based tests (grad, vjp, vjpvjp). - Since functorch transform tests are autogenerated from OpInfo tests, this is the easiest way to test various autograd.Function with functorch. Future - jvp and vmap support coming next - better error message (functorch only supports autograd.Function that have the optional setup_context staticmethod) - documentation to come when we remove the feature flag ghstack-source-id: eff78b1cee9eccb020bb8cc0a3cb8a83ec91a652 Pull Request resolved: #89860
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty cool!
return autograd_function.apply(*args[1:], **kwargs) | ||
|
||
|
||
# "custom_function_call" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the instance of CustomFunctionPyOperator be private? I see it being called as part of "apply', but is there a case when someone would someone call it directly?
} | ||
|
||
Tensor JvpInterpreterPtr::lift(const Tensor& tensor) const { | ||
return materializeGradWrappers(tensor, level()); | ||
return base_lift(tensor, level()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When does this happen? Why didn't we need it before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is necessary to support functorch.vjp. functorch.vjp constructs a graph using GradWrapperTensors which then all become dead tensor wrappers.
In the C++ the functorch grad transform always unwraps dead tensor wrappers (
pytorch/aten/src/ATen/functorch/ADInterpreters.cpp
Lines 60 to 68 in a6caa9c
// if is a grad transform, and the operation is in-place, and the mutated | |
// argument is not currently wrapped in a TensorWrapper, then we need to | |
// error out otherwise the result is silently incorrect | |
checkForInvalidMutationOnCaptures(op, stack, current_level); | |
// materialize live GradWrappers | |
auto maybeTransformGradWrappers = [&](const Tensor& tensor) { | |
return materializeGradWrappers(tensor, current_level); | |
}; |
In Python, we want to be consistent with that behavior (
pytorch/torch/_functorch/pyfunctorch.py
Lines 107 to 110 in a6caa9c
def process(self, op, args, kwargs): | |
kernel = op.functorch_table[TransformType.Grad] | |
args, kwargs = self.lift(args, kwargs) | |
return kernel(self, *args, **kwargs) |
Previously we didn't need it because I didn't actually test functorch.vjp; this PR exercises the functorch vjp tests and so we hit this case.
EDIT: I commented directly on the thread. |
Happy to split this PR more if it helps. This PR adds functorch.grad support for autograd.Function. There's a lot going on; here is the high level picture and there are more details as comments in the code. Mechanism (PyOperator) - Somehow, autograd.Function needs to dispatch with functorch. This is necessary because every layer of functorch needs to see the autograd.Function; grad layers need to preserve the backward pass. - The mechanism for this is via PyOperator. If functorch transforms are active, then we wrap the autograd.Function in a `custom_function_call` PyOperator where we are able to define various rules for functorch transforms. - `custom_function_call` has a rule for the functorch grad transform. autograd.Function changes - I needed to make some changes to autograd.Function to make this work. - First, this PR splits autograd.Function into a _SingleLevelFunction (that works with a single level of functorch transform) and autograd.Function (which works with multiple levels). This is necessary because functorch's grad rule needs some way of specifying a backward pass for that level only. - This PR changes autograd.Function's apply to eitehr call `custom_function_call` (if functorch is active) or super().apply (if functorch isn't active). Testing - Most of this PR is just testing. It creates an autograd.Function OpInfo database that then gets passed to the functorch grad-based tests (grad, vjp, vjpvjp). - Since functorch transform tests are autogenerated from OpInfo tests, this is the easiest way to test various autograd.Function with functorch. Future - jvp and vmap support coming next - better error message (functorch only supports autograd.Function that have the optional setup_context staticmethod) - documentation to come when we remove the feature flag cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
❌ 🤖 pytorchbot command failed:
Try |
@pytorchbot merge -f "all test failures look unrelated (they didn't fail on the topmost PR) |
❌ 🤖 pytorchbot command failed:
|
@pytorchbot merge -f "all test failures look unrelated (they didn't fail on the topmost PR)" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Happy to split this PR more if it helps. This PR adds functorch.grad support for autograd.Function. There's a lot going on; here is the high level picture and there are more details as comments in the code. Mechanism (PyOperator) - Somehow, autograd.Function needs to dispatch with functorch. This is necessary because every layer of functorch needs to see the autograd.Function; grad layers need to preserve the backward pass. - The mechanism for this is via PyOperator. If functorch transforms are active, then we wrap the autograd.Function in a `custom_function_call` PyOperator where we are able to define various rules for functorch transforms. - `custom_function_call` has a rule for the functorch grad transform. autograd.Function changes - I needed to make some changes to autograd.Function to make this work. - First, this PR splits autograd.Function into a _SingleLevelFunction (that works with a single level of functorch transform) and autograd.Function (which works with multiple levels). This is necessary because functorch's grad rule needs some way of specifying a backward pass for that level only. - This PR changes autograd.Function's apply to eitehr call `custom_function_call` (if functorch is active) or super().apply (if functorch isn't active). Testing - Most of this PR is just testing. It creates an autograd.Function OpInfo database that then gets passed to the functorch grad-based tests (grad, vjp, vjpvjp). - Since functorch transform tests are autogenerated from OpInfo tests, this is the easiest way to test various autograd.Function with functorch. Future - jvp and vmap support coming next - better error message (functorch only supports autograd.Function that have the optional setup_context staticmethod) - documentation to come when we remove the feature flag Pull Request resolved: pytorch#89860 Approved by: https://github.com/soulitzer
grad_y = torch.randn_like(x) | ||
|
||
def h(x, grad_y): | ||
_, vjp_fn = vjp(f, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you forgot to define f for this one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lol, looks like the expected failure masked this. Thanks for catching.
…on to allow mark_dirty" Uses what was originally in #89860 [ghstack-poisoned]
…_dirty" Uses what was originally in #89860 [ghstack-poisoned]
…on to allow mark_dirty" Uses what was originally in #89860 [ghstack-poisoned]
…_dirty" Uses what was originally in #89860 [ghstack-poisoned]
) Fixes #90225 Uses what was originally in #89860 Pull Request resolved: #91222 Approved by: https://github.com/zou3519
Stack from ghstack (oldest at bottom):
Happy to split this PR more if it helps.
This PR adds functorch.grad support for autograd.Function. There's a lot
going on; here is the high level picture and there are more details as
comments in the code.
Mechanism (PyOperator)
necessary because every layer of functorch needs to see the
autograd.Function; grad layers need to preserve the backward pass.
active, then we wrap the autograd.Function in a
custom_function_call
PyOperator where we are able to define various rules for functorch
transforms.
custom_function_call
has a rule for the functorch grad transform.autograd.Function changes
(that works with a single level of functorch transform) and
autograd.Function (which works with multiple levels). This is necessary
because functorch's grad rule needs some way of specifying a backward
pass for that level only.
custom_function_call
(if functorch is active) or super().apply (iffunctorch isn't active).
Testing
OpInfo database that then gets passed to the functorch grad-based tests
(grad, vjp, vjpvjp).
this is the easiest way to test various autograd.Function with
functorch.
Future
have the optional setup_context staticmethod)
cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire