Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow converting parameters of nn.Module to complex dtypes #44788

Closed
wants to merge 17 commits into from

Conversation

IvanYashchuk
Copy link
Collaborator

This PR makes it possible to cast the parameters of nn.Module to complex dtypes.
The following code works with the proposed changes.

In [1]: import torch
In [2]: lin = torch.nn.Linear(5, 1).to(torch.complex64)
In [3]: lin(torch.zeros(3, 5, dtype=torch.complex64))
Out[3]: 
tensor([[-0.1739+0.j],
        [-0.1739+0.j],
        [-0.1739+0.j]], grad_fn=<AddmmBackward>)

Fixes #43477.

@dr-ci
Copy link

dr-ci bot commented Sep 16, 2020

💊 CI failures summary and remediations

As of commit 0321193 (more details on the Dr. CI page):


  • 1/4 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)
  • 3/4 broken upstream at merge base 2f51ddb since Oct 20

🚧 3 ongoing upstream failures:

These were probably caused by upstream breakages that are not fixed yet:


ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 36 times.

@mruberry mruberry removed the request for review from apaszke September 17, 2020 03:16
@mruberry mruberry added module: complex Related to complex number support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: nn Related to torch.nn labels Sep 17, 2020
@mruberry
Copy link
Collaborator

mruberry commented Oct 7, 2020

Hey @IvanYashchuk, looks like this fell through the review cracks. Sorry about that and please ping us (or me) if you're not getting a response after a couple days.

This looks like a a cool, simple PR. I'd like to see the test moved to TestNNDeviceType and use the device generic test pattern so you don't have to query for CUDA in the middle of the test. The other thing I'd like to suggest is that we throw a warning if a module is converted to a complex dtype. Maybe something like, "Complex modules are a new feature, and many modules will not work as expected when using complex tensors as parameters or buffers."?

The reason I want to consider a warning like this is that I suspect that many more complicated modules just won't compute their forward or backward passes correctly with complex parameters. @anjali411, what are your thoughts?

@IvanYashchuk
Copy link
Collaborator Author

Thanks, Mike! I've moved the test to TestNNDeviceType and have added a warning.
I think the forward pass with complex parameters should work correctly if underlying functions are supported (I assume all new functions are being tested). As for the backward with PR #45461 only functions that are tested are allowed to be differentiated, so there should not be silent incorrect derivatives.

torch/nn/modules/module.py Outdated Show resolved Hide resolved
@mruberry
Copy link
Collaborator

mruberry commented Oct 8, 2020

This looks good to me. I think the test failures are in your and not your PR. Let's give @anjali411 a chance to look, too.

test/test_nn.py Outdated Show resolved Hide resolved
Copy link
Contributor

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall. left some comments about documentation updates at various places.
Should

This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
also include complex right?

@codecov
Copy link

codecov bot commented Oct 8, 2020

Codecov Report

Merging #44788 into master will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master   #44788   +/-   ##
=======================================
  Coverage   68.20%   68.20%           
=======================================
  Files         410      410           
  Lines       53453    53455    +2     
=======================================
+ Hits        36458    36460    +2     
  Misses      16995    16995           
Impacted Files Coverage Δ
torch/nn/modules/module.py 92.44% <100.00%> (+0.03%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update bbb3f09...560b45a. Read the comment docs.

@IvanYashchuk
Copy link
Collaborator Author

Should

This module supports :ref:`TensorFloat32<tf32_on_ampere>`.

also include complex right?

You mean we should update the docs of Linear? Do we then want to update docs of other modules as well?

@mruberry
Copy link
Collaborator

Should

This module supports :ref:`TensorFloat32<tf32_on_ampere>`.

also include complex right?

You mean we should update the docs of Linear? Do we then want to update docs of other modules as well?

No, let's not update the documentation for any individual module at this time.

@mruberry
Copy link
Collaborator

Hey @IvanYashchuk, this continues to look very good. I made two last doc requests: a minor edit to the warning message, and a separate complex example to better demonstrate complex modules. After that I think we're good to go!

@IvanYashchuk
Copy link
Collaborator Author

Thank you @mruberry! I've added the requested changes.

>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j, 0.2382+0.j],
Copy link
Collaborator

@mruberry mruberry Oct 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, how does this get initialized?

Oh, nevermind. I realize now that all the imaginary parts are zero. Which is lame for the moment. Maybe in the future we'll have complex initialization functions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight is initialized using kaiming_uniform_. Yes, if Linear had a dtype argument, then it could initialize complex-valued weights.

@mruberry mruberry self-requested a review October 12, 2020 13:50
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, @IvanYashchuk! LGTM! Let's let @anjali411 review, too.

@anjali411
Copy link
Contributor

anjali411 commented Oct 12, 2020

I tried running backward on a loss obtained using custom loss function with Linear model input with the PR changes but it errors out in forward ...

>>> import torch
>>> input1 = torch.randn(2, 3, dtype=torch.cdouble, requires_grad=True)
>>> target = torch.randn(2, 3, dtype=torch.cdouble, requires_grad=True)
>>> model = torch.nn.Linear(3, 3)
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> def loss(x, y):
...     return (x-y).abs().sum()
...
>>> loss(model(input1), target).backward()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/chourdiaanjali/pytorch2/torch/nn/modules/module.py", line 726, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/chourdiaanjali/pytorch2/torch/nn/modules/linear.py", line 93, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/chourdiaanjali/pytorch2/torch/nn/functional.py", line 1665, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: expected scalar type Float but found ComplexDouble

It's because the weight and bias assigned here https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L85-L90 is float by default. Just to be clear, fixing this might be out of the scope of this PR, but this is something we should think about. My first thoughts are that maybe we should set weight and bias to None and initialize them with correct dtype tensor the first time forward is called. Optionally we could also think of adding a kwarg argument dtype to give the user option to select the dtype for weight and bias parameters.

@mruberry
Copy link
Collaborator

I tried running backward on a loss obtained using custom loss function with Linear model input with the PR changes but it errors out in forward ...

>>> import torch
>>> input1 = torch.randn(2, 3, dtype=torch.cdouble, requires_grad=True)
>>> target = torch.randn(2, 3, dtype=torch.cdouble, requires_grad=True)
>>> model = torch.nn.Linear(3, 3)
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> def loss(x, y):
...     return (x-y).abs().sum()
...
>>> loss(model(input1), target).backward()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/chourdiaanjali/pytorch2/torch/nn/modules/module.py", line 726, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/chourdiaanjali/pytorch2/torch/nn/modules/linear.py", line 93, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/chourdiaanjali/pytorch2/torch/nn/functional.py", line 1665, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: expected scalar type Float but found ComplexDouble

It's because the weight and bias assigned here https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L85-L90 is float by default. Just to be clear, fixing this might be out of the scope of this PR, but this is something we should think about. My first thoughts are that maybe we should set weight and bias to None and initialize them with correct dtype tensor the first time forward is called. Optionally we could also think of adding a kwarg argument dtype to give the user option to select the dtype for weight and bias parameters.

The example shows converting the linear module to complex to address this, though?

@anjali411
Copy link
Contributor

The example shows converting the linear module to complex to address this, though?

hmm I see I missed that. However, that means that we always initialize the weight and bias values with zero imag values, since they were converted float to complex. I think it might be ok for this PR but it should be clearly documented. I will also look into research papers to see what the state of art is ...

@mruberry
Copy link
Collaborator

The example shows converting the linear module to complex to address this, though?

hmm I see I missed that. However, that means that we always initialize the weight and bias values with zero imag values, since they were converted float to complex. I think it might be ok for this PR but it should be clearly documented. I will also look into research papers to see what the state of art is ...

Correct. We have no complex initialization functions (yet). See this previous comment #44788 (comment).

A user can develop their own initialization scheme and apply it to the module, however, just like today.

Copy link
Contributor

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR looks good by itself.

There's more to think about:

  1. Creating modules with complex valued Parameters without having to go via real
  2. Initialization of complex valued module Parameters

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@albanD
Copy link
Collaborator

albanD commented Oct 14, 2020

I am not convinced that this is something we actually want.
In particular having a .to() call requiring to re-initialize the Parameters one way and dropping half of the information the other way does not sound right.

I am not sure that this is actually any better than having a separate class for ComplexLinear() (or CLinear) that will have the proper forward method that handles complex inputs properly and have the right initialization method.

Also I am wondering what is happening when you save/load such a Linear into another Linear of a different dtype? I think you will get uninitialized memory one way and drop half of the content of the state_dict the other way!

As discussed with @anjali411 offline, we are going to open an umbrella issue on how we're gonna add complex support to torch.nn to make sure we have a consistent result that works well in the end.

@albanD
Copy link
Collaborator

albanD commented Oct 15, 2020

For reference the issue discussing the design: #46374

'dtypes, but got desired dtype={}'.format(dtype))
if dtype.is_complex:
warnings.warn(
"Complex modules are a new feature, and some modules might not work as expected "
Copy link
Collaborator

@mruberry mruberry Oct 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested new wording to be clearer that the behavior may change in the future:

"new feature under active development whose design may change"

tensor([[0.6122+0.j, 0.1150+0.j],
[0.6122+0.j, 0.1150+0.j],
[0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)

Copy link
Collaborator

@mruberry mruberry Oct 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 TODO: update with discussion of populating imaginary part. (Idea, can take params imaginary part and use kaiming init on it, too.)

@anjali411
Copy link
Contributor

@IvanYashchuk can you please rebase?

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@anjali411 merged this pull request in 6de619e.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: complex Related to complex number support in PyTorch module: nn Related to torch.nn open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Complex Linear does not work
6 participants