Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dist_optim] add distributed functional Adam optimizer #50624

Closed
wants to merge 5 commits into from

Conversation

wanchaol
Copy link
Contributor

@wanchaol wanchaol commented Jan 15, 2021

Stack from ghstack:

Add TorchScript compatible Adam functional optimizer to distributed optimizer

Differential Revision: D25932770

Add TorchScript compatible Adam functional optimizer to distributed optimizer

[ghstack-poisoned]
# Define a TorchScript compatible Functional Adam Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
# we explicitly let the user pass gradients to the `step` function
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, so according to the comments below "user" is the DistributedOptimizer API, not RPC application user right? The call to optimizer should remain the same for RPC user?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's right, DistributedOptimizer API actually pass those grads to the step function, let me update the comment to clarify

+ f"Gradients length: {len(gradients)}"
)

for param, gradient in zip(self.param_group['params'], gradients):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question, It looks like the similar code in torch/optim/adam.py uses for p in group['params'], and then accesses the grad with p.grad. Although I'm assuming we can't do this since we need the grads explicitly, since dist autograd doesn't populate p.grad?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah in distributed autograd context, we don't populate p.grad, instead we call dist_autograd.get_gradients(autograd_ctx_id) to get the list of gradients locally.

Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, although I mostly compared the changes to torch/optim/adam.py and torch/distributed/optim/functional_adagrad.py and checked for parity. I don't have context on the changes in torch/optim/functional.py so please get someone to look at that.

@pritamdamania87 Would be great if you get a chance to take a look at these changes as well.

# update the steps for each param group update
state['step'] += 1
# record the step after step update
state_steps.append(state['step'].item())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing all the logic up until the point we call F.adam is aiming to emulate torch/optim/adam.py, although is there any automated way to guarantee this? Could we dedupe the similar parts into helper functions and call those helper functions here? Alternatively, are we guaranteed that the dist optimizer tests will raise an error if the implementations diverge at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes they are similar indeed, but different across different optimizers as there need to be different states for each optimizer (and this is different from the original adam.py as well bc of TorchScript limitations on some syntax), so it's hard to generalized across them. Though, I think we can guaranteed the implementation will not diverge from the functional part as we shared the computation part, and test_optim has a good coverage of it. On the state management side, do you think we should introduce some sort of flag to disable/enable the TorchScript support and compare the results in the test?

@@ -199,65 +198,12 @@ def test_dist_optim(self):
self.assertEqual(new_w1, module1.get_w())
self.assertEqual(new_w2, module2.get_w())

@dist_init()
def test_dist_optim(self):
self._test_dist_optim_base(optim.SGD, lr=0.05)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just move this to test_dist_optim_functional?
Ah nvm, this is regular optimizer, not torchscripted

Add TorchScript compatible Adam functional optimizer to distributed optimizer

Differential Revision: [D25932770](https://our.internmc.facebook.com/intern/diff/D25932770)

[ghstack-poisoned]
Add TorchScript compatible Adam functional optimizer to distributed optimizer

Differential Revision: [D25932770](https://our.internmc.facebook.com/intern/diff/D25932770)

[ghstack-poisoned]
Add TorchScript compatible Adam functional optimizer to distributed optimizer

Differential Revision: [D25932770](https://our.internmc.facebook.com/intern/diff/D25932770)

[ghstack-poisoned]
Add TorchScript compatible Adam functional optimizer to distributed optimizer

Differential Revision: [D25932770](https://our.internmc.facebook.com/intern/diff/D25932770)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

@wanchaol merged this pull request in 5cbe1e4.

@facebook-github-bot facebook-github-bot deleted the gh/wanchaol/156/head branch January 26, 2021 15:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants