Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,11 @@ def test_zero_grad(self):
self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())

# Force set to None.
module.zero_grad(set_to_none=True)
self.assertIsNone(module.weight.grad)


def test_no_grad(self):
for dtype in [torch.bfloat16, torch.float, torch.double]:
module = nn.Conv2d(2, 5, kernel_size=3, padding=1).to(dtype)
Expand Down
21 changes: 15 additions & 6 deletions torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,8 +1313,14 @@ def requires_grad_(self: T, requires_grad: bool = True) -> T:
p.requires_grad_(requires_grad)
return self

def zero_grad(self) -> None:
r"""Sets gradients of all model parameters to zero."""
def zero_grad(self, set_to_none: bool = False) -> None:
r"""Sets gradients of all model parameters to zero. See similar function
under `torch.optimizer` for more contexts.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

``torch.optimizer``


Arguments:
set_to_none (bool): instead of setting to zero, set the grad to None.
See :meth:`torch.optim.optimizer.zero_grad` for details.
"""
if getattr(self, '_is_replica', False):
warnings.warn(
"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. "
Expand All @@ -1324,11 +1330,14 @@ def zero_grad(self) -> None:

for p in self.parameters():
if p.grad is not None:
if p.grad.grad_fn is not None:
p.grad.detach_()
if set_to_none:
p.grad = None
else:
p.grad.requires_grad_(False)
p.grad.zero_()
if p.grad.grad_fn is not None:
p.grad.detach_()
else:
p.grad.requires_grad_(False)
p.grad.zero_()

def share_memory(self: T) -> T:
return self._apply(lambda t: t.share_memory_())
Expand Down
28 changes: 22 additions & 6 deletions torch/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,32 @@ def update_group(group, new_group):
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})

def zero_grad(self):
r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
def zero_grad(self, set_to_none: bool = False):
r"""Set the gradients of all optimized :class:`torch.Tensor` s to zero.

Arguments:
set_to_none (bool): instead of setting to zero, set the grad to None.
This is will in general have lower memory footprint, and can modestly improve performance.
However, it changes certain behaviors. For example:
1. When user tries to access the gradient value and perform manual ops on it.
A None attribute or a Tensor full of 0s will be different.
2. If the user requests `zero_grad(set_to_none=True)` followed by a backward pass, `.grad` s
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s

are guaranteed to be None for params that did not receive a gradient.
3. `torch.optim` optimizers have a different behavior if the gradient is 0 or None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

``torch.optim``

(in one case it does the step with a gradient of 0 and in the other it skip
the step altogether).
"""
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
if p.grad.grad_fn is not None:
p.grad.detach_()
if set_to_none:
p.grad = None
else:
p.grad.requires_grad_(False)
p.grad.zero_()
if p.grad.grad_fn is not None:
p.grad.detach_()
else:
p.grad.requires_grad_(False)
p.grad.zero_()

def step(self, closure):
r"""Performs a single optimization step (parameter update).
Expand Down
2 changes: 1 addition & 1 deletion torch/optim/optimizer.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ class Optimizer:
def __setstate__(self, statue: dict) -> None: ...
def state_dict(self) -> dict: ...
def load_state_dict(self, state_dict: dict) -> None: ...
def zero_grad(self) -> None: ...
def zero_grad(self, set_to_none: Optional[bool]=...) -> None: ...
def step(self, closure: Optional[Callable[[], float]]=...) -> Optional[float]: ...
def add_param_group(self, param_group: dict) -> None: ...