In [1]:
import os
import torch
import pyro
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam

smoke_test = ('CI' in os.environ)
pyro.enable_validation(True)    # <---- This is always a good idea!

# We'll use this helper to check our models are correct.
def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)

In [2]:
#       |      iid     | independent | dependent
# ------+--------------+-------------+------------
# shape = sample_shape + batch_shape + event_shape

In [3]:
d = Bernoulli(0.5)
assert d.batch_shape == ()
assert d.event_shape == ()
x = d.sample()
assert x.shape == ()
assert d.log_prob(x).shape == ()

In [4]:
d = Bernoulli(0.5 * torch.ones(3,4))
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

In [5]:
d = Bernoulli(torch.tensor([0.1, 0.2, 0.3, 0.4])).expand_by([3])
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

In [6]:
d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3,)
x = d.sample()
assert x.shape == (3,)            # == batch_shape + event_shape
assert d.log_prob(x).shape == ()  # == batch_shape

In [7]:
a = (3,3)
type(a)
print(a[1])

3


In [8]:
d = Bernoulli(0.5 * torch.ones(3,4)).independent(1)
assert d.batch_shape == (3,)
assert d.event_shape == (4,)
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3,)

In [9]:
d = Bernoulli(0.5 * torch.ones(3,4)).expand_by([10])
assert d.batch_shape == (10,3,4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (10,3,4)
assert d.log_prob(x).shape == (10,3,4)

In [10]:
d = Bernoulli(0.5 * torch.ones(3,4)).expand_by([10]).independent(2)
print(d.batch_shape)
print(d.event_shape)
x = d.sample()
print(x.shape)
print(d.log_prob(x).shape)

torch.Size([10])
torch.Size([3, 4])
torch.Size([10, 3, 4])
torch.Size([10])


In [11]:
def model1():
    a = pyro.sample("a", Normal(0, 1))
    b = pyro.sample("b", Normal(torch.zeros(2), 1).independent(1))
    with pyro.iarange("c_iarange", 2):
        c = pyro.sample("c", Normal(torch.zeros(2), 1))
    with pyro.iarange("d_iarange", 3):
        d = pyro.sample("d", Normal(torch.zeros(3,4,5), 1).independent(2))
    assert a.shape == ()       # batch_shape == ()     event_shape == ()
    assert b.shape == (2,)     # batch_shape == ()     event_shape == (2,)
    assert c.shape == (2,)     # batch_shape == (2,)   event_sahpe == ()
    assert d.shape == (3,4,5)  # batch_shape == (3,)   event_shape == (4,5)

    x_axis = pyro.iarange("x_axis", 3, dim=-2)
    y_axis = pyro.iarange("y_axis", 2, dim=-3)
    with x_axis:
        x = pyro.sample("x", Normal(0, 1).expand_by([3, 1]))
    with y_axis:
        y = pyro.sample("y", Normal(0, 1).expand_by([2, 1, 1]))
    with x_axis, y_axis:
        xy = pyro.sample("xy", Normal(0, 1).expand_by([2, 3, 1]))
        z = pyro.sample("z", Normal(0, 1).expand_by([2, 3, 1, 5]).independent(1))
    assert x.shape == (3, 1)        # batch_shape == (3,1)     event_shape == ()
    assert y.shape == (2, 1, 1)     # batch_shape == (2,1,1)   event_shape == ()
    assert xy.shape == (2, 3, 1)    # batch_shape == (2,3,1)   event_shape == ()
    assert z.shape == (2, 3, 1, 5)  # batch_shape == (2,3,1)   event_shape == (5,)

test_model(model1, model1, Trace_ELBO())



In [12]:
# batch dims | event dims
# -----------+-----------
#            |        a = sample("a", Normal(0, 1))
#            |2       b = sample("b", Normal(zeros(2), 1)
#            |                        .independent(1)
#            |        with iarange("c", 2):
#           2|            c = sample("c", Normal(zeros(2), 1))
#            |        with iarange("d", 3):
#           3|4 5         d = sample("d", Normal(zeros(3,4,5), 1)
#            |                       .independent(2)
#            |
#            |        x_axis = iarange("x", 3, dim=-2)
#            |        y_axis = iarange("y", 2, dim=-3)
#            |        with x_axis:
#         3 1|            x = sample("x", Normal(0, 1).expand_by([3, 1]))
#            |        with y_axis:
#       2 1 1|            y = sample("y", Normal(0, 1).expand_by([2, 1, 1]))
#            |        with x_axis, y_axis:
#       2 3 1|            xy = sample("xy", Normal(0, 1).expand_by([2, 3, 1]))
#       2 3 1|5           z = sample("z", Normal(0, 1).expand_by([2, 3, 1, 5])
#            |                       .independent(1))

In [13]:
data = torch.arange(100)

def model2():
    mean = pyro.param("mean", torch.zeros(len(data)))
    with pyro.iarange("data", len(data), subsample_size=20) as ind:
        assert len(ind) == 20    # ind is a LongTensor that indexes the subsample.
        batch = data[ind]        # Select a minibatch of data.
        mean_batch = mean[ind]   # Take care to select the relevant per-datum parameters.
        # Do stuff with batch:
        x = pyro.sample("x", Normal(mean_batch, 1), obs=batch)
        assert len(x) == 20

test_model(model2, guide=lambda: None, loss=Trace_ELBO())

In [14]:
@config_enumerate(default="parallel")
def model3():
    p = pyro.param("p", torch.arange(6) / 6)
    locs = pyro.param("locs", torch.tensor([-1., 1.]))

    a = pyro.sample("a", Categorical(torch.ones(6) / 6))
    b = pyro.sample("b", Bernoulli(p[a]))  # Note this depends on a.
    with pyro.iarange("c_iarange", 4):
        c = pyro.sample("c", Bernoulli(0.3).expand_by([4]))
        with pyro.iarange("d_iarange", 5):
            d = pyro.sample("d", Bernoulli(0.4).expand_by([5,4]))
            e_loc = locs[d.long()].unsqueeze(-1)
            e_scale = torch.arange(1, 8)
            e = pyro.sample("e", Normal(e_loc, e_scale)
                            .independent(1))  # Note this depends on d.

    #                   enumerated|batch|event dims
    assert a.shape == (         6, 1, 1   )  # Six enumerated values of the Categorical.
    assert b.shape == (      2, 6, 1, 1   )  # 2 enumerated Bernoullis x 6 Categoricals.
    assert c.shape == (   2, 1, 1, 1, 4   )  # Only 2 Bernoullis; does not depend on a or b.
    assert d.shape == (2, 1, 1, 1, 5, 4   )  # Only two Bernoullis.
    assert e.shape == (2, 1, 1, 1, 5, 4, 7)  # This is sampled and depends on d.

    assert e_loc.shape   == (2, 1, 1, 1, 5, 4, 1,)
    assert e_scale.shape == (                  7,)

test_model(model3, model3, TraceEnum_ELBO(max_iarange_nesting=2))

In [15]:
#      max_iarange_nesting = 2
#             |<->|
# enumeration batch event
# ------------|---|-----
#            6|1 1|     a = pyro.sample("a", Categorical(torch.ones(6) / 6))
#          2 1|1 1|     b = pyro.sample("b", Bernoulli(p[a]))
#             |   |     with pyro.iarange("c_iarange", 4):
#        2 1 1|1 4|         c = pyro.sample("c", Bernoulli(0.3).expand_by([4]))
#             |   |         with pyro.iarange("d_iarange", 5):
#      2 1 1 1|5 4|             d = pyro.sample("d", Bernoulli(0.4).expand_by([5,4]))
#      2 1 1 1|5 4|1            e_loc = locs[d.long()].unsqueeze(-1)
#             |   |7            e_scale = torch.arange(1, 8)
#      2 1 1 1|5 4|7            e = pyro.sample("e", Normal(e_loc, e_scale)
#             |   |                             .independent(1))

In [18]:
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
enumerated = None  # set to either True or False below

def fun(observe):
    p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval)
    p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval)
    x_axis = pyro.iarange('x_axis', width, dim=-2)
    y_axis = pyro.iarange('y_axis', height, dim=-1)

    # Note that the shapes of these sites depend on whether Pyro is enumerating.
    with x_axis:
        x_active = pyro.sample("x_active", Bernoulli(p_x).expand_by([width, 1]))
    with y_axis:
        y_active = pyro.sample("y_active", Bernoulli(p_y).expand_by([height]))
    if enumerated:
        assert x_active.shape  == (2, width, 1)
        assert y_active.shape  == (2, 1, 1, height)
    else:
        assert x_active.shape  == (width, 1)
        assert y_active.shape  == (height,)

    # The first trick is to broadcast. This works with or without enumeration.
    p = 0.1 + 0.5 * x_active * y_active
    if enumerated:
        assert p.shape == (2, 2, width, height)
    else:
        assert p.shape == (width, height)

    # The second trick is to index using ellipsis slicing.
    # This allows Pyro to add arbitrary dimensions on the left.
    dense_pixels = torch.zeros_like(p)
    for x, y in sparse_pixels:
        dense_pixels[..., x, y] = 1
    if enumerated:
        assert dense_pixels.shape == (2, 2, width, height)
    else:
        assert dense_pixels.shape == (width, height)

    with x_axis, y_axis:
        if observe:
            pyro.sample("pixels", Bernoulli(p), obs=dense_pixels)

