Skip to content

Conversation

zaxtax
Copy link
Contributor

@zaxtax zaxtax commented Jun 25, 2022

This commit addresses issues in #65711

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 25, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 3e6f0a0 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@mruberry mruberry added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 27, 2022
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code change looks ok but CI failures are real and need to be fixed.

@zaxtax zaxtax force-pushed the adam-2d-complex branch 2 times, most recently from 63dbc12 to 92b5082 Compare June 29, 2022 14:51
@zaxtax zaxtax force-pushed the adam-2d-complex branch from 92b5082 to 63ab89c Compare July 21, 2022 04:15
@@ -260,6 +260,12 @@ def _single_tensor_adam(params: List[Tensor],
grad = grad.add(param, alpha=weight_decay)

# Decay the first and second moment running average coefficient
if torch.is_complex(param):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment above needs to be moved down

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed!

@@ -307,6 +313,13 @@ def _single_tensor_adam(params: List[Tensor],

param.addcdiv_(exp_avg, denom, value=-step_size)

if torch.is_complex(param):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't do anything right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. I see my error.

a1_imag = a1.imag.clone().detach()
a1_real.requires_grad_()
a1_imag.requires_grad_()
a2 = torch.complex(a1_real, a1_imag)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This a2 is never used?

f(a1).backward()
f(a2).backward()

assert(torch.allclose(a1.grad.real, a1_real.grad))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the builtin asserts on self. to play nicely with the test suite in general.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@zaxtax zaxtax force-pushed the adam-2d-complex branch 4 times, most recently from b7cd8cf to 019ff97 Compare July 25, 2022 00:49
@zaxtax
Copy link
Contributor Author

zaxtax commented Jul 25, 2022

@albanD does this look alright?

@@ -320,6 +320,34 @@ def _test_complex_optimizer(self, optimizer_constructor):

self.assertEqual(torch.view_as_real(complex_param), real_param)

def _test_complex_2d(self, optimizer_constructor, f=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The f argument is not actually needed right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to reuse this function for all tests involving the complex optimisers.

a1 = torch.randn(2, dtype=torch.complex64, requires_grad=True)
a1_real = a1.real.clone().detach()
a1_imag = a1.imag.clone().detach()
a1_real.requires_grad_()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this can be inlined above if you want.

f(a1).backward()
f(a2).backward()

self.assertTrue(torch.allclose(a1.grad.real, a1_real.grad))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't self.assertEqual() working here? It is doing a close check by default IIRC

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@@ -307,6 +314,13 @@ def _single_tensor_adam(params: List[Tensor],

param.addcdiv_(exp_avg, denom, value=-step_size)

if is_complex_param:
grad = torch.view_as_complex(grad)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still not needed right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@@ -404,4 +424,5 @@ def _multi_tensor_adam(params: List[Tensor],
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
denom = torch._foreach_add(exp_avg_sq_sqrt, eps)

torch._foreach_addcdiv_(params, exp_avgs, denom, step_size)
torch._foreach_addcdiv_(params_, exp_avgs, denom, step_size)
params = [torch.view_as_complex(x) if torch.is_complex(params[i]) else x for i, x in enumerate(params_)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed either.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

params is the original params_ has all complex tensors converted to reals.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is params actually modified inplace?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

params is never updated inplace. params_ replaces it wholesale

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if it is never updated inplace, there is no need to restore it here and this line does nothing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be misunderstanding something, but doesn't params need to hold the updated values. I do all the computation in params_ but that is only locally defined. For the update step to change the parameters, I assumed I have to specifically change params. Certainly, if I don't the unit test fails for me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've pushed changes that remove this line since I think I misunderstood the semantics of view_as_real and view_as_complex

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the content of these lists are modified inplace!

@zaxtax zaxtax force-pushed the adam-2d-complex branch from 019ff97 to 3e6f0a0 Compare July 27, 2022 14:36
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change looks good.
Small nit about test organization but good otherwise.

@@ -566,27 +592,14 @@ def test_adam(self):
lambda opt: ReduceLROnPlateau(opt)],
constructor_accepts_maximize=True
)
self._test_complex_2d(optimizer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit surprised we're rolling out a custom test here. Why can't _test_complex_optimizer() be re-used?

Copy link
Contributor Author

@zaxtax zaxtax Jul 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They don't test quite the same things, but I can refactor these two functions into a single one. Is this a blocker for the PR being merged?

@zaxtax
Copy link
Contributor Author

zaxtax commented Jul 27, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Contributor

Hey @zaxtax.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Jul 28, 2022
…#80279)

Summary:
This commit addresses issues in #65711

Pull Request resolved: #80279
Approved by: https://github.com/albanD

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/f9ef363982136f45dfb2bd4205c545cb17e59afd

Reviewed By: osalpekar

Differential Revision: D38227584

Pulled By: osalpekar

fbshipit-source-id: 48fbb9187124fe7d337e464f41f34ccc0d8b927b
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants