-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Add zero_grad(set_to_none=True) #42754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
SummaryAddress issue #41696 Test Plan:I leveraged the mnist example in the pytorch/example, manually changed the zero_grad() to reset_grad() and see if that affects the training precision: Before:
After:
|
💊 CI failures summary and remediationsAs of commit 82594d2 (more details on the Dr. CI page):
🕵️ 9 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
Optional suggestion: #41696 (comment) Regardless of which API you choose, you should apply similar treatment to |
torch/optim/optimizer.py
Outdated
for group in self.param_groups: | ||
for p in group['params']: | ||
if p.grad is not None: | ||
p.grad.detach_() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to detach_()
? Afaict the only way this has an effect is if something external also holds a reference to .grad
and .grad
was created in a create_graph=True
backward pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that sounds like a possible scenario so trying to be safe here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My point is detach_
may be the unsafe option. Because it affects p.grad
in place, it silently affects anything else holding a reference to p.grad
. Setting p.grad = None
without detach_
simply drops our reference. Anything else holding a reference to grad
will not see an effect.
Admittedly, I'm not sure why the default zeroing behavior of zero_grad
performs detach_
then zero_
. I assume it's to avoid building up spurious autograd history if the grad was created with create_graph=True
and therefore requires grad. (edit: related, avoids memory leak when grad has grad_fn)
If so, in the alternative set-to-None path, we dont need detach_
. We drop the reference to .grad
, we don't perform any ops on it, therefore there's no danger of building spurious autograd history. And if we don't detach_
, we don't risk silently affecting other references to grad
.
tl;dr I think not detaching is the safe option here.
@albanD you're good with these tricky cases...Why is detach_
used with the default zeroing behavior of zero_grad
? Also, do you agree not detaching is the safer implementation for the set-to-None path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original code is trying to make sure that we don't change the Tensor referenced by .grad
as much as possible I think.
And yes as mentioned below, I do agree that we don't want to detach here.
Thanks for running an example. Out of curiosity, did it change the runtime and memory? |
Responded about API on original issue. |
Differential Revision: [D23010859](https://our.internmc.facebook.com/intern/diff/D23010859) [ghstack-poisoned]
It doesn't seem to lead to noticeable difference in terms of run time, but again I am running a toy example here so probably hard to tell. Not sure how to measure the memory footprint though. |
Differential Revision: [D23010859](https://our.internmc.facebook.com/intern/diff/D23010859) [ghstack-poisoned]
Can you please modify |
Differential Revision: [D23010859](https://our.internmc.facebook.com/intern/diff/D23010859) [ghstack-poisoned]
torch/nn/modules/module.py
Outdated
if p.grad is not None: | ||
if p.grad.grad_fn is not None: | ||
if set_to_none: | ||
p.grad.detach_() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can remove this detach()
as you remove it anyways.
r"""Sets gradients of all model parameters to zero. | ||
Arguments: | ||
set_to_none (bool): instead of setting to zero, set the grad to None. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you link to the nn.optim
version of this doc that contains more details about the change of behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.optim
, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, typo!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add
See :meth:`torch.optim.optimizer.zero_grad` for details.
torch/optim/optimizer.py
Outdated
if p.grad is not None: | ||
if p.grad.grad_fn is not None: | ||
if set_to_none: | ||
p.grad.detach_() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unneeded detach here as well.
torch/nn/modules/module.py
Outdated
|
||
def zero_grad(self) -> None: | ||
r"""Sets gradients of all model parameters to zero.""" | ||
def zero_grad(self, set_to_none=False) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you add the type for set_to_none
: set_to_none: bool = False
torch/optim/optimizer.py
Outdated
|
||
def zero_grad(self): | ||
r"""Clears the gradients of all optimized :class:`torch.Tensor` s.""" | ||
def zero_grad(self, set_to_none=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update optimizer.pyi
as well to reflect this change? Just add the new arg and type there.
torch/optim/optimizer.py
Outdated
A None attribute or a Tensor full of 0s will be different. | ||
2. User can no longer rely on checking if `.grad` is None to see if a tensor | ||
is touched in the backward pass | ||
3. `nn.optim` optimizers have a different behavior if the gradient is 0 or None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.optim
torch/optim/optimizer.py
Outdated
Arguments: | ||
set_to_none (bool): instead of setting to zero, set the grad to None. | ||
This will in general have lower memory footprint, but using this | ||
comes with caveats, to name a few: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't say the changes are caveats. Some are benefits, like point 3 (skipping the update is faster if the grad is None).
Instead of
This will in general have lower memory footprint, but using this comes with caveats, to name a few:
I would say
This is will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example:
torch/optim/optimizer.py
Outdated
comes with caveats, to name a few: | ||
1. When user tries to access the gradient value and perform manual ops on it. | ||
A None attribute or a Tensor full of 0s will be different. | ||
2. User can no longer rely on checking if `.grad` is None to see if a tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't make sense. I'd say it's the opposite: if a user sets grads = None before backward, they CAN check if .grad
is None to see if the tensor received a gradient in the backward pass. Let's just describe the behavior:
If the user requests `zero_grad(set_to_none=True)` followed by a backward pass, `.grad`\ s are guaranteed to be None for params that did not receive a gradient.
Differential Revision: [D23010859](https://our.internmc.facebook.com/intern/diff/D23010859) [ghstack-poisoned]
Differential Revision: [D23010859](https://our.internmc.facebook.com/intern/diff/D23010859) [ghstack-poisoned]
torch/optim/optimizer.py
Outdated
Arguments: | ||
set_to_none (bool): instead of setting to zero, set the grad to None. | ||
This is will in general have lower memory footprint, and can modestly improve performance. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is will
-> This will
torch/optim/optimizer.py
Outdated
This is will in general have lower memory footprint, and can modestly improve performance. | ||
However, it changes certain behaviors. For example: | ||
1. When user tries to access the gradient value and perform manual ops on it. | ||
A None attribute or a Tensor full of 0s will be different. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
combine sentences
When user tries to access the gradient value and perform manual ops on it,
a None attribute or a Tensor full of 0s will act differently.
However, it changes certain behaviors. For example: | ||
1. When user tries to access the gradient value and perform manual ops on it. | ||
A None attribute or a Tensor full of 0s will be different. | ||
2. If the user requests `zero_grad(set_to_none=True)` followed by a backward pass, `.grad` s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Upon rereading this, I think we should switch the order of points 1 and 2. The current point 2 sets up some context for (what is currently) point 1, so I think it makes sense to say point 2 first.
Otherwise LGTM! Thanks for the PR, making this optimization a first-class citizen is very helpful.
@albanD, @vincentqb anything holding up this diff? Comments seem to be addressed (other than @mcarilli's doc suggestions which are imo minor). |
Nothing blocking on my side. Just the doc update. And make sure that the CI is happy with it after rebase. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving, subject to minor doc fixes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM too, I fixed two of the three doc changes that I see. I'm ok either way with the last one:
Upon rereading this, I think we should switch the order of points 1 and 2. The current point 2 sets up some context for (what is currently) point 1, so I think it makes sense to say point 2 first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#42754 (review) is not a big deal imo but don't forget https://github.com/pytorch/pytorch/pull/42754/files#r481400744.
Summary: Pull Request resolved: pytorch#44423 Pull Request resolved: pytorch#42754 Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D23010859 Pulled By: ngimel fbshipit-source-id: 760279f7c9cb84d11bef51207c18bf1f362ca7ad
Stack from ghstack:
Differential Revision: D23010859