diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 32dfb524298d..2ca5d1e72d3b 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -132,6 +132,34 @@ def _apply(self, fn): return self def apply(self, fn): + """Applies ``fn`` recursively to every submodule (as returned by ``.children()``) + as well as self. Typical use includes initializing the parameters of a model + (see also :ref:`torch-nn-init`). + + Example: + >>> def init_weights(m): + >>> print(m) + >>> if type(m) == nn.Linear: + >>> m.weight.data.fill_(1.0) + >>> print(m.weight) + >>> + >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) + >>> net.apply(init_weights) + Linear (2 -> 2) + Parameter containing: + 1 1 + 1 1 + [torch.FloatTensor of size 2x2] + Linear (2 -> 2) + Parameter containing: + 1 1 + 1 1 + [torch.FloatTensor of size 2x2] + Sequential ( + (0): Linear (2 -> 2) + (1): Linear (2 -> 2) + ) + """ for module in self.children(): module.apply(fn) fn(self)