def model4():
    fun(observe=True)

@config_enumerate(default="parallel")
def guide4():
    fun(observe=False)

# Test without enumeration.
#enumerated = False
#test_model(model4, guide4, Trace_ELBO())

# Test with enumeration.
enumerated = True
test_model(model4, guide4, TraceEnum_ELBO(max_iarange_nesting=2))

In [24]:
num_particles = 100  # Number of samples for the ELBO estimator
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])

def sample_pixel_locations_no_broadcasting(p_x, p_y, x_axis, y_axis):
    with x_axis:
        x_active = pyro.sample("x_active", Bernoulli(p_x).expand_by([num_particles, width, 1]))
    with y_axis:
        y_active = pyro.sample("y_active", Bernoulli(p_y).expand_by([num_particles, 1, height]))
    return x_active, y_active

def sample_pixel_locations_automatic_broadcasting(p_x, p_y, x_axis, y_axis):
    with x_axis:
        x_active = pyro.sample("x_active", Bernoulli(p_x))
    with y_axis:
        y_active = pyro.sample("y_active", Bernoulli(p_y))
    return x_active, y_active

def sample_pixel_locations_partial_broadcasting(p_x, p_y, x_axis, y_axis):
    with x_axis:
        x_active = pyro.sample("x_active", Bernoulli(p_x).expand_by([width, 1]))
    with y_axis:
        y_active = pyro.sample("y_active", Bernoulli(p_y).expand_by([height]))
    return x_active, y_active

def fun(observe, sample_fn):
    p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval)
    p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval)
    x_axis = pyro.iarange('x_axis', width, dim=-2)
    y_axis = pyro.iarange('y_axis', height, dim=-1)

    with pyro.iarange("num_particles", 100, dim=-3):
        x_active, y_active = sample_fn(p_x, p_y, x_axis, y_axis)
        # Indices corresponding to "parallel" enumeration are appended
        # to the left of the "num_particles" iarange dim.
        print(x_active.shape) #assert x_active.shape  == (2, num_particles, width, 1)
        print(y_active.shape) #assert y_active.shape  == (2, 1, num_particles, 1, height)
        p = 0.1 + 0.5 * x_active * y_active
        assert p.shape == (2, 2, num_particles, width, height)

        dense_pixels = torch.zeros_like(p)
        for x, y in sparse_pixels:
            dense_pixels[..., x, y] = 1
        assert dense_pixels.shape == (2, 2, num_particles, width, height)

        with x_axis, y_axis:
            if observe:
                pyro.sample("pixels", Bernoulli(p), obs=dense_pixels)

def test_model_with_sample_fn(sample_fn, broadcast=False):
    def model():
        fun(observe=True, sample_fn=sample_fn)

    @config_enumerate(default="parallel")
    def guide():
        fun(observe=False, sample_fn=sample_fn)

    if broadcast:
        model = poutine.broadcast(model)
        guide = poutine.broadcast(guide)
    test_model(model, guide, TraceEnum_ELBO(max_iarange_nesting=3))

#test_model_with_sample_fn(sample_pixel_locations_no_broadcasting)
test_model_with_sample_fn(sample_pixel_locations_automatic_broadcasting, broadcast=True)
#test_model_with_sample_fn(sample_pixel_locations_partial_broadcasting, broadcast=True)

torch.Size([2, 100, 8, 1])
torch.Size([2, 1, 100, 1, 10])
torch.Size([2, 100, 8, 1])
torch.Size([2, 1, 100, 1, 10])
