-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.sign doesn't work for complex tensors #36323
Comments
Although I think it makes more sense to return the sign for both real and imag though. |
You're probably right that it would make more sense to do that, but because NumPy exists and is so popular it's already created a significant expectation of what sign() should do. Maybe the behavior you describe could be implemented with an optional kwarg or using a different function name? |
@rgommers do you have an opinion on this matter? |
I don't think there's a clear choice here. Mathematically I'd probably expect |
"Yes, and..." to @rgommers. If we're divergent from NumPy we should error. Wikipedia defines two sign functions for complex numbers: |
That does feel like namespace pollution for something that won't be used a lot and is not difficult to implement by hand for a user in the way they want. I'm starting to lean more to just erroring out and leaving it at that. |
hmm. I remember @mmuckley mentioned that |
Numpy has its own history discussing the matter: As mentioned in the threads, most software packages other than Numpy use x/|x|, but the implementation in Numpy is a "valid analytic continuation of the real sign function." In MRI we never use this mathematical property, but we use x/|x| all the time. I can't speak to it in other fields. It looks like from the second issue that Numpy is considering two sign functions. A few of us complex people are still in the process of coming over from Matlab where sign gives x/|x|, so the Numpy behavior was surprising to me. |
Since there are two valid implementations I think we have consensus to disable complex torch.sign for now to avoid user confusion. Thanks for your input, everyone! |
synced with @mmuckley offline, and we think that adding a method that returns In addition I would also like to add that we should disable cc @mmuckley |
@anjali411 I would like to work on |
|
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` TODO: 1. add tests for backward (waiting on gradcheck PR for complex) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` TODO: 1. add tests for backward (waiting on gradcheck PR for complex) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` This PR doesn't test the correctness of the gradients. It will be done as a part of auditing all the ops in future once we decide the autograd behavior (JAX vs TF) and add gradchek. [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` This PR doesn't test the correctness of the gradients. It will be done as a part of auditing all the ops in future once we decide the autograd behavior (JAX vs TF) and add gradchek. [ghstack-poisoned]
Hello everyone! I get the error while training a net that uses the following layer (which computes the response of an IIR filter from its # iir filter layer
class IIRFilter(torch.nn.Module):
def __init__(self, nfft):
super(IIRFilter, self).__init__()
self.nfft = nfft
def forward(self, b, a):
N = a.shape[0]
bn = torch.cat([b, torch.zeros(N, self.nfft-3)], -1)
an = torch.cat([a, torch.zeros(N, self.nfft-3)], -1)
y = torch.view_as_complex(torch.rfft(bn, 1)) / torch.view_as_complex(torch.rfft(an, 1))
y = y.abs()
return y I there otherwise a way I can implement the backward pass manually for the incriminating operations? (I am not interested in performances...yet) Thanks in advance! |
The error is caused because In the meantime, I think defining a backward function in the class |
If you want to manually implement, use the custom autograd api https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd |
Alright, I'll check the docs on implementing the backward pass while waiting for the new release. |
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` This PR doesn't test the correctness of the gradients. It will be done as a part of auditing all the ops in future once we decide the autograd behavior (JAX vs TF) and add gradchek. [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` This PR doesn't test the correctness of the gradients. It will be done as a part of auditing all the ops in future once we decide the autograd behavior (JAX vs TF) and add gradchek. [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` This PR doesn't test the correctness of the gradients. It will be done as a part of auditing all the ops in future once we decide the autograd behavior (JAX vs TF) and add gradchek. [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` This PR doesn't test the correctness of the gradients. It will be done as a part of auditing all the ops in future once we decide the autograd behavior (JAX vs TF) and add gradchek. Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` This PR doesn't test the correctness of the gradients. It will be done as a part of auditing all the ops in future once we decide the autograd behavior (JAX vs TF) and add gradchek. Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` This PR doesn't test the correctness of the gradients. It will be done as a part of auditing all the ops in future once we decide the autograd behavior (JAX vs TF) and add gradchek. Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` also updates the backward definition of `torch.div`, `torch.abs` Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` also updates the backward definition of `torch.div`, `torch.abs` Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` also updates the backward definition of `torch.div`, `torch.abs` Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` also updates the backward definition of `torch.div`, `torch.abs` Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` also updates the backward definition of `torch.div`, `torch.abs` Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
|
hey @MishinVD, see the code i was trying to get to work at #36323 (comment) I'm currently using the nightly build of pytorch, which implements a new # iir filter response layer, similar to freqz
class IIRFilterResponse(torch.nn.Module):
def __init__(self, nfft):
super(IIRFilterResponse, self).__init__()
self.nfft = nfft
self.eps = torch.finfo(torch.float32).eps
def forward(self, b, a):
b_fft = rfft(b, n=self.nfft, dim=1)
a_fft = rfft(a, n=self.nfft, dim=1)
yabs = (torch.abs(b_fft) + self.eps) / (torch.abs(a_fft) + self.eps)
assert torch.all(torch.isfinite(yabs))
return yabs Hope this helps. |
Numpy:
From numpy docs: For complex inputs, the sign function returns sign(x.real) + 0j if x.real != 0 else sign(x.imag) + 0j.
complex(nan, 0) is returned for complex nan inputs.
PyTorch
cc. @mruberry
cc @ezyang @anjali411 @dylanbespalko
The text was updated successfully, but these errors were encountered: