Skip to content

Add documentation for FeatureAlphaDropout #36295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
5 changes: 5 additions & 0 deletions docs/source/nn.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,11 @@ Dropout functions

.. autofunction:: alpha_dropout

:hidden:`feature_alpha_dropout`
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: feature_alpha_dropout

:hidden:`dropout2d`
~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 1 addition & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_namespace(ns, *skips):
'sparse_resize_and_clear_',
)
test_namespace(torch.nn)
test_namespace(torch.nn.functional, 'assert_int_or_pair', 'feature_alpha_dropout')
test_namespace(torch.nn.functional, 'assert_int_or_pair')
# TODO: add torch.* tests when we have proper namespacing on ATen functions
# test_namespace(torch)

Expand Down
19 changes: 19 additions & 0 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,25 @@ def dropout3d(input, p=0.5, training=True, inplace=False):

def feature_alpha_dropout(input, p=0.5, training=False, inplace=False):
# type: (Tensor, float, bool, bool) -> Tensor
r"""
Randomly masks out entire channels (a channel is a feature map,
e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input
is a tensor :math:`\text{input}[i, j]`) of the input tensor). Instead of
setting activations to zero, as in regular Dropout, the activations are set
to the negative saturation value of the SELU activation function.

Each element will be masked independently on every forward call with
probability :attr:`p` using samples from a Bernoulli distribution.
The elements to be masked are randomized on every forward call, and scaled
and shifted to maintain zero mean and unit variance.

See :class:`~torch.nn.FeatureAlphaDropout` for details.

Args:
p: dropout probability of a channel to be zeroed. Default: 0.5
training: apply dropout if is ``True``. Default: ``True``
inplace: If set to ``True``, will do this operation in-place. Default: ``False``
"""
if not torch.jit.is_scripting():
if type(input) is not Tensor and has_torch_function((input,)):
return handle_torch_function(
Expand Down
43 changes: 43 additions & 0 deletions torch/nn/modules/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,49 @@ def forward(self, input):


class FeatureAlphaDropout(_DropoutNd):
r"""Randomly masks out entire channels (a channel is a feature map,
e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input
is a tensor :math:`\text{input}[i, j]`) of the input tensor). Instead of
setting activations to zero, as in regular Dropout, the activations are set
to the negative saturation value of the SELU activation function. More details
can be found in the paper `Self-Normalizing Neural Networks`_ .

Each element will be masked independently for each sample on every forward
call with probability :attr:`p` using samples from a Bernoulli distribution.
The elements to be masked are randomized on every forward call, and scaled
and shifted to maintain zero mean and unit variance.

Usually the input comes from :class:`nn.AlphaDropout` modules.

As described in the paper
`Efficient Object Localization Using Convolutional Networks`_ ,
if adjacent pixels within feature maps are strongly correlated
(as is normally the case in early convolution layers) then i.i.d. dropout
will not regularize the activations and will otherwise just result
in an effective learning rate decrease.

In this case, :func:`nn.AlphaDropout` will help promote independence between
feature maps and should be used instead.

Args:
p (float, optional): probability of an element to be zeroed. Default: 0.5
inplace (bool, optional): If set to ``True``, will do this operation
in-place

Shape:
- Input: :math:`(N, C, D, H, W)`
- Output: :math:`(N, C, D, H, W)` (same shape as input)

Examples::

>>> m = nn.FeatureAlphaDropout(p=0.2)
>>> input = torch.randn(20, 16, 4, 32, 32)
>>> output = m(input)

.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
.. _Efficient Object Localization Using Convolutional Networks:
http://arxiv.org/abs/1411.4280
"""

def forward(self, input):
return F.feature_alpha_dropout(input, self.p, self.training)