# Tensor shapes in Pyro 0.2

This tutorial introduces Pyro's organization of tensor dimensions.

Summary:
- Tensors broadcast on the left: `torch.ones(3,4,5) + torch.ones(5)`.
- Distribution `.sample()` has shape `batch_shape + event_shape`.
- Distribution `.log_prob(x)` has shape `batch_shape` (but not `event_shape`!).
- Use `my_dist.reshape([2,3,4])` to draw a batch of samples.
- Use `my_dist.reshape(extra_event_dims=1)` to declare a dimension as dependent.
- Use `with iarange('name', size):` to declare a dimension as independent.
- All dimensions must be declared either dependent or independent.
- Try to support batching on the left. This lets Pyro auto-parallelize.
  - use negative indices like `x.sum(-1)` rather than `x.sum(2)`
  - use ellipsis notation like `pixel = image[..., i, j]`
  

In [1]:
import torch
import pyro
from pyro.distributions import Bernoulli, MultivariateNormal, Normal

## Distributions shapes: `batch_shape` and `event_shape`

The simplest distribution shape is a single univariate distribution.

In [2]:
d = Bernoulli(0.5)
assert d.batch_shape == ()
assert d.event_shape == ()
x = d.sample()
assert x.shape == ()
assert d.log_prob(x).shape == ()

Distributions can be batched by passing in batched parameters.

In [3]:
d = Bernoulli(0.5 * torch.ones(3,4))
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

Another way to batch distributions is via the `.reshape()` method. This only works if 
parameters are identical along the leftmost dimensions.

In [4]:
d = Bernoulli(torch.tensor([0.1, 0.2, 0.3, 0.4])).reshape([3])
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

Multivariate distributions have nonempty `.event_shape`. For these distributions, the shapes of `.sample()` and `.log_prob(x)` differ:

In [5]:
d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3,)
x = d.sample()
assert x.shape == (3,)            # == batch_shape + event_shape
assert d.log_prob(x).shape == ()  # == batch_shape

In Pyro you can treat a univariate distribution as multivariate by calling the `.reshape(extra_event_dims=_)` property.

In [6]:
d = Bernoulli(0.5 * torch.ones(3,4)).reshape(extra_event_dims=1)
assert d.batch_shape == (3,)
assert d.event_shape == (4,)
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3,)

While you work with Pyro programs, keep in mind that samples have shape `batch_shape + event_shape`, whereas `.log_prob(x)` values have shape `batch_shape`. You'll need to ensure that `batch_shape` is carefully controlled by either trimming it down with `.reshape(extra_event_dims=_)` or by declaring dimensions as independent via `pyro.iarange`.

## Declare independent dims with `iarange`

In [7]:
def model():
    a = sample("a", Normal(0, 1))
    b = sample("b", Normal(zeros(2), 1).reshape(extra_event_dims=1))
    with iarange("c", 2):
        c = sample("c", Normal(zeros(2), 1))
    with iarange("d", 3):
        d = sample("d", Normal(zeros(3,4,5), 1).reshape(extra_event_dims=2))
    assert a.shape == ()
    assert b.shape == (2,)
    assert c.shape == (2,)
    assert d.shape == (3,4,5)

    x_axis = iarange("x", 3, dim=-2)
    y_axis = iarange("y", 2, dim=-3)
    with x_axis:
        x = sample("x", Normal(0, 1).reshape([3, 1]))
    with y_axis:
        y = sample("y", Normal(0, 1).reshape([2, 1, 1]))
    with x_axis, y_axis:
        xy = sample("xy", Normal(0, 1).reshape([2, 3, 1]))
        z = sample("z", Normal(0, 1).reshape([2, 3, 1, 5], 1))
    assert x.shape == (3, 1)
    assert y.shape == (2, 1, 1)
    assert xy.shape == (2, 3, 1)
    assert z.shape == (2, 3, 1, 5)

```
batch dims | event dims
-----------+-----------
           |        a = sample("a", Normal(0, 1))
           |2       b = sample("b", Normal(zeros(2), 1)
           |                        .reshape(extra_event_dims=1)
           |        with iarange("c", 2):
          2|            c = sample("c", Normal(zeros(2), 1))
           |        with iarange("d", 3):
          3|4 5         d = sample("d", Normal(zeros(3,4,5), 1)
           |                       .reshape(extra_event_dims=2)
           |        x_axis = iarange("x", 3, dim=-2)
           |        y_axis = iarange("y", 2, dim=-3)
           |        with x_axis:
        3 1|            x = sample("x", Normal(0, 1)
           |                       .reshape([3, 1]))
           |        with y_axis:
      2 1 1|            y = sample("y", Normal(0, 1)
           |                       .reshape([2, 1, 1]))
           |        with x_axis, y_axis:
      2 3 1|            xy = sample("xy", Normal(0, 1)
           |                        .reshape([2, 3, 1]))
      2 3 1|5           z = sample("z", Normal(0, 1)
           |                       .reshape([2, 3, 1, 5], 1))
```

## Broadcasting to allow parallel enumeration

```
       max_iarange_nesting = 4
            |<--->|
enumeration |batch| event
------------+-----+--------
            |. . .|        a = sample("a", Normal(0, 1))
            |. . .|2       b = sample("b", Normal(zeros(2), 1)
            |     |                        .reshape(extra_event_dims=1))
            |     |        with iarange("c", 2):
            |. . 2|            c = sample("c", Normal(zeros(2), 1))
            |     |        with iarange("d", 3):
            |. . 3|4 5         d = sample("d", Normal(zeros(3,4,5), 1)
            |     |                       .reshape(extra_event_dims=2))
            |     |        x_axis = iarange("x", 3, dim=-2)
            |     |        y_axis = iarange("y", 2, dim=-3)
            |     |        with x_axis:
            |. 3 .|            x = sample("x", Normal(0, 1)
            |     |                       .reshape([3, 1]))
            |     |        with y_axis:
            |2 . .|            y = sample("y", Normal(0, 1)
            |     |                       .reshape([2, 1, 1]))
            |     |        with x_axis, y_axis:
            |2 3 .|            xy = sample("xy", Normal(0, 1)
            |     |                        .reshape([2, 3, 1]))
            |2 3 .|5           z = sample("z", Normal(0, 1)
            |     |                       .reshape([2, 3, 1, 5], 1))

```