Skip to content

Commit

Permalink
Add documentation for apply (#2327)
Browse files Browse the repository at this point in the history
  • Loading branch information
lantiga authored and soumith committed Aug 9, 2017
1 parent 9357b8f commit 1ac98b1
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions torch/nn/modules/module.py
Expand Up @@ -132,6 +132,34 @@ def _apply(self, fn):
return self

def apply(self, fn):
"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``)
as well as self. Typical use includes initializing the parameters of a model
(see also :ref:`torch-nn-init`).
Example:
>>> def init_weights(m):
>>> print(m)
>>> if type(m) == nn.Linear:
>>> m.weight.data.fill_(1.0)
>>> print(m.weight)
>>>
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
Linear (2 -> 2)
Parameter containing:
1 1
1 1
[torch.FloatTensor of size 2x2]
Linear (2 -> 2)
Parameter containing:
1 1
1 1
[torch.FloatTensor of size 2x2]
Sequential (
(0): Linear (2 -> 2)
(1): Linear (2 -> 2)
)
"""
for module in self.children():
module.apply(fn)
fn(self)
Expand Down

0 comments on commit 1ac98b1

Please sign in to comment.