-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Modifying Adam to support complex numbers as 2d real numbers #80279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 3e6f0a0 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code change looks ok but CI failures are real and need to be fixed.
63dbc12
to
92b5082
Compare
@@ -260,6 +260,12 @@ def _single_tensor_adam(params: List[Tensor], | |||
grad = grad.add(param, alpha=weight_decay) | |||
|
|||
# Decay the first and second moment running average coefficient | |||
if torch.is_complex(param): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment above needs to be moved down
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed!
torch/optim/adam.py
Outdated
@@ -307,6 +313,13 @@ def _single_tensor_adam(params: List[Tensor], | |||
|
|||
param.addcdiv_(exp_avg, denom, value=-step_size) | |||
|
|||
if torch.is_complex(param): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't do anything right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops. I see my error.
test/test_optim.py
Outdated
a1_imag = a1.imag.clone().detach() | ||
a1_real.requires_grad_() | ||
a1_imag.requires_grad_() | ||
a2 = torch.complex(a1_real, a1_imag) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This a2 is never used?
test/test_optim.py
Outdated
f(a1).backward() | ||
f(a2).backward() | ||
|
||
assert(torch.allclose(a1.grad.real, a1_real.grad)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use the builtin asserts on self.
to play nicely with the test suite in general.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
b7cd8cf
to
019ff97
Compare
@albanD does this look alright? |
@@ -320,6 +320,34 @@ def _test_complex_optimizer(self, optimizer_constructor): | |||
|
|||
self.assertEqual(torch.view_as_real(complex_param), real_param) | |||
|
|||
def _test_complex_2d(self, optimizer_constructor, f=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The f
argument is not actually needed right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to reuse this function for all tests involving the complex optimisers.
a1 = torch.randn(2, dtype=torch.complex64, requires_grad=True) | ||
a1_real = a1.real.clone().detach() | ||
a1_imag = a1.imag.clone().detach() | ||
a1_real.requires_grad_() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this can be inlined above if you want.
test/test_optim.py
Outdated
f(a1).backward() | ||
f(a2).backward() | ||
|
||
self.assertTrue(torch.allclose(a1.grad.real, a1_real.grad)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't self.assertEqual()
working here? It is doing a close check by default IIRC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
torch/optim/adam.py
Outdated
@@ -307,6 +314,13 @@ def _single_tensor_adam(params: List[Tensor], | |||
|
|||
param.addcdiv_(exp_avg, denom, value=-step_size) | |||
|
|||
if is_complex_param: | |||
grad = torch.view_as_complex(grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still not needed right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
torch/optim/adam.py
Outdated
@@ -404,4 +424,5 @@ def _multi_tensor_adam(params: List[Tensor], | |||
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) | |||
denom = torch._foreach_add(exp_avg_sq_sqrt, eps) | |||
|
|||
torch._foreach_addcdiv_(params, exp_avgs, denom, step_size) | |||
torch._foreach_addcdiv_(params_, exp_avgs, denom, step_size) | |||
params = [torch.view_as_complex(x) if torch.is_complex(params[i]) else x for i, x in enumerate(params_)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
params
is the original params_
has all complex tensors converted to reals.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is params
actually modified inplace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
params
is never updated inplace. params_
replaces it wholesale
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if it is never updated inplace, there is no need to restore it here and this line does nothing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be misunderstanding something, but doesn't params
need to hold the updated values. I do all the computation in params_
but that is only locally defined. For the update step to change the parameters, I assumed I have to specifically change params
. Certainly, if I don't the unit test fails for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've pushed changes that remove this line since I think I misunderstood the semantics of view_as_real
and view_as_complex
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the content of these lists are modified inplace!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change looks good.
Small nit about test organization but good otherwise.
@@ -566,27 +592,14 @@ def test_adam(self): | |||
lambda opt: ReduceLROnPlateau(opt)], | |||
constructor_accepts_maximize=True | |||
) | |||
self._test_complex_2d(optimizer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit surprised we're rolling out a custom test here. Why can't _test_complex_optimizer()
be re-used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They don't test quite the same things, but I can refactor these two functions into a single one. Is this a blocker for the PR being merged?
@pytorchbot merge |
@pytorchbot successfully started a merge job. Check the current status here |
Hey @zaxtax. |
…#80279) Summary: This commit addresses issues in #65711 Pull Request resolved: #80279 Approved by: https://github.com/albanD Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/f9ef363982136f45dfb2bd4205c545cb17e59afd Reviewed By: osalpekar Differential Revision: D38227584 Pulled By: osalpekar fbshipit-source-id: 48fbb9187124fe7d337e464f41f34ccc0d8b927b
This commit addresses issues in #65711