-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Make Adam optimizer differentiable #82205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit db97708 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
df14687
to
9af1df4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
torch/optim/adam.py
Outdated
@@ -278,7 +284,10 @@ def _single_tensor_adam(params: List[Tensor], | |||
|
|||
if amsgrad: | |||
# Maintains the maximum of all 2nd moment running avg. till now | |||
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) | |||
if differentiable: | |||
max_exp_avg_sqs[i][:] = torch.maximum(max_exp_avg_sqs[i].clone(), exp_avg_sq) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sqs[i], exp_avg_sq))
Also the perf hit compared to the out= call below is really minor. I think we can unconditionally use this new version for the sake of simplicity (same below). cc @jbschlosser what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing here is that we need to clone max_exp_avg_sqs
for this to work ...
Wonder if its ok to keep this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ho right maximum saves both its inputs.
@@ -2693,9 +2693,9 @@ def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored): | |||
p.grad = grad | |||
opt_differentiable_state = {k: v.clone() for k, v in opt_differentiable_state.items()} | |||
opt = opt_class([p], **kwargs) | |||
opt.state.update(opt_differentiable_state) | |||
opt.state[p].update(opt_differentiable_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was my mistake in the SGD PR, state is a dict of dicts where the key is the param.
b8a6f60
to
37c23c1
Compare
opt.step() | ||
return (p,) + tuple(opt_differentiable_state.values()) | ||
return (p,) + tuple(v for v in opt_differentiable_state.values() if v.requires_grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this actually needed? Gradcheck will filter out things that don't require gradients already
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually it didn't, lol step was returned and gradcheck detected errors with it
state = {} | ||
p = torch.rand(10, requires_grad=True, dtype=torch.float64) | ||
grad = torch.rand(10, requires_grad=True, dtype=torch.float64) | ||
state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not requires grad?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this causes the gradcheck to fail if its enabled ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ho this is the step. Why is this float and not long?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/pytorch/pytorch/blob/master/torch/optim/adam.py#L140-L141
It was defined as float in the optimizer itself
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok,
Can you leave a comment here that mentions that step is not a continuous variable (even though we define it as a float) and so it shouldn't require gradients.
btw can we have an nice assert in the adam()
function to ensure that step never requires grad?
torch/optim/adam.py
Outdated
@@ -278,7 +284,10 @@ def _single_tensor_adam(params: List[Tensor], | |||
|
|||
if amsgrad: | |||
# Maintains the maximum of all 2nd moment running avg. till now | |||
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) | |||
if differentiable: | |||
max_exp_avg_sqs[i][:] = torch.maximum(max_exp_avg_sqs[i].clone(), exp_avg_sq) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ho right maximum saves both its inputs.
torch/optim/adam.py
Outdated
@@ -278,7 +280,7 @@ def _single_tensor_adam(params: List[Tensor], | |||
|
|||
if amsgrad: | |||
# Maintains the maximum of all 2nd moment running avg. till now | |||
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) | |||
max_exp_avg_sqs[i][:] = torch.maximum(max_exp_avg_sqs[i].clone(), exp_avg_sq) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a very personal preference but I'm not very confident with advanced indexing assignment in general, can this be done with copy_()
to make sure of the behavior we get?
But we should also avoid the clone in general.
max_exp_avg_sqs[i][:] = torch.maximum(max_exp_avg_sqs[i].clone(), exp_avg_sq) | |
prev_exp_avg = max_exp_avg_sqs[i].clone() if differentiable else max_exp_avg_sqs[i] | |
max_exp_avg_sqs[i].copy_(torch.maximum(prev_exp_avg, exp_avg_sq)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the problem here is that if we pass the differentiable
argument, the JIT and other pieces of code will complain and we will make this non-compat since torchscript does not allow arguments with default values.
I wonder if we can store the flag in the thread local storage set with the differentiable context manager for this check?
Error if we don't clone
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.DoubleTensor [10]], which is output 0 of torch::autograd::CopyBackwards, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is only kwarg-only arguments with default values IIRC ?
You can make it positional as done in
Lines 172 to 175 in 1dfcad8
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 | |
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim | |
has_sparse_grad: bool = None, | |
foreach: bool = None, |
37c23c1
to
36f09c9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit on a clone but SGTM otherwise.
The CI needs fixing though no?
state = {} | ||
p = torch.rand(10, requires_grad=True, dtype=torch.float64) | ||
grad = torch.rand(10, requires_grad=True, dtype=torch.float64) | ||
state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok,
Can you leave a comment here that mentions that step is not a continuous variable (even though we define it as a float) and so it shouldn't require gradients.
btw can we have an nice assert in the adam()
function to ensure that step never requires grad?
torch/optim/adam.py
Outdated
@@ -305,7 +315,7 @@ def _single_tensor_adam(params: List[Tensor], | |||
|
|||
if amsgrad: | |||
# Maintains the maximum of all 2nd moment running avg. till now | |||
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) | |||
torch.maximum(max_exp_avg_sqs[i].clone(), exp_avg_sq, out=max_exp_avg_sqs[i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only if differentiable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was not needed, the differentiable case is now handled in the outer if :)
@@ -328,7 +338,7 @@ def _multi_tensor_adam(params: List[Tensor], | |||
weight_decay: float, | |||
eps: float, | |||
maximize: bool, | |||
capturable: bool): | |||
differentiable: bool): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The foreach ops don't support autograd. So we should have an assert here that differentiable is False (same for sgd btw).
This can be done in a separate PR though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the assert to this one, will fix SGD too in the next PR :)
36f09c9
to
b647df0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM !
Just a small nit about error meessage
torch/optim/adam.py
Outdated
@@ -152,7 +154,8 @@ def step(self, closure=None): | |||
|
|||
if group['amsgrad']: | |||
max_exp_avg_sqs.append(state['max_exp_avg_sq']) | |||
|
|||
if group['differentiable']: | |||
assert not state['step'].requires_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit make this a RuntimeError and add a nice error message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
b647df0
to
db97708
Compare
@pytorchbot merge |
@pytorchbot successfully started a merge job. Check the current status here. |
Hey @emcastillo. |
Summary: Continues [80938](#80938) Pull Request resolved: #82205 Approved by: https://github.com/albanD Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/5aab57e112d244f0cf3bbab30db640e52a0c2c44 Reviewed By: seemethere Differential Revision: D38788436 fbshipit-source-id: e35677b92267d068e044693acb9a7fcc96ed59c5
Blocked by #82205 Pull Request resolved: #83578 Approved by: https://github.com/albanD
Summary: Blocked by #82205 Pull Request resolved: #83578 Approved by: https://github.com/albanD Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/f0eb841d209f251d6a735827d4b903962d0d31b8 Reviewed By: seemethere Differential Revision: D38911145 Pulled By: seemethere fbshipit-source-id: 10bff92beba31fed5adacdf453b680c5acd8b19c
Continues 80938