-
Notifications
You must be signed in to change notification settings - Fork 25.6k
torch.prod
backward for complex types.
#48125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2c2f400
48652db
b0179ea
eb178c7
8371d82
443a698
d0671a1
7219437
638f1ca
24baf84
ddcf407
120aa1f
2847c2d
56095df
c61df34
81a306b
7f1ad40
941db80
ecf539a
e962d3f
298aebd
6e6061f
ca7d279
d4787c2
8f08e16
b521637
91201d1
100280b
49a5b9a
7977e70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -466,7 +466,7 @@ Tensor prod_safe_zeros_backward(const Tensor &grad, const Tensor& inp, int64_t d | |
Tensor exclusive_reverse_nocp = at::cat({ones, narrow_reverse}, dim); | ||
Tensor exclusive_reverse = reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim); | ||
|
||
return grad * (exclusive_normal * exclusive_reverse); | ||
return grad * (exclusive_normal * exclusive_reverse).conj(); | ||
} | ||
|
||
// note that the gradient for prod is equivalent to: | ||
|
@@ -482,7 +482,7 @@ Tensor prod_backward(const Tensor& grad, const Tensor& input, const Tensor& resu | |
} | ||
Tensor zero_idx = (input == 0).nonzero(); | ||
if (zero_idx.numel() == 0) { | ||
return (grad * result) / input; | ||
return grad * (result / input).conj(); | ||
|
||
} else if (zero_idx.size(0) > 1) { | ||
return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); | ||
} else { | ||
|
@@ -504,7 +504,7 @@ Tensor prod_backward(Tensor grad, const Tensor& input, Tensor result, int64_t di | |
Tensor slice_zero_count = zero_mask.sum(dim, true); | ||
int64_t total_zeros = slice_zero_count.sum().item<int64_t>(); | ||
if (total_zeros == 0) { | ||
return (grad * result) / input; | ||
return grad * (result / input).conj(); | ||
} else { | ||
return prod_safe_zeros_backward(grad, input, dim); | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.