In [1]:
import torch

### Bernoulli Distribution

Bernoulli distribution is a special case of binomial distribution, where n = 1, i.e., a single success
or fail type trial is performed

Let us reconsider the familiar dataset of photos described in the Binomial distribution notebook, which contains 20% of celebrity faces.  The probability of success (picking up a celebrity photo) in a single trial is 0.2 `(p=0.2)`. Probability of failure (not picking up a celebrity photo) is `1 - p = 0.8`.

Formally, $$P(X=1) = p$$  $$P(X=0) = 1 - p$$ where 1 represents success and 0 represents failure

This is demonstrated by the PyTorch code below

In [2]:
from torch.distributions import Bernoulli

# Set the parameters of the distribution
p = torch.tensor([0.2], dtype=torch.float)

# Instantiate the uniform distribution
bern_dist = Bernoulli(p)

In [3]:
# Instantiate single point test dataset
X = torch.tensor([1], dtype=torch.float)

# Function to evaluate log prob using math formula
def raw_eval(X, p):
    prob = p if X == 1 else 1-p
    return torch.log(prob)

# Evaluate log-prob using PyTorch distributions function call
log_prob = bern_dist.log_prob(X)
print("Log Prob: {:.3f}".format(log_prob[0]))

# Evaluate log-prob using formula
raw_eval_log_prob = raw_eval(X, p)
print("Raw eval Log Prob: {:.3f}".format(raw_eval_log_prob[0]))

assert torch.isclose(log_prob, raw_eval_log_prob, atol=1e-4)

Log Prob: -1.609
Raw eval Log Prob: -1.609


In [4]:
# Number of samples to draw
num_samples = 100000

# Draw samples
samples = bern_dist.sample([num_samples])

In [5]:
# The mean obtained from the samples
sample_mean = samples.mean()
print("Sample Mean: {}".format(sample_mean))

# The mean of the distribution from Pytorch
dist_mean = bern_dist.mean
print("Dist Mean: {:.3f}".format(dist_mean[0]))

# As expected, the two means approximately match
assert torch.isclose(sample_mean, dist_mean, atol=0.2)

# The variance obtained from the samples
sample_var = bern_dist.sample([num_samples]).var()

# The variance of the distribution from Pytorch
dist_var = bern_dist.variance

# As expected, the two variances approximately match
assert torch.isclose(sample_var, dist_var, atol=0.2)

Sample Mean: 0.2002899944782257
Dist Mean: 0.200
