Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvpvjp] Batch norm coverage with decomposition #877

Merged
merged 1 commit into from
Jun 17, 2022
Merged

Conversation

samdow
Copy link
Contributor

@samdow samdow commented Jun 15, 2022

Adds decomposition from #675 for forward over reverse coverage. Similar to with layer norm, we needed to recompute the mean and variance so autograd propagates properly (sad) and needed to return tensors of zeros instead of None (sad)

@samdow samdow force-pushed the batch_norm_decomp branch 2 times, most recently from 3b68942 to 825f439 Compare June 16, 2022 16:23
Comment on lines +208 to +215
grad_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp
else:
grad_weight = torch.zeros(()) # should be None but doesn't work with vjp

if output_mask[2]:
grad_bias = grad_output_sum
else:
grad_bias = torch.zeros_like(grad_output_sum) # should be None but doesn't work with vjp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sad, but is it what it is

@samdow samdow merged commit 347334c into main Jun 17, 2022
soulitzer added a commit to pytorch/pytorch that referenced this pull request Jul 11, 2022
…hen input requires grad"


We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.

Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`

Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch


[ghstack-poisoned]
soulitzer added a commit to pytorch/pytorch that referenced this pull request Jul 12, 2022
… and saved_var when input requires grad"


We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.

Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`

Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch


[ghstack-poisoned]
soulitzer added a commit to pytorch/pytorch that referenced this pull request Jul 12, 2022
…hen input requires grad"


We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.

Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`

Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch


[ghstack-poisoned]
soulitzer added a commit to pytorch/pytorch that referenced this pull request Jul 12, 2022
… and saved_var when input requires grad"


We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.

Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`

Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch


[ghstack-poisoned]
soulitzer added a commit to pytorch/pytorch that referenced this pull request Jul 12, 2022
…hen input requires grad"


We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.

Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`

Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch


[ghstack-poisoned]
zou3519 pushed a commit to zou3519/pytorch that referenced this pull request Jul 20, 2022
soulitzer added a commit to pytorch/pytorch that referenced this pull request Jul 21, 2022
… and saved_var when input requires grad"


We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.

Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`

Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch


[ghstack-poisoned]
soulitzer added a commit to pytorch/pytorch that referenced this pull request Jul 21, 2022
…hen input requires grad"


We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.

Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`

Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch


[ghstack-poisoned]
soulitzer added a commit to pytorch/pytorch that referenced this pull request Jul 21, 2022
… and saved_var when input requires grad"


We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.

Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`

Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch


[ghstack-poisoned]
soulitzer added a commit to pytorch/pytorch that referenced this pull request Jul 21, 2022
…hen input requires grad"


We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.

Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`

Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch


[ghstack-poisoned]
soulitzer added a commit to pytorch/pytorch that referenced this pull request Jul 21, 2022
… and saved_var when input requires grad"


We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.

Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`

Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch


[ghstack-poisoned]
soulitzer added a commit to pytorch/pytorch that referenced this pull request Jul 21, 2022
…hen input requires grad"


We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.

Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`

Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch


[ghstack-poisoned]
bigfootjon pushed a commit to pytorch/pytorch that referenced this pull request Jul 21, 2022
soulitzer added a commit to pytorch/pytorch that referenced this pull request Aug 4, 2022
… and saved_var when input requires grad"


We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.

Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`

Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch


[ghstack-poisoned]
soulitzer added a commit to pytorch/pytorch that referenced this pull request Aug 4, 2022
…hen input requires grad"


We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.

Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`

Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch


[ghstack-poisoned]
soulitzer added a commit to pytorch/pytorch that referenced this pull request Aug 4, 2022
… and saved_var when input requires grad"


We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.

Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`

Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch


[ghstack-poisoned]
soulitzer added a commit to pytorch/pytorch that referenced this pull request Aug 4, 2022
…hen input requires grad"


We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.

Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`

Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch


[ghstack-poisoned]
soulitzer added a commit to pytorch/pytorch that referenced this pull request Oct 10, 2022
… and saved_var when input requires grad"


We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.

Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`

Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch


[ghstack-poisoned]
soulitzer added a commit to pytorch/pytorch that referenced this pull request Oct 10, 2022
…hen input requires grad"


We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.

Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`

Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch


[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants