In [1]:
import torch
from torch import nn

In [2]:
nn.Dropout?

[0;31mInit signature:[0m [0mnn[0m[0;34m.[0m[0mDropout[0m[0;34m([0m[0mp[0m[0;34m:[0m [0mfloat[0m [0;34m=[0m [0;36m0.5[0m[0;34m,[0m [0minplace[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mFalse[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
During training, randomly zeroes some of the elements of the input
tensor with probability :attr:`p` using samples from a Bernoulli
distribution. Each channel will be zeroed out independently on every forward
call.

This has proven to be an effective technique for regularization and
preventing the co-adaptation of neurons as described in the paper
`Improving neural networks by preventing co-adaptation of feature
detectors`_ .

Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during
training. This means that during evaluation the module simply computes an
identity function.

Args:
    p: probability of an element to be zeroed. Default: 0.5
    inplace

In [7]:
class Dropout(nn.Module):
    def __init__(self, p: float = 0.5):
        super().__init__()
        self.p = p
        #multiply up by this to maintain activation total:
        self.factor = 1/(1-p)

    def forward(self, x):
        if self.training:
            #probability of each neuron _not_ being dropped
            probs = (1 - self.p) * torch.ones_like(x)
            actives = torch.bernoulli(probs)
            x = self.factor * actives * x
        return x

## Testing

In [10]:
p = .8
d, dt = [Dropout(p), nn.Dropout(p)]

In [11]:
data = torch.rand(4,3)

for f in [dt, d]:
    print(f(data))

tensor([[0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 4.0502, 0.0000],
        [4.0048, 0.0000, 0.0000]])
tensor([[0.0000, 0.0000, 0.7069],
        [0.0000, 2.9770, 0.0000],
        [0.0000, 0.0000, 4.2918],
        [0.0000, 3.9622, 0.0000]])


In [12]:
for f in [dt, d]:
    f.eval()
    print(f(data))

tensor([[0.9207, 0.9467, 0.1414],
        [0.2533, 0.5954, 0.4269],
        [0.7610, 0.8100, 0.8584],
        [0.8010, 0.7924, 0.5572]])
tensor([[0.9207, 0.9467, 0.1414],
        [0.2533, 0.5954, 0.4269],
        [0.7610, 0.8100, 0.8584],
        [0.8010, 0.7924, 0.5572]])
