## Table of content

* Distribution shapes

* * Examples

* * Reshaping distributions

* * It is always safe to assume dependence

* Declaring independence with plate

* Subsampling inside plate

* Broadcasting to allow Parallel Enumeration

* * Writing parallelizable code

* * Automatic broadcasting inside pyro.plate


In [1]:
import os
import torch
import pyro
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.distributions.util import broadcast_shape
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.4')

# we'll use this helper to check our models are correct
def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)

### Distributions shapes: batch_shape and event_shape

PyTorch Tensors have a single .shape attribute, but Distributions have two shape attributions with special meaning: .batch_shape and .event_shape. These two combine to define the total shape of a sample.

```
x = d.sample()
assert x.shape == d.batch_shape + d.event_shape

```

Indices over .batch_shape denote conditionally independent random variables, whereas indices over .event_shape denote dependent random variables (ie one draw from a distribution). Because the dependent random variables define probability together, the .log_prob() method only produces a single number for each event of shape .event_shape. Thus the total shape of .log_prob() is .batch_shape:

```
assert d.log_prob(x).shape == d.batch_shape
```

Note that the Distribution.sample() method also takes a sample_shape parameter that indexes over independent identically distributed (iid) random varables, so that:

```
x2 = d.sample(sample_shape)
assert x2.shape == sample_shape + batch_shape + event_shape

```

in summary:

```
      |      iid     | independent | dependent
------+--------------+-------------+------------
shape = sample_shape + batch_shape + event_shape

```

For example univariate distributions have empty event shape (because each number is an independent event). Distributions over vectors like MultivariateNormal have len(event_shape) == 1. Distributions over matrices like InverseWishart have len(event_shape) == 2.

### Examples
The simplest distribution shape is a single univariate distribution.


In [3]:
d = Bernoulli(0.5)
assert d.batch_shape == ()
assert d.event_shape == ()
x = d.sample()
assert x.shape == ()
assert d.log_prob(x).shape == ()

In [4]:
'''
Distributions can be batched by passing in batched parameters.
'''

d = Bernoulli(0.5 * torch.ones(3, 4))
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

In [6]:
'''
Another way to batch distributions is via the .expand() method. This only works if parameters are identical along the leftmost dimensions.
'''

d = Bernoulli(torch.tensor([0.1, 0.2, 0.3, 0.4])).expand([3, 4])
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

In [None]:
'''
Multivariate distributions have nonempty .event_shape. For these distributions, the shapes of .sample() and .log_prob(x) differ:
'''

d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3, )
x = d.sample()
assert x.shape == (3, )             # == batch_shape + even_shape 
assert d.log_prob(x).shape == ()    # == batch shape