-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Add batched grad testing to OpInfo #50818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR does two things: 1. Add batched grad testing to OpInfo 2. Improve the error message from `gradcheck` if batched gradient computation fails to include suggestions for workarounds. To add batched grad testing to OpInfo, this PR: - adds new `check_batched_grad=True` and `check_batched_gradgrad=True` attributes to OpInfo. These are True by default because we expect most operators to support batched gradient computation. - If `check_batched_grad=True`, then `test_fn_grad` invokes gradcheck with `check_batched_grad=True`. - If `check_batched_gradgrad=True`, then `test_fn_gradgradgrad` invokes gradgradcheck with `check_batched_grad=True`. The improved gradcheck error message looks like the following when an exception is thrown while computing batched gradients: https://gist.github.com/zou3519/5a0f46f908ba036259ca5e3752fd642f Future - Sometime in the not-near future, we will separate out "batched grad testing" from "gradcheck" for the purposes of OpInfo to make the testing more granular and also so that we can test that the vmap fallback doesn't get invoked (currently batched gradient testing only tests that the output values are correct). Test Plan: - run tests `pytest test/test_ops.py -v -k "Gradients"` [ghstack-poisoned]
This PR does two things: 1. Add batched grad testing to OpInfo 2. Improve the error message from `gradcheck` if batched gradient computation fails to include suggestions for workarounds. To add batched grad testing to OpInfo, this PR: - adds new `check_batched_grad=True` and `check_batched_gradgrad=True` attributes to OpInfo. These are True by default because we expect most operators to support batched gradient computation. - If `check_batched_grad=True`, then `test_fn_grad` invokes gradcheck with `check_batched_grad=True`. - If `check_batched_gradgrad=True`, then `test_fn_gradgradgrad` invokes gradgradcheck with `check_batched_grad=True`. The improved gradcheck error message looks like the following when an exception is thrown while computing batched gradients: https://gist.github.com/zou3519/5a0f46f908ba036259ca5e3752fd642f Future - Sometime in the not-near future, we will separate out "batched grad testing" from "gradcheck" for the purposes of OpInfo to make the testing more granular and also so that we can test that the vmap fallback doesn't get invoked (currently batched gradient testing only tests that the output values are correct). Test Plan: - run tests `pytest test/test_ops.py -v -k "Gradients"` ghstack-source-id: 4bb191e Pull Request resolved: #50818
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
One question I would have about all these functions for which batched grad check fails is: should we make these hard errors for now to make sure they don't return silently wrong gradients?
My understanding is that they are hard errors -- are they actually soft errors in the code right now? (Is that what happens when gradcheck runs with raise_exception=False?) |
They are hard error when you run gradcheck because the gradients don't match (from my understanding, maybe I'm wrong here). |
Stack from ghstack:
This PR does two things:
gradcheck
if batched gradientcomputation fails to include suggestions for workarounds.
To add batched grad testing to OpInfo, this PR:
check_batched_grad=True
andcheck_batched_gradgrad=True
attributes to OpInfo. These are True by default because we expect most
operators to support batched gradient computation.
check_batched_grad=True
, thentest_fn_grad
invokes gradcheckwith
check_batched_grad=True
.check_batched_gradgrad=True
, thentest_fn_gradgradgrad
invokesgradgradcheck with
check_batched_grad=True
.The improved gradcheck error message looks like the following when an
exception is thrown while computing batched gradients:
https://gist.github.com/zou3519/5a0f46f908ba036259ca5e3752fd642f
Future
testing" from "gradcheck" for the purposes of OpInfo to make the
testing more granular and also so that we can test that the vmap
fallback doesn't get invoked (currently batched gradient testing only
tests that the output values are correct).
Test Plan:
pytest test/test_ops.py -v -k "Gradients"
Differential Revision: D25997703