Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autograd not working for torch.exp(1j * phase) #43349

Closed
jonashaag opened this issue Aug 20, 2020 · 7 comments
Closed

Autograd not working for torch.exp(1j * phase) #43349

jonashaag opened this issue Aug 20, 2020 · 7 comments
Labels
high priority module: autograd Related to torch.autograd, and the autograd engine in general module: complex Related to complex number support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jonashaag
Copy link
Contributor

jonashaag commented Aug 20, 2020

馃悰 Bug

Complex tensor construction from magnitude and phase does not seem to support autograd when using mag * torch.exp(1j * phase) notation:

import torch
mag, phase = torch.tensor(5., requires_grad=True), torch.tensor(3., requires_grad=True)

complex_good = torch.view_as_complex(torch.stack([mag * torch.cos(phase), mag * torch.sin(phase)], dim=-1))
complex_good.backward()  # works

complex_bad = mag * torch.exp(1j * phase)
complex_bad.backward()

=>

.../torch/autograd/__init__.py:125: UserWarning: Complex backward is not fully supported yet and could lead to wrong gradients for functions we have not fixed yet (Triggered internally at  /opt/conda/conda-bld/pytorch_1597820903894/work/torch/csrc/autograd/python_engine.cpp:172.)
  Variable._execution_engine.run_backward(
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File ".../torch/tensor.py", line 214, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File ".../torch/autograd/__init__.py", line 125, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Expected isFloatingType(grad.scalar_type()) || (input_is_complex == grad_is_complex) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
...

Torch version: 1.7.0.dev20200819

cc @ezyang @gchanan @zou3519 @ssnl @albanD @gqchen @anjali411 @dylanbespalko

@pbelevich pbelevich added module: autograd Related to torch.autograd, and the autograd engine in general module: complex Related to complex number support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 20, 2020
@albanD
Copy link
Collaborator

albanD commented Aug 21, 2020

As mentioned in the warning there, complex autograd is not fully supported right now.

My guess here is that the formula for one of these ops has not been updated yet and does not produce a complex gradient like it should.

@anjali411 you might want to open a single umbrella issue that tracks these as I expect there will be more...

@anjali411
Copy link
Contributor

anjali411 commented Aug 25, 2020

@jonashaag thanks for reporting the issue! It's because the MulBackward's behavior for complex is problematic. According to the current formula, when a real tensor is multiplied by a complex scalar, a complex valued gradient is returned instead of the expected real valued gradient.

>>> import torch
x=t>>> x=torch.randn(4, dtype=torch.cdouble)
>>> x=torch.randn(4, dtype=torch.cdouble, requires_grad=True)
>>> x=torch.randn(4, requires_grad=True)
>>> y=x*1j
>>> y
tensor([-0.-1.1449j, 0.+0.5593j, 0.+1.7524j, 0.+0.1101j],
       grad_fn=<MulBackward0>)
>>> y.sum().backward()
/home/chourdiaanjali/pytorch2/torch/autograd/__init__.py:127: UserWarning: Complex backward is not fully supported yet and could lead to wrong gradients for functions we have not fixed yet (Triggered internally at  ../torch/csrc/autograd/python_engine.cpp:172.)
  allow_unreachable=True)  # allow_unreachable flag
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/chourdiaanjali/pytorch2/torch/tensor.py", line 214, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/chourdiaanjali/pytorch2/torch/autograd/__init__.py", line 127, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Expected isFloatingType(grad.scalar_type()) || (input_is_complex == grad_is_complex) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

As Alban mentioned, complex backward is not fully supported yet but many of these functions will be fixed soon.

@anjali411 you might want to open a single umbrella issue that tracks these as I expect there will be more...

I have been adding these issues in the Autograd Tasks subsection in #33152

@cshewmake2
Copy link

Having the same issue as above. How's this going? Any way I can help?

@mountain
Copy link

same issue +1

@IvanYashchuk
Copy link
Collaborator

This was fixed in #43208.

@rjkilpatrick
Copy link
Contributor

This is still an issue on nightly

>>> import torch
>>> print(torch.__version__)
>>> mag = torch.tensor(5., requires_grad=True, dtype=torch.complex128)
>>> phase = torch.tensor(3., requires_grad=True, dtype=torch.complex128)
>>>
>>> complex_good = mag * (torch.cos(phase) + 1.j * torch.sin(phase))
>>> complex_good.backward() # works
>>>
>>> complex_bad = mag * torch.exp(1j * phase)
>>> complex_bad.backward()

1.8.0.dev20201101
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-20-98d908b62931> in <module>
      8 complex_good.backward()  # works
      9 
---> 10 complex_bad = mag * torch.exp(1j * phase)
     11 complex_bad.backward()

RuntimeError: exp does not support automatic differentiation for outputs with complex dtype.

pytorch/test/test_autograd.py

Lines 4935 to 4941 in 1cc1da5

complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone',
'repeat', 'expand', 'flip', 'fliplr', 'flipud', 'rot90', 'transpose',
'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu',
'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_',
'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh',
'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split', 'matmul',
'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle', 'tanh', 'fill_', 'sub'] + separate_complex_tests

torch.exp is not being tested for autograd backwards on master

@anjali411
Copy link
Contributor

fixed in #47194

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: autograd Related to torch.autograd, and the autograd engine in general module: complex Related to complex number support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

9 participants