# Tensor shapes in Pyro 0.2

This tutorial introduces Pyro's organization of tensor dimensions. Before starting, you should familiarize yourself with [PyTorch broadcasting semantics](http://pytorch.org/docs/master/notes/broadcasting.html).

#### Summary:
- While you are learning or debugging, set `pyro.enable_validation(True)`.
- Tensors broadcast by aligning on the right: `torch.ones(3,4,5) + torch.ones(5)`.
- Distribution `.sample().shape == batch_shape + event_shape`.
- Distribution `.log_prob(x).shape == batch_shape` (but not `event_shape`!).
- Use `my_dist.reshape([2,3,4])` to draw a batch of samples.
- Use `my_dist.reshape(extra_event_dims=1)` to declare a dimension as dependent.
- Use `with pyro.iarange('name', size):` to declare a dimension as independent.
- All dimensions must be declared either dependent or independent.
- Try to support batching on the left. This lets Pyro auto-parallelize.
  - use negative indices like `x.sum(-1)` rather than `x.sum(2)`
  - use ellipsis notation like `pixel = image[..., i, j]`
  
#### Table of Contents
- [Distribution shapes](#dist-shapes)
  - [Examples](#dist-examples)
  - [Reshaping distributions](#dist-reshape)
- [Declaring independence with `iarange`](#iarange)
- [Subsampling inside `iarange`](#subsampling)
- [Broadcasting to allow Parallel Enumeration](#subsampling)

In [1]:
import os
import torch
import pyro
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.infer import ELBO, config_enumerate
from pyro.optim import Adam

smoke_test = ('CI' in os.environ)
pyro.enable_validation(True)    # <---- This is always a good idea!

def test_model(model, guide=None, **kwargs):
    pyro.clear_param_store()
    ELBO.make(**kwargs).loss(model, model if guide is None else guide)

## Distributions shapes: `batch_shape` and `event_shape` <a class="anchor" id="dist-shapes"></a>

PyTorch `Tensor`s have a single `.shape` attribute, but `Distribution`s have two shape attributions with special meaning: `.batch_shape` and `.event_shape`. These two combine to define the total shape of a sample
```py
x = d.sample()
assert x.shape == d.batch_shape + d.event_shape
```
Indices over `.batch_shape` denote independent random variables, whereas indices over `.event_shape` denote dependent random variables. Because the dependent random variables define probability together, the `.log_prob()` method only produces a single number for each event of shape `.event_shape`. Thus the total shape of `.log_prob()` is `.batch_shape`:
```py
assert d.log_prob(x).shape == d.batch_shape
```
Note that the `.sample()` method also takes `sample_shape` parameter that indexes over independent identically distributed (iid) random varables, so that
```py
x2 = d.sample(sample_shape)
assert x2.shape == sample_shape + batch_shape + event_shape
```
In summary
```
      |      iid     | independent | dependent
------+--------------+-------------+------------
shape = sample_shape + batch_shape + event_shape
```
Note that univariate distributions have empty event shape (because each number is an independent event). Distributions over vectors like `MultivariateNormal` have `len(event_shape) == 1`. Distributions over matrices like `InverseWishart` have `len(event_sahpe) == 2`.

### Examples <a class="anchor" id="dist-examples"></a>

The simplest distribution shape is a single univariate distribution.

In [2]:
d = Bernoulli(0.5)
assert d.batch_shape == ()
assert d.event_shape == ()
x = d.sample()
assert x.shape == ()
assert d.log_prob(x).shape == ()

Distributions can be batched by passing in batched parameters.

In [3]:
d = Bernoulli(0.5 * torch.ones(3,4))
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

Another way to batch distributions is via the `.reshape()` method. This only works if 
parameters are identical along the leftmost dimensions.

In [4]:
d = Bernoulli(torch.tensor([0.1, 0.2, 0.3, 0.4])).reshape([3])
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

Multivariate distributions have nonempty `.event_shape`. For these distributions, the shapes of `.sample()` and `.log_prob(x)` differ:

In [5]:
d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3,)
x = d.sample()
assert x.shape == (3,)            # == batch_shape + event_shape
assert d.log_prob(x).shape == ()  # == batch_shape

### Reshaping distributions <a class="anchor" id="dist-reshape"></a>

In Pyro you can treat a univariate distribution as multivariate by calling the `.reshape(extra_event_dims=_)` property.

In [6]:
d = Bernoulli(0.5 * torch.ones(3,4)).reshape(extra_event_dims=1)
assert d.batch_shape == (3,)
assert d.event_shape == (4,)
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3,)

While you work with Pyro programs, keep in mind that samples have shape `batch_shape + event_shape`, whereas `.log_prob(x)` values have shape `batch_shape`. You'll need to ensure that `batch_shape` is carefully controlled by either trimming it down with `.reshape(extra_event_dims=n)` or by declaring dimensions as independent via `pyro.iarange`.

## Declaring independent dims with `iarange` <a class="anchor" id="iarange"></a>

Pyro models can use the context manager [pyro.iarange](http://docs.pyro.ai/en/dev/primitives.html#pyro.__init__.iarange) to declare that certain batch dimensions are independent. Inference algorithms can then take advantage of this independence to e.g. construct lower variance gradient estimators or to enumerate in linear space rather than exponential space. An example of an independent dimension is the index over data in a minibatch: each datum should be independent of all others.

The simplest way to declare a dimension as independent is to declare the rightmost batch dimension as independent via a simple
```py
with pyro.iarange("my_iarange"):
    # within this context, batch dimension -1 is independent
```
We recommend always providing an optional size argument to aid in debugging shapes
```py
with pyro.iarange("my_iarange", len(my_data)):
    # within this context, batch dimension -1 is independent
```
Starting with Pyro 0.2 you can additionally nest `iaranges`, e.g. if you have per-pixel independence:
```py
with pyro.iarange("x_axis", 320):
    # within this context, batch dimension -1 is independent
    with pyro.iarange("y_axis", 200):
        # within this context, batch dimensions -2 and -1 are independent
```
Note that we always count from the right by using negative indices like -2, -1.

Finally if you want to mix and match `iarange`s for e.g. noise that depends only on `x`, some noise that depends only on `y`, and some noise that depends on both, you can declare multiple `iaranges` and use them as reusable context managers. In this case Pyro cannot automatically allocate a dimension, so you need to provide a `dim` argument (again counting from the right):
```py
x_axis = pyro.iarange("x_axis", 3, dim=-2)
y_axis = pyro.iarange("y_axis", 2, dim=-3)
with x_axis:
    # within this context, batch dimension -2 is independent
with y_axis:
    # within this context, batch dimension -3 is independent
with x_axis, y_axis:
    # within this context, batch dimensions -3 and -2 are independent
```
Let's take a closer look at batch sizes within `iarange`s.

In [7]:
def model1():
    a = pyro.sample("a", Normal(0, 1))
    b = pyro.sample("b", Normal(torch.zeros(2), 1).reshape(extra_event_dims=1))
    with pyro.iarange("c_iarange", 2):
        c = pyro.sample("c", Normal(torch.zeros(2), 1))
    with pyro.iarange("d_iarange", 3):
        d = pyro.sample("d", Normal(torch.zeros(3,4,5), 1).reshape(extra_event_dims=2))
    assert a.shape == ()       # batch_shape == ()     event_shape == ()
    assert b.shape == (2,)     # batch_shape == ()     event_shape == (2,)
    assert c.shape == (2,)     # batch_shape == (2,)   event_sahpe == ()
    assert d.shape == (3,4,5)  # batch_shape == (3,)   event_shape == (4,5) 

    x_axis = pyro.iarange("x_axis", 3, dim=-2)
    y_axis = pyro.iarange("y_axis", 2, dim=-3)
    with x_axis:
        x = pyro.sample("x", Normal(0, 1).reshape([3, 1]))
    with y_axis:
        y = pyro.sample("y", Normal(0, 1).reshape([2, 1, 1]))
    with x_axis, y_axis:
        xy = pyro.sample("xy", Normal(0, 1).reshape([2, 3, 1]))
        z = pyro.sample("z", Normal(0, 1).reshape([2, 3, 1, 5], 1))
    assert x.shape == (3, 1)        # batch_shape == (3,1)     event_shape == ()
    assert y.shape == (2, 1, 1)     # batch_shape == (2,1,1)   event_shape == ()
    assert xy.shape == (2, 3, 1)    # batch_shape == (2,3,1)   event_shape == ()
    assert z.shape == (2, 3, 1, 5)  # batch_shape == (2,3,1)   event_shape == (5,)
    
test_model(model1)

It is helpful to visualize the `.shape`s of each sample site by aligning them at the boundary between `batch_shape` and `event_shape`: dimensions to the right will be summed out in `.log_prob()` and dimensions to the left will remain. 
```
batch dims | event dims
-----------+-----------
           |        a = sample("a", Normal(0, 1))
           |2       b = sample("b", Normal(zeros(2), 1)
           |                        .reshape(extra_event_dims=1)
           |        with iarange("c", 2):
          2|            c = sample("c", Normal(zeros(2), 1))
           |        with iarange("d", 3):
          3|4 5         d = sample("d", Normal(zeros(3,4,5), 1)
           |                       .reshape(extra_event_dims=2)
           |
           |        x_axis = iarange("x", 3, dim=-2)
           |        y_axis = iarange("y", 2, dim=-3)
           |        with x_axis:
        3 1|            x = sample("x", Normal(0, 1).reshape([3, 1]))
           |        with y_axis:
      2 1 1|            y = sample("y", Normal(0, 1).reshape([2, 1, 1]))
           |        with x_axis, y_axis:
      2 3 1|            xy = sample("xy", Normal(0, 1).reshape([2, 3, 1]))
      2 3 1|5           z = sample("z", Normal(0, 1).reshape([2, 3, 1, 5], 1))
```
As an exercise, try to tabulate the shapes of sample sites in one of your own programs.

## Subsampling tensors inside an `iarange`    <a class="anchor" id="subsampling"></a>

One of the main uses of [iarange](http://docs.pyro.ai/en/dev/primitives.html#pyro.__init__.iarange) is to subsample data. This is possible within an `iarange` because data are independent, so the expected value of the loss on half the data should be half the expected value of the loss on the full data.

To subsample data, you need to inform Pyro of both the original data size and the subsample size; Pyro will then choose a random subset of data and yield the set of indices.

In [8]:
data = torch.arange(100)

def model2():
    mean = pyro.param("mean", torch.zeros(len(data)))
    with pyro.iarange("data", len(data), subsample_size=10) as ind:
        assert len(ind) == 10    # ind is a LongTensor that indexes the subsample.
        batch = data[ind]        # Select a minibatch of data.
        mean_batch = mean[ind]   # Take care to select the relevant per-datum parameters.
        # Do stuff with batch:
        x = pyro.sample("x", Normal(mean_batch, 1), obs=batch)
        assert len(x) == 10
        
test_model(model2, guide=lambda: None)

## Broadcasting to allow parallel enumeration <a class="anchor" id="enumerate"></a>

Pyro 0.2 introduces the ability to enumerate discrete latent variables in parallel. This can significantly reduce the variance of gradient estimators when fitting an SVI model.

To use discrete enumeration, Pyro needs to allocate tensor dimension that it can use for enumeration. To avoid conflicting with other dimensions that we want to use for `iarange`s, we need to declare a budget of the maximum number of tensor dimensions we'll use. This budget is called `max_iarange_nesting` and is an argument to [SVI](http://docs.pyro.ai/en/dev/inference_algos.html) (the argument is simply passed through to [TraceEnum_ELBO](http://docs.pyro.ai/en/dev/inference_algos.html#pyro.infer.traceenum_elbo.TraceEnum_ELBO)).

To understand `max_iarange_nesting` and how Pyro allocates dimensions for enumeration, let's revisit `model1()` from above. This time we'll map out three types of dimensions:
enumeration dimensions on the left (Pyro takes control of these), batch dimensions in the middle, and event dimensions on the right.

```
      max_iarange_nesting = 3
           |<--->|
enumeration|batch|event
-----------+-----+-----
           |. . .|      a = sample("a", Normal(0, 1))
           |. . .|2     b = sample("b", Normal(zeros(2), 1)
           |     |                      .reshape(extra_event_dims=1))
           |     |      with iarange("c", 2):
           |. . 2|          c = sample("c", Normal(zeros(2), 1))
           |     |      with iarange("d", 3):
           |. . 3|4 5       d = sample("d", Normal(zeros(3,4,5), 1)
           |     |                     .reshape(extra_event_dims=2))
           |     |
           |     |      x_axis = iarange("x", 3, dim=-2)
           |     |      y_axis = iarange("y", 2, dim=-3)
           |     |      with x_axis:
           |. 3 1|          x = sample("x", Normal(0, 1).reshape([3,1]))
           |     |      with y_axis:
           |2 1 1|          y = sample("y", Normal(0, 1).reshape([2,1,1]))
           |     |      with x_axis, y_axis:
           |2 3 1|          xy = sample("xy", Normal(0, 1).reshape([2,3,1]))
           |2 3 1|5         z = sample("z", Normal(0, 1).reshape([2,3,1,5], 1))
```
Note that we can overprovision `max_iarange_nesting=4` but we cannot underprovision `max_iarange_nesting=2` (or Pyro will error). Let's see how this works in practice.

In [9]:
@config_enumerate(default="parallel")
def model3():
    p = pyro.param("p", torch.arange(6) / 6)
    locs = pyro.param("locs", torch.tensor([-1., 1.]))

    a = pyro.sample("a", Categorical(torch.ones(6) / 6))
    b = pyro.sample("b", Bernoulli(p[a]))  # Note this depends on a.
    with pyro.iarange("c_iarange", 4):
        c = pyro.sample("c", Bernoulli(0.3).reshape([4]))
        with pyro.iarange("d_iarange", 5):
            d = pyro.sample("d", Bernoulli(0.4).reshape([5,4]))
            e_loc = locs[d.long()].unsqueeze(-1)
            e_scale = torch.arange(1, 8)
            e = pyro.sample("e", Normal(e_loc, e_scale)
                            .reshape(extra_event_dims=1))  # Note this depends on d.

    #                   enumerated|batch|event dims
    assert a.shape == (         6, 1, 1   )  # Six enumerated values of the Categorical.
    assert b.shape == (      2, 6, 1, 1   )  # 2 enumerated Bernoullis x 6 Categoricals.
    assert c.shape == (   2, 1, 1, 1, 4   )  # Only 2 Bernoullis; does not depend on a or b.
    assert d.shape == (2, 1, 1, 1, 5, 4   )  # Only two Bernoullis.
    assert e.shape == (2, 1, 1, 1, 5, 4, 7)  # This is sampled and depends on d.

    assert e_loc.shape   == (2, 1, 1, 1, 5, 4, 1,)
    assert e_scale.shape == (                  7,)
            
test_model(model3, max_iarange_nesting=2, enum_discrete=True)

Let's take a closer look at those dimensions. First note that Pyro allocates enumeration dims starting from the right at `max_iarange_nesting`: Pyro allocates dim -3 to enumerate `a`, then dim -4 to enumerate `b`, then dim -5 to enumerate `c`, and finally dim -6 to enumerate `d`. Next note that variables only have extent (size > 1) in dimensions they depend on. This helps keep tensors small and computation cheap. We can draw a similar map of the tensor dimensions:
```
     max_iarange_nesting = 2
            |<->|
enumeration batch event
------------|---|-----
           6|1 1|     a = pyro.sample("a", Categorical(torch.ones(6) / 6))
         2 1|1 1|     b = pyro.sample("b", Bernoulli(p[a]))
            |   |     with pyro.iarange("c_iarange", 4):
       2 1 1|1 4|         c = pyro.sample("c", Bernoulli(0.3).reshape([4]))
            |   |         with pyro.iarange("d_iarange", 5):
     2 1 1 1|5 4|             d = pyro.sample("d", Bernoulli(0.4).reshape([5,4]))
     2 1 1 1|5 4|1            e_loc = locs[d.long()].unsqueeze(-1)
            |   |7            e_scale = torch.arange(1, 8)
     2 1 1 1|5 4|7            e = pyro.sample("e", Normal(e_loc, e_scale)
            |   |                             .reshape(extra_event_dims=1))
```