New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow converting parameters of nn.Module to complex dtypes #44788
Conversation
💊 CI failures summary and remediationsAs of commit 0321193 (more details on the Dr. CI page):
🚧 3 ongoing upstream failures:These were probably caused by upstream breakages that are not fixed yet:
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 36 times. |
Hey @IvanYashchuk, looks like this fell through the review cracks. Sorry about that and please ping us (or me) if you're not getting a response after a couple days. This looks like a a cool, simple PR. I'd like to see the test moved to TestNNDeviceType and use the device generic test pattern so you don't have to query for CUDA in the middle of the test. The other thing I'd like to suggest is that we throw a warning if a module is converted to a complex dtype. Maybe something like, "Complex modules are a new feature, and many modules will not work as expected when using complex tensors as parameters or buffers."? The reason I want to consider a warning like this is that I suspect that many more complicated modules just won't compute their forward or backward passes correctly with complex parameters. @anjali411, what are your thoughts? |
Thanks, Mike! I've moved the test to TestNNDeviceType and have added a warning. |
This looks good to me. I think the test failures are in your and not your PR. Let's give @anjali411 a chance to look, too. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall. left some comments about documentation updates at various places.
Should
pytorch/torch/nn/modules/linear.py
Line 37 in acca11b
This module supports :ref:`TensorFloat32<tf32_on_ampere>`. |
Codecov Report
@@ Coverage Diff @@
## master #44788 +/- ##
=======================================
Coverage 68.20% 68.20%
=======================================
Files 410 410
Lines 53453 53455 +2
=======================================
+ Hits 36458 36460 +2
Misses 16995 16995
Continue to review full report at Codecov.
|
You mean we should update the docs of |
No, let's not update the documentation for any individual module at this time. |
Hey @IvanYashchuk, this continues to look very good. I made two last doc requests: a minor edit to the warning message, and a separate complex example to better demonstrate complex modules. After that I think we're good to go! |
Thank you @mruberry! I've added the requested changes. |
>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) | ||
>>> linear.weight | ||
Parameter containing: | ||
tensor([[ 0.3741+0.j, 0.2382+0.j], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, how does this get initialized?
Oh, nevermind. I realize now that all the imaginary parts are zero. Which is lame for the moment. Maybe in the future we'll have complex initialization functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weight
is initialized using kaiming_uniform_. Yes, if Linear
had a dtype
argument, then it could initialize complex-valued weights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work, @IvanYashchuk! LGTM! Let's let @anjali411 review, too.
I tried running backward on a loss obtained using custom loss function with Linear model input with the PR changes but it errors out in forward ...
It's because the weight and bias assigned here https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L85-L90 is float by default. Just to be clear, fixing this might be out of the scope of this PR, but this is something we should think about. My first thoughts are that maybe we should set weight and bias to None and initialize them with correct dtype tensor the first time forward is called. Optionally we could also think of adding a kwarg argument dtype to give the user option to select the dtype for weight and bias parameters. |
The example shows converting the linear module to complex to address this, though? |
hmm I see I missed that. However, that means that we always initialize the weight and bias values with zero imag values, since they were converted float to complex. I think it might be ok for this PR but it should be clearly documented. I will also look into research papers to see what the state of art is ... |
Correct. We have no complex initialization functions (yet). See this previous comment #44788 (comment). A user can develop their own initialization scheme and apply it to the module, however, just like today. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this PR looks good by itself.
There's more to think about:
- Creating modules with complex valued Parameters without having to go via real
- Initialization of complex valued module Parameters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
I am not convinced that this is something we actually want. I am not sure that this is actually any better than having a separate class for Also I am wondering what is happening when you save/load such a Linear into another Linear of a different dtype? I think you will get uninitialized memory one way and drop half of the content of the state_dict the other way! As discussed with @anjali411 offline, we are going to open an umbrella issue on how we're gonna add complex support to torch.nn to make sure we have a consistent result that works well in the end. |
For reference the issue discussing the design: #46374 |
torch/nn/modules/module.py
Outdated
'dtypes, but got desired dtype={}'.format(dtype)) | ||
if dtype.is_complex: | ||
warnings.warn( | ||
"Complex modules are a new feature, and some modules might not work as expected " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggested new wording to be clearer that the behavior may change in the future:
"new feature under active development whose design may change"
tensor([[0.6122+0.j, 0.1150+0.j], | ||
[0.6122+0.j, 0.1150+0.j], | ||
[0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 TODO: update with discussion of populating imaginary part. (Idea, can take params imaginary part and use kaiming init on it, too.)
@IvanYashchuk can you please rebase? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@anjali411 merged this pull request in 6de619e. |
This PR makes it possible to cast the parameters of nn.Module to complex dtypes.
The following code works with the proposed changes.
Fixes #43477.