In [1]:
"""
Original tutorial: http://pyro.ai/examples/tensor_shapes.html

While you are learning or debugging, set pyro.enable_validation(True).
Tensors broadcast by aligning on the right: torch.ones(3,4,5) + torch.ones(5).
Distribution .sample().shape == batch_shape + event_shape.
Distribution .log_prob(x).shape == batch_shape (but not event_shape!).
Use .expand() to draw a batch of samples, or rely on plate to expand automatically.
Use my_dist.to_event(1) to declare a dimension as dependent.
Use with pyro.plate('name', size): to declare a dimension as conditionally independent.
All dimensions must be declared either dependent or conditionally independent.
Try to support batching on the left. This lets Pyro auto-parallelize.
use negative indices like x.sum(-1) rather than x.sum(2)
use ellipsis notation like pixel = image[..., i, j]
use Vindex if i,j are enumerated, pixel = Vindex(image)[..., i, j]
When debugging, examine all shapes in a trace using Trace.format_shapes().
"""

In [2]:
import torch
import pyro
import os
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.distributions.util import broadcast_shape
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.0.0')
pyro.enable_validation(True)    # <---- This is always a good idea!

# We'll ue this helper to check our models are correct.
def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)

In [3]:
"""
PyTorch Tensors have a single .shape attribute, 
but Distributions have two shape attributions with special meaning: 
.batch_shape and .event_shape. These two combine to define the total shape of a sample

Indices over .batch_shape denote conditionally independent random variables, 
whereas indices over .event_shape denote dependent random variables (ie one draw from a distribution)

Note that the Distribution.sample() method also takes a sample_shape parameter
that indexes over independent identically distributed (iid) random varables, so that

x2.shape == sample_shape + batch_shape + event_shape

Meaning that we can have multiple samples for each item in the batch 
"""
d = Bernoulli(0.5)
assert d.batch_shape == ()
assert d.event_shape == ()
x = d.sample()
assert x.shape == ()
assert d.log_prob(x).shape == ()

In [4]:
d = Bernoulli(0.5 * torch.ones(3,4)) # If this particular tensor or in general if tensor is passed as an initializer that is the resulttant distribution shape
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample((5,))
print(x)
assert x.shape == (5,3, 4)
assert d.log_prob(x).shape == (5,3, 4)
print(d.log_prob(x))
# Which is ln(0.5) == -0.6931

tensor([[[0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [1., 1., 1., 1.]],

        [[1., 1., 0., 0.],
         [0., 0., 0., 1.],
         [1., 0., 0., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 0.],
         [0., 1., 1., 1.]],

        [[0., 1., 0., 1.],
         [1., 1., 0., 0.],
         [1., 0., 1., 1.]],

        [[1., 0., 1., 0.],
         [0., 1., 0., 1.],
         [0., 1., 0., 0.]]])
tensor([[[-0.6931, -0.6931, -0.6931, -0.6931],
         [-0.6931, -0.6931, -0.6931, -0.6931],
         [-0.6931, -0.6931, -0.6931, -0.6931]],

        [[-0.6931, -0.6931, -0.6931, -0.6931],
         [-0.6931, -0.6931, -0.6931, -0.6931],
         [-0.6931, -0.6931, -0.6931, -0.6931]],

        [[-0.6931, -0.6931, -0.6931, -0.6931],
         [-0.6931, -0.6931, -0.6931, -0.6931],
         [-0.6931, -0.6931, -0.6931, -0.6931]],

        [[-0.6931, -0.6931, -0.6931, -0.6931],
         [-0.6931, -0.6931, -0.6931, -0.6931],
         [-0.6931, -0.6931, -0.6931, -0.6931]],

        [[-0.

In [5]:
"""
Distributions over vectors like MultivariateNormal have len(event_shape) == 1.
Distributions over matrices like InverseWishart have len(event_shape) == 2
"""

'\nDistributions over vectors like MultivariateNormal have len(event_shape) == 1.\nDistributions over matrices like InverseWishart have len(event_shape) == 2\n'

In [6]:
d = Bernoulli(torch.tensor([0.1, 0.2, 0.3, 0.4])).expand([3, 4])
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)
print(x)
print(d.log_prob(x))

tensor([[0., 0., 1., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 1.]])
tensor([[-0.1054, -0.2231, -1.2040, -0.5108],
        [-0.1054, -0.2231, -0.3567, -0.5108],
        [-0.1054, -0.2231, -0.3567, -0.9163]])


In [7]:
"""
Multivariate distributions have nonempty .event_shape. 
For these distributions, the shapes of .sample() and .log_prob(x) differ:
"""
d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3,)
x = d.sample()
print(x)
assert x.shape == (3,)            # == batch_shape + event_shape
assert d.log_prob(x).shape == ()  # == batch_shape
print(d.log_prob(x))

tensor([0.4929, 0.0339, 0.2949])
tensor(-2.9223)


In [16]:
"""
In Pyro you can treat a univariate distribution as multivariate by calling the .to_event(n)
property where n is the number of batch dimensions (from the right) to declare as dependent.
"""
d = Bernoulli(0.5 * torch.ones(3,4)).to_event(2)
assert d.batch_shape == ()
assert d.event_shape == (3,4)
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == ()
print(x)
print(d.log_prob(x))

tensor([[1., 0., 1., 1.],
        [0., 0., 0., 1.],
        [1., 1., 1., 1.]])
tensor(-8.3178)


In [17]:
"""
While you work with Pyro programs, keep in mind that samples have shape batch_shape + event_shape,
whereas .log_prob(x) values have shape batch_shape. 
You’ll need to ensure that batch_shape is carefully controlled by either trimming it down with .to_event(n)
or by declaring dimensions as independent via pyro.plate.
"""

'\nWhile you work with Pyro programs, keep in mind that samples have shape batch_shape + event_shape,\nwhereas .log_prob(x) values have shape batch_shape. \nYou’ll need to ensure that batch_shape is carefully controlled by either trimming it down with .to_event(n)\nor by declaring dimensions as independent via pyro.plate.\n'

In [22]:
"""
Often in Pyro we’ll declare some dimensions as dependent even though they are in fact independent, e.g.
"""
d = Normal(0, 1).expand([10]).to_event(1)
print(d.shape,d.batch_shape, d.event_shape)
x = d.sample()
print(x, )
assert x.shape == (10,)
"""
This is useful for two reasons: First it allows us to easily swap in a MultivariateNormal distribution later.
Second it simplifies the code a bit since we don’t need a plate as in
"""
with pyro.plate("x_plate", 10):
    x = pyro.sample("x", Normal(0, 1))  # .expand([10]) is automatic
    assert x.shape == (10,)
    
"""
The difference between these two versions is that the second version with plate
informs Pyro that it can make use of conditional independence information when estimating gradients,
whereas in the first version Pyro must assume they are dependent
(even though the normals are in fact conditionally independent).
This is analogous to d-separation in graphical models: it is always safe to add edges
and assume variables may be dependent (i.e. to widen the model class), 
but it is unsafe to assume independence when variables are actually dependent
(i.e. narrowing the model class so the true model lies outside of the class,
as in mean field). 

In practice Pyro’s SVI inference algorithm uses
reparameterized gradient estimators for Normal distributions
so both gradient estimators have the same performance.
"""

<bound method TorchDistributionMixin.shape of Independent()> torch.Size([]) torch.Size([10])
tensor([ 0.9060, -0.2277,  0.4350, -0.1837,  0.6944, -2.0546, -0.4330,  0.1646,
         1.5989, -1.5854])


In [24]:
"""
We recommend always providing an optional size argument to aid in debugging shapes

"""
with pyro.plate("my_plate", 10):
    pass
    # within this context, batch dimension -1 is independent
"""
Starting with Pyro 0.2 you can additionally nest plates, e.g. if you have per-pixel independence:

"""
with pyro.plate("x_axis", 320):
    # within this context, batch dimension -1 is independent
    with pyro.plate("y_axis", 200):
        # within this context, batch dimensions -2 and -1 are independent
        pass

In [26]:
def model1():
    a = pyro.sample("a", Normal(0, 1))
    b = pyro.sample("b", Normal(torch.zeros(2), 1).to_event(1))
    with pyro.plate("c_plate", 2):
        c = pyro.sample("c", Normal(torch.zeros(2), 1))
    with pyro.plate("d_plate", 3):
        d = pyro.sample("d", Normal(torch.zeros(3,4,5), 1).to_event(2))
    assert a.shape == ()       # batch_shape == ()     event_shape == ()
    assert b.shape == (2,)     # batch_shape == ()     event_shape == (2,)
    assert c.shape == (2,)     # batch_shape == (2,)   event_sahpe == ()
    assert d.shape == (3,4,5)  # batch_shape == (3,)   event_shape == (4,5)

    x_axis = pyro.plate("x_axis", 3, dim=-2)
    y_axis = pyro.plate("y_axis", 2, dim=-3)
    with x_axis:
        x = pyro.sample("x", Normal(0, 1))
    with y_axis:
        y = pyro.sample("y", Normal(0, 1))
    with x_axis, y_axis:
        xy = pyro.sample("xy", Normal(0, 1))
        z = pyro.sample("z", Normal(0, 1).expand([5]).to_event(1))
    assert x.shape == (3, 1)        # batch_shape == (3,1)     event_shape == ()
    assert y.shape == (2, 1, 1)     # batch_shape == (2,1,1)   event_shape == ()
    assert xy.shape == (2, 3, 1)    # batch_shape == (2,3,1)   event_shape == ()
    assert z.shape == (2, 3, 1, 5)  # batch_shape == (2,3,1)   event_shape == (5,)

test_model(model1, model1, Trace_ELBO())
trace = poutine.trace(model1).get_trace()
trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())

Trace Shapes:            
 Param Sites:            
Sample Sites:            
       a dist       |    
        value       |    
     log_prob       |    
       b dist       | 2  
        value       | 2  
     log_prob       |    
 c_plate dist       |    
        value     2 |    
     log_prob       |    
       c dist     2 |    
        value     2 |    
     log_prob     2 |    
 d_plate dist       |    
        value     3 |    
     log_prob       |    
       d dist     3 | 4 5
        value     3 | 4 5
     log_prob     3 |    
  x_axis dist       |    
        value     3 |    
     log_prob       |    
  y_axis dist       |    
        value     2 |    
     log_prob       |    
       x dist   3 1 |    
        value   3 1 |    
     log_prob   3 1 |    
       y dist 2 1 1 |    
        value 2 1 1 |    
     log_prob 2 1 1 |    
      xy dist 2 3 1 |    
        value 2 3 1 |    
     log_prob 2 3 1 |    
       z dist 2 3 1 | 5  
        value 2 3 1 | 5  
     log_pro

In [27]:
"""
One of the main uses of plate is to subsample data. 
This is possible within a plate because data are conditionally independent, 
so the expected value of the loss on, say,
half the data should be half the expected loss on the full data.

To subsample data, you need to inform Pyro of both the original data size
and the subsample size; Pyro will then choose a random subset
of data and yield the set of indices.
"""
data = torch.arange(100.)

def model2():
    mean = pyro.param("mean", torch.zeros(len(data)))
    with pyro.plate("data", len(data), subsample_size=10) as ind:
        assert len(ind) == 10    # ind is a LongTensor that indexes the subsample.
        batch = data[ind]        # Select a minibatch of data.
        mean_batch = mean[ind]   # Take care to select the relevant per-datum parameters.
        # Do stuff with batch:
        x = pyro.sample("x", Normal(mean_batch, 1), obs=batch)
        assert len(x) == 10

test_model(model2, guide=lambda: None, loss=Trace_ELBO())

In [28]:
"""
Pyro 0.2 introduces the ability to enumerate discrete latent variables in parallel. 
This can significantly reduce the variance of gradient estimators when learning a posterior via SVI.
"""
@config_enumerate
def model3():
    p = pyro.param("p", torch.arange(6.) / 6)
    locs = pyro.param("locs", torch.tensor([-1., 1.]))

    a = pyro.sample("a", Categorical(torch.ones(6) / 6))
    b = pyro.sample("b", Bernoulli(p[a]))  # Note this depends on a.
    with pyro.plate("c_plate", 4):
        c = pyro.sample("c", Bernoulli(0.3))
        with pyro.plate("d_plate", 5):
            d = pyro.sample("d", Bernoulli(0.4))
            e_loc = locs[d.long()].unsqueeze(-1)
            e_scale = torch.arange(1., 8.)
            e = pyro.sample("e", Normal(e_loc, e_scale)
                            .to_event(1))  # Note this depends on d.

    #                   enumerated|batch|event dims
    assert a.shape == (         6, 1, 1   )  # Six enumerated values of the Categorical.
    assert b.shape == (      2, 1, 1, 1   )  # Two enumerated Bernoullis, unexpanded.
    assert c.shape == (   2, 1, 1, 1, 1   )  # Only two Bernoullis, unexpanded.
    assert d.shape == (2, 1, 1, 1, 1, 1   )  # Only two Bernoullis, unexpanded.
    assert e.shape == (2, 1, 1, 1, 5, 4, 7)  # This is sampled and depends on d.

    assert e_loc.shape   == (2, 1, 1, 1, 1, 1, 1,)
    assert e_scale.shape == (                  7,)

test_model(model3, model3, TraceEnum_ELBO(max_plate_nesting=2))

In [29]:
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
enumerated = None  # set to either True or False below

def fun(observe):
    p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval)
    p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval)
    x_axis = pyro.plate('x_axis', width, dim=-2)
    y_axis = pyro.plate('y_axis', height, dim=-1)

    # Note that the shapes of these sites depend on whether Pyro is enumerating.
    with x_axis:
        x_active = pyro.sample("x_active", Bernoulli(p_x))
    with y_axis:
        y_active = pyro.sample("y_active", Bernoulli(p_y))
    if enumerated:
        assert x_active.shape  == (2, 1, 1)
        assert y_active.shape  == (2, 1, 1, 1)
    else:
        assert x_active.shape  == (width, 1)
        assert y_active.shape  == (height,)

    # The first trick is to broadcast. This works with or without enumeration.
    p = 0.1 + 0.5 * x_active * y_active
    if enumerated:
        assert p.shape == (2, 2, 1, 1)
    else:
        assert p.shape == (width, height)
    dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))

    # The second trick is to index using ellipsis slicing.
    # This allows Pyro to add arbitrary dimensions on the left.
    for x, y in sparse_pixels:
        dense_pixels[..., x, y] = 1
    if enumerated:
        assert dense_pixels.shape == (2, 2, width, height)
    else:
        assert dense_pixels.shape == (width, height)

    with x_axis, y_axis:
        if observe:
            pyro.sample("pixels", Bernoulli(p), obs=dense_pixels)

def model4():
    fun(observe=True)

def guide4():
    fun(observe=False)

# Test without enumeration.
enumerated = False
test_model(model4, guide4, Trace_ELBO())

# Test with enumeration.
enumerated = True
test_model(model4, config_enumerate(guide4, "parallel"),
           TraceEnum_ELBO(max_plate_nesting=2))