Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: diag does not support automatic differentiation for outputs with complex dtype. #48490

Closed
rjkilpatrick opened this issue Nov 26, 2020 · 8 comments
Labels
complex_autograd function request A request for a new function or the addition of new arguments/modes to an existing function. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@rjkilpatrick
Copy link
Contributor

馃悰 Bug

Autograd doesn't work for torch.diag for complex dtypes.

To Reproduce

Steps to reproduce the behavior:

>>> import torch
>>> x = torch.ones(1, dtype=torch.cdouble, requires_grad=True)
>>> y = torch.diag(x)
Traceback (most recent call last):
    File "<stdin>", line 1, in <module>
RuntimeError: diag does not support automatic differentiation for outputs with complex dtype.

Same holds for torch.diagflat.

Expected behaviour

Gradient operator of effective reshape operator.
i.e. equivalent to:

>>> x = torch.ones(1, dtype=torch.cdouble, requires_grad=True)
>>> def diag(x):
>>>     y = torch.zeros([len(x), len(x)])
>>>     for i in range(len(x)):
>>>         y[i, i] = x[i]
>>>     return y
>>> y = diag(x)
>>> y.backward()
>>> x.grad
tensor([1.+0.j], dtype=torch.complex128)

Environment

PyTorch version: 1.8.0.dev20201124
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Microsoft Windows 10 Enterprise
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Python version: 3.8 (64-bit runtime)
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] torch==1.8.0.dev20201124
[pip3] torchaudio==0.8.0.dev20201124
[pip3] torchvision==0.9.0.dev20201124
[conda] blas 1.0 mkl
[conda] cpuonly 1.0 0 pytorch-nightly
[conda] mkl 2020.2 256
[conda] mkl-service 2.3.0 py38h2bbff1b_0
[conda] mkl_fft 1.2.0 py38h45dec08_0
[conda] mkl_random 1.1.1 py38h47e9c7a_0
[conda] numpy 1.19.2 py38hadc3359_0
[conda] numpy-base 1.19.2 py38ha3acd2a_0
[conda] pytorch 1.8.0.dev20201124 py3.8_cpu_0 [cpuonly] pytorch-nightly
[conda] torchaudio 0.8.0.dev20201124 py38 pytorch-nightly [conda] torchvision 0.9.0.dev20201124 py38_cpu [cpuonly] pytorch-nightly

Additional context

@mruberry mruberry added complex_autograd triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module function request A request for a new function or the addition of new arguments/modes to an existing function. labels Nov 27, 2020
@mruberry
Copy link
Collaborator

Hey @rjkilpatrick! Thanks for reporting this issue. We would accept a PR adding this support.

@Rajathbharadwaj
Copy link

@rjkilpatrick @mruberry, Looks like torch.diag itself is not implemented for cdouble dtypes.

>>> x1 = torch.ones(1, dtype=torch.cdouble, requires_grad=False)
>>> y = torch.diag(x1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: "diag" not implemented for 'ComplexDouble'
>>> 

@rjkilpatrick
Copy link
Contributor Author

@Rajathbharadwaj, torch.diag was implemented for complex double in #47564.
Try using pytorch-nightly for a more up to date build.

I will attempt a PR over the weekend.

@Rajathbharadwaj
Copy link

Rajathbharadwaj commented Nov 27, 2020

Oh okay. Can either one be deprecated? Since both are essentially doing the same thing, and I found it confusing.
torch.diag and torch.diagflat

@rjkilpatrick @mruberry

@mruberry
Copy link
Collaborator

No, we cannot deprecate them. They're different functions and NumPy also has both:

https://numpy.org/doc/stable/reference/generated/numpy.diag.html?highlight=diag#numpy.diag
https://numpy.org/doc/stable/reference/generated/numpy.diagflat.html?highlight=diagflat#numpy.diagflat

Their names could probably be clearer, but one extracts a diagonal and the other creates a matrix with the given the diagonal.

@Rajathbharadwaj
Copy link

@rjkilpatrick @mruberry I also want to work on this, any pointers on how I can get started?

@mruberry
Copy link
Collaborator

@Rajathbharadwaj

@rjkilpatrick
Copy link
Contributor Author

Issue no longer present on nightly (1.11.0.dev20210926)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
complex_autograd function request A request for a new function or the addition of new arguments/modes to an existing function. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants