-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[custom_op] explicit autograd API #101824
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR adds an explicit API for registering a backward formula for a CustomOp. In the end state, we will likely have this explicit API and a magic API (which is sugar on top of an explicit API), since different parties of users prefer different ones. Concretely, to define a backward formula for a CustomOp: - a user must provide us a "save for backward" function that accepts (inputs, output) and returns exactly what they want saved for backward - a user must provide us a "backward" function that accepts (ctx, saved, *grads) and returns us the grad_inputs. The grad_inputs are returned as a dict mapping str to a gradient. Please see the changes in custom_op_db.py for examples of the API. There are a number of pieces to this PR and I'm happy to split it if it helps. They are: - The actual APIs for specifying the two functions (impl_save_for_backward, impl_backward) - The autograd kernel: we take the functions the user give us and construct an autograd.Function object that we then register to the Autograd dispatch key - Indirection for the autograd kernel. We add a layer of indirection so that one can swap out the autograd kernel. This is necessary because by default, we register an "autograd not implemented" kernel as the Autograd implementation but then swap it for the actual kernel when the user provides it. Test Plan: - We apply this API to give backward formulas for things in custom_op_db. We then hook up custom_op_db to the Autograd OpInfo tests. - Various tests in test_python_dispatch.py to check error cases. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/101824
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7442ecd: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR adds an explicit API for registering a backward formula for a CustomOp. In the end state, we will likely have this explicit API and a magic API (which is sugar on top of an explicit API), since different parties of users prefer different ones. Concretely, to define a backward formula for a CustomOp: - a user must provide us a "save for backward" function that accepts (inputs, output) and returns exactly what they want saved for backward - a user must provide us a "backward" function that accepts (ctx, saved, *grads) and returns us the grad_inputs. The grad_inputs are returned as a dict mapping str to a gradient. Please see the changes in custom_op_db.py for examples of the API. There are a number of pieces to this PR and I'm happy to split it if it helps. They are: - The actual APIs for specifying the two functions (impl_save_for_backward, impl_backward) - The autograd kernel: we take the functions the user give us and construct an autograd.Function object that we then register to the Autograd dispatch key - Indirection for the autograd kernel. We add a layer of indirection so that one can swap out the autograd kernel. This is necessary because by default, we register an "autograd not implemented" kernel as the Autograd implementation but then swap it for the actual kernel when the user provides it. Test Plan: - We apply this API to give backward formulas for things in custom_op_db. We then hook up custom_op_db to the Autograd OpInfo tests. - Various tests in test_python_dispatch.py to check error cases. ghstack-source-id: 55ff699 Pull Request resolved: #101824
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. But perhaps you want to have one of the autograd experts take a look over.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. But perhaps you want to have one of the autograd experts take a look over.
torch/_custom_op/autograd.py
Outdated
return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info) | ||
|
||
generated_cls = gen_autograd_function( | ||
forward_op._opname + 'CustomOp', forward, backward) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: lower case vs camel case?
|
||
@numpy_mul.impl_backward() | ||
def numpy_mul_backward(ctx, saved, grad_out): | ||
grad_x = grad_out * saved['y'] if saved['x_requires_grad'] else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ctx should be telling you what to compute gradients for really.
This is fine as a test but we should not recommend users do this. We should just expose needs_input_grad
on the ctx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, I didn't add needs_input_grad in this PR because it was getting long.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add it as a follow-up, unless you want to see it in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to add it here as long as this is not in any user-facing example!
save_for_backward_fn_inputs = namedtuple_args(schema, args) | ||
to_save = save_for_backward_fn(save_for_backward_fn_inputs, output) | ||
|
||
save_pytree_for_backward(ctx, (to_save, args_info)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit we should be able to skip this when grad mode is disabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a little complicated to do this because autograd.Function's forward sets grad_mode to False. This means that we would need to shepherd the knowledge of if grad mode is disabled or not from somewhere else.
Since this is a nit I'm just going to add this to the wishlist.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.
This PR adds an explicit API for registering a backward formula for a CustomOp. In the end state, we will likely have this explicit API and a magic API (which is sugar on top of an explicit API), since different parties of users prefer different ones. Concretely, to define a backward formula for a CustomOp: - a user must provide us a "save for backward" function that accepts (inputs, output) and returns exactly what they want saved for backward - a user must provide us a "backward" function that accepts (ctx, saved, *grads) and returns us the grad_inputs. The grad_inputs are returned as a dict mapping str to a gradient. Please see the changes in custom_op_db.py for examples of the API. There are a number of pieces to this PR and I'm happy to split it if it helps. They are: - The actual APIs for specifying the two functions (impl_save_for_backward, impl_backward) - The autograd kernel: we take the functions the user give us and construct an autograd.Function object that we then register to the Autograd dispatch key - Indirection for the autograd kernel. We add a layer of indirection so that one can swap out the autograd kernel. This is necessary because by default, we register an "autograd not implemented" kernel as the Autograd implementation but then swap it for the actual kernel when the user provides it. Test Plan: - We apply this API to give backward formulas for things in custom_op_db. We then hook up custom_op_db to the Autograd OpInfo tests. - Various tests in test_python_dispatch.py to check error cases. [ghstack-poisoned]
This PR adds an explicit API for registering a backward formula for a CustomOp. In the end state, we will likely have this explicit API and a magic API (which is sugar on top of an explicit API), since different parties of users prefer different ones. Concretely, to define a backward formula for a CustomOp: - a user must provide us a "save for backward" function that accepts (inputs, output) and returns exactly what they want saved for backward - a user must provide us a "backward" function that accepts (ctx, saved, *grads) and returns us the grad_inputs. The grad_inputs are returned as a dict mapping str to a gradient. Please see the changes in custom_op_db.py for examples of the API. There are a number of pieces to this PR and I'm happy to split it if it helps. They are: - The actual APIs for specifying the two functions (impl_save_for_backward, impl_backward) - The autograd kernel: we take the functions the user give us and construct an autograd.Function object that we then register to the Autograd dispatch key - Indirection for the autograd kernel. We add a layer of indirection so that one can swap out the autograd kernel. This is necessary because by default, we register an "autograd not implemented" kernel as the Autograd implementation but then swap it for the actual kernel when the user provides it. Test Plan: - We apply this API to give backward formulas for things in custom_op_db. We then hook up custom_op_db to the Autograd OpInfo tests. - Various tests in test_python_dispatch.py to check error cases. [ghstack-poisoned]
This PR adds an explicit API for registering a backward formula for a CustomOp. In the end state, we will likely have this explicit API and a magic API (which is sugar on top of an explicit API), since different parties of users prefer different ones. Concretely, to define a backward formula for a CustomOp: - a user must provide us a "save for backward" function that accepts (inputs, output) and returns exactly what they want saved for backward - a user must provide us a "backward" function that accepts (ctx, saved, *grads) and returns us the grad_inputs. The grad_inputs are returned as a dict mapping str to a gradient. Please see the changes in custom_op_db.py for examples of the API. There are a number of pieces to this PR and I'm happy to split it if it helps. They are: - The actual APIs for specifying the two functions (impl_save_for_backward, impl_backward) - The autograd kernel: we take the functions the user give us and construct an autograd.Function object that we then register to the Autograd dispatch key - Indirection for the autograd kernel. We add a layer of indirection so that one can swap out the autograd kernel. This is necessary because by default, we register an "autograd not implemented" kernel as the Autograd implementation but then swap it for the actual kernel when the user provides it. Test Plan: - We apply this API to give backward formulas for things in custom_op_db. We then hook up custom_op_db to the Autograd OpInfo tests. - Various tests in test_python_dispatch.py to check error cases. ghstack-source-id: 553a1ca Pull Request resolved: #101824
This PR adds an explicit API for registering a backward formula for a CustomOp. In the end state, we will likely have this explicit API and a magic API (which is sugar on top of an explicit API), since different parties of users prefer different ones. Concretely, to define a backward formula for a CustomOp: - a user must provide us a "save for backward" function that accepts (inputs, output) and returns exactly what they want saved for backward - a user must provide us a "backward" function that accepts (ctx, saved, *grads) and returns us the grad_inputs. The grad_inputs are returned as a dict mapping str to a gradient. Please see the changes in custom_op_db.py for examples of the API. There are a number of pieces to this PR and I'm happy to split it if it helps. They are: - The actual APIs for specifying the two functions (impl_save_for_backward, impl_backward) - The autograd kernel: we take the functions the user give us and construct an autograd.Function object that we then register to the Autograd dispatch key - Indirection for the autograd kernel. We add a layer of indirection so that one can swap out the autograd kernel. This is necessary because by default, we register an "autograd not implemented" kernel as the Autograd implementation but then swap it for the actual kernel when the user provides it. Test Plan: - We apply this API to give backward formulas for things in custom_op_db. We then hook up custom_op_db to the Autograd OpInfo tests. - Various tests in test_python_dispatch.py to check error cases. [ghstack-poisoned]
This PR adds an explicit API for registering a backward formula for a CustomOp. In the end state, we will likely have this explicit API and a magic API (which is sugar on top of an explicit API), since different parties of users prefer different ones. Concretely, to define a backward formula for a CustomOp: - a user must provide us a "save for backward" function that accepts (inputs, output) and returns exactly what they want saved for backward - a user must provide us a "backward" function that accepts (ctx, saved, *grads) and returns us the grad_inputs. The grad_inputs are returned as a dict mapping str to a gradient. Please see the changes in custom_op_db.py for examples of the API. There are a number of pieces to this PR and I'm happy to split it if it helps. They are: - The actual APIs for specifying the two functions (impl_save_for_backward, impl_backward) - The autograd kernel: we take the functions the user give us and construct an autograd.Function object that we then register to the Autograd dispatch key - Indirection for the autograd kernel. We add a layer of indirection so that one can swap out the autograd kernel. This is necessary because by default, we register an "autograd not implemented" kernel as the Autograd implementation but then swap it for the actual kernel when the user provides it. Test Plan: - We apply this API to give backward formulas for things in custom_op_db. We then hook up custom_op_db to the Autograd OpInfo tests. - Various tests in test_python_dispatch.py to check error cases. ghstack-source-id: c40cade Pull Request resolved: #101824
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
were declared to be Tensors in the CustomOp definition must be accounted | ||
for in the dict. The gradient may be a Tensor or None. | ||
TODO(rzou): Add example when this PR is closer to landing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I remembered this after I typed the merge command. Will do as a follow-up.
Regarding the syntax, could @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...
@foo.impl(['cpu', 'cuda'])
def foo_impl(x):
return x.sin()
@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return grad * saved.cos() proposed: @custom_op_with_forward(f'{TestCustomOp.test_ns}::foo', device_types = ['cpu', 'cuda'])
def foo(x : torch.Tensor) -> torch.Tensor:
return x.sin()
@foo.impl_backward()
def foo_backward(ctx, saved, grad):
return grad * saved.cos() |
Stack from ghstack:
This PR adds an explicit API for registering a backward formula for a
CustomOp. In the end state, we will likely have this explicit API and a
magic API (which is sugar on top of an explicit API), since different
parties of users prefer different ones.
Concretely, to define a backward formula for a CustomOp:
(inputs, output) and returns exactly what they want saved for backward
(ctx, saved, *grads) and returns us the grad_inputs. The grad_inputs
are returned as a dict mapping str to a gradient.
Please see the changes in custom_op_db.py for examples of the API.
There are a number of pieces to this PR and I'm happy to split it if it
helps. They are:
(impl_save_for_backward, impl_backward)
construct an autograd.Function object that we then register to
the Autograd dispatch key
that one can swap out the autograd kernel. This is necessary because by
default, we register an "autograd not implemented" kernel as the
Autograd implementation but then swap it for the actual kernel when the
user provides it.
Test Plan:
custom_op_db. We then hook up custom_op_db to the Autograd OpInfo tests.