New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dist_optim] add distributed functional Adam optimizer #50624
Conversation
Add TorchScript compatible Adam functional optimizer to distributed optimizer [ghstack-poisoned]
# Define a TorchScript compatible Functional Adam Optimizer | ||
# where we use these optimizer in a functional way. | ||
# Instead of using the `param.grad` when updating parameters, | ||
# we explicitly let the user pass gradients to the `step` function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify, so according to the comments below "user" is the DistributedOptimizer API, not RPC application user right? The call to optimizer should remain the same for RPC user?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's right, DistributedOptimizer API actually pass those grads to the step function, let me update the comment to clarify
+ f"Gradients length: {len(gradients)}" | ||
) | ||
|
||
for param, gradient in zip(self.param_group['params'], gradients): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General question, It looks like the similar code in torch/optim/adam.py
uses for p in group['params']
, and then accesses the grad with p.grad
. Although I'm assuming we can't do this since we need the grads explicitly, since dist autograd doesn't populate p.grad?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah in distributed autograd context, we don't populate p.grad, instead we call dist_autograd.get_gradients(autograd_ctx_id)
to get the list of gradients locally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, although I mostly compared the changes to torch/optim/adam.py
and torch/distributed/optim/functional_adagrad.py
and checked for parity. I don't have context on the changes in torch/optim/functional.py
so please get someone to look at that.
@pritamdamania87 Would be great if you get a chance to take a look at these changes as well.
# update the steps for each param group update | ||
state['step'] += 1 | ||
# record the step after step update | ||
state_steps.append(state['step'].item()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing all the logic up until the point we call F.adam
is aiming to emulate torch/optim/adam.py
, although is there any automated way to guarantee this? Could we dedupe the similar parts into helper functions and call those helper functions here? Alternatively, are we guaranteed that the dist optimizer tests will raise an error if the implementations diverge at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes they are similar indeed, but different across different optimizers as there need to be different states for each optimizer (and this is different from the original adam.py as well bc of TorchScript limitations on some syntax), so it's hard to generalized across them. Though, I think we can guaranteed the implementation will not diverge from the functional part as we shared the computation part, and test_optim
has a good coverage of it. On the state management side, do you think we should introduce some sort of flag to disable/enable the TorchScript support and compare the results in the test?
@@ -199,65 +198,12 @@ def test_dist_optim(self): | |||
self.assertEqual(new_w1, module1.get_w()) | |||
self.assertEqual(new_w2, module2.get_w()) | |||
|
|||
@dist_init() | |||
def test_dist_optim(self): | |||
self._test_dist_optim_base(optim.SGD, lr=0.05) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just move this to test_dist_optim_functional
?
Ah nvm, this is regular optimizer, not torchscripted
Add TorchScript compatible Adam functional optimizer to distributed optimizer Differential Revision: [D25932770](https://our.internmc.facebook.com/intern/diff/D25932770) [ghstack-poisoned]
Add TorchScript compatible Adam functional optimizer to distributed optimizer Differential Revision: [D25932770](https://our.internmc.facebook.com/intern/diff/D25932770) [ghstack-poisoned]
Add TorchScript compatible Adam functional optimizer to distributed optimizer Differential Revision: [D25932770](https://our.internmc.facebook.com/intern/diff/D25932770) [ghstack-poisoned]
Add TorchScript compatible Adam functional optimizer to distributed optimizer Differential Revision: [D25932770](https://our.internmc.facebook.com/intern/diff/D25932770) [ghstack-poisoned]
Stack from ghstack:
Add TorchScript compatible Adam functional optimizer to distributed optimizer
Differential Revision: D25932770