Skip to content

Commit

Permalink
Fix prod double backward when there are 2+ zeros (#113969)
Browse files Browse the repository at this point in the history
Pull Request resolved: #113969
Approved by: https://github.com/albanD
  • Loading branch information
guilhermeleobas authored and pytorchmergebot committed Nov 21, 2023
1 parent 85ce8a6 commit 77f16eb
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 1 addition & 1 deletion torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ Tensor prod_backward(
Tensor zero_idx = (input == 0).nonzero();
if (zero_idx.sym_numel() == 0) {
return grad * (result / input).conj();
} else if (zero_idx.size(0) > 1) {
} else if (!at::GradMode::is_enabled() && zero_idx.size(0) > 1) {
return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else {
return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0)
Expand Down
1 change: 1 addition & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6000,6 +6000,7 @@ def prod_single_zero():

yield SampleInput(make_arg((3, 0)), args=(1,))
yield SampleInput(make_arg((3, 0)), args=(1,), kwargs={'keepdim': True})
yield SampleInput(torch.tensor([2., 3, 0, 0], dtype=dtype, device=device, requires_grad=requires_grad))

# test zero scalar tensor
zero = make_arg(())
Expand Down

0 comments on commit 77f16eb

Please sign in to comment.