Skip to content

Conversation

zou3519
Copy link
Contributor

@zou3519 zou3519 commented Jan 20, 2021

Stack from ghstack:

This PR does two things:

  1. Add batched grad testing to OpInfo
  2. Improve the error message from gradcheck if batched gradient
    computation fails to include suggestions for workarounds.

To add batched grad testing to OpInfo, this PR:

  • adds new check_batched_grad=True and check_batched_gradgrad=True
    attributes to OpInfo. These are True by default because we expect most
    operators to support batched gradient computation.
  • If check_batched_grad=True, then test_fn_grad invokes gradcheck
    with check_batched_grad=True.
  • If check_batched_gradgrad=True, then test_fn_gradgradgrad invokes
    gradgradcheck with check_batched_grad=True.

The improved gradcheck error message looks like the following when an
exception is thrown while computing batched gradients:
https://gist.github.com/zou3519/5a0f46f908ba036259ca5e3752fd642f

Future

  • Sometime in the not-near future, we will separate out "batched grad
    testing" from "gradcheck" for the purposes of OpInfo to make the
    testing more granular and also so that we can test that the vmap
    fallback doesn't get invoked (currently batched gradient testing only
    tests that the output values are correct).

Test Plan:

  • run tests pytest test/test_ops.py -v -k "Gradients"

Differential Revision: D25997703

This PR does two things:
1. Add batched grad testing to OpInfo
2. Improve the error message from `gradcheck` if batched gradient
computation fails to include suggestions for workarounds.

To add batched grad testing to OpInfo, this PR:
- adds new `check_batched_grad=True` and `check_batched_gradgrad=True`
attributes to OpInfo. These are True by default because we expect most
operators to support batched gradient computation.
- If `check_batched_grad=True`, then `test_fn_grad` invokes gradcheck
with `check_batched_grad=True`.
- If `check_batched_gradgrad=True`, then `test_fn_gradgradgrad` invokes
gradgradcheck with `check_batched_grad=True`.

The improved gradcheck error message looks like the following when an
exception is thrown while computing batched gradients:
https://gist.github.com/zou3519/5a0f46f908ba036259ca5e3752fd642f

Future
- Sometime in the not-near future, we will separate out "batched grad
testing" from "gradcheck" for the purposes of OpInfo to make the
testing more granular and also so that we can test that the vmap
fallback doesn't get invoked (currently batched gradient testing only
tests that the output values are correct).

Test Plan:
- run tests `pytest test/test_ops.py -v -k "Gradients"`

[ghstack-poisoned]
@zou3519 zou3519 requested a review from albanD as a code owner January 20, 2021 16:58
zou3519 added a commit that referenced this pull request Jan 20, 2021
This PR does two things:
1. Add batched grad testing to OpInfo
2. Improve the error message from `gradcheck` if batched gradient
computation fails to include suggestions for workarounds.

To add batched grad testing to OpInfo, this PR:
- adds new `check_batched_grad=True` and `check_batched_gradgrad=True`
attributes to OpInfo. These are True by default because we expect most
operators to support batched gradient computation.
- If `check_batched_grad=True`, then `test_fn_grad` invokes gradcheck
with `check_batched_grad=True`.
- If `check_batched_gradgrad=True`, then `test_fn_gradgradgrad` invokes
gradgradcheck with `check_batched_grad=True`.

The improved gradcheck error message looks like the following when an
exception is thrown while computing batched gradients:
https://gist.github.com/zou3519/5a0f46f908ba036259ca5e3752fd642f

Future
- Sometime in the not-near future, we will separate out "batched grad
testing" from "gradcheck" for the purposes of OpInfo to make the
testing more granular and also so that we can test that the vmap
fallback doesn't get invoked (currently batched gradient testing only
tests that the output values are correct).

Test Plan:
- run tests `pytest test/test_ops.py -v -k "Gradients"`

ghstack-source-id: 4bb191e
Pull Request resolved: #50818
@zou3519 zou3519 requested a review from mruberry January 20, 2021 22:13
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

One question I would have about all these functions for which batched grad check fails is: should we make these hard errors for now to make sure they don't return silently wrong gradients?

@zou3519
Copy link
Contributor Author

zou3519 commented Jan 21, 2021

Looks good to me.

One question I would have about all these functions for which batched grad check fails is: should we make these hard errors for now to make sure they don't return silently wrong gradients?

My understanding is that they are hard errors -- are they actually soft errors in the code right now? (Is that what happens when gradcheck runs with raise_exception=False?)

@albanD
Copy link
Collaborator

albanD commented Jan 21, 2021

My understanding is that they are hard errors

They are hard error when you run gradcheck because the gradients don't match (from my understanding, maybe I'm wrong here).
But if people try to use them, they will run without any error (and produce wrong gradients) <- I meant silent error here.

@facebook-github-bot
Copy link
Contributor

@zou3519 merged this pull request in 1669151.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants