Skip to content

Commit

Permalink
Add tests for Reshape distribution; fix some bugs (#813)
Browse files Browse the repository at this point in the history
* Add tests for Reshape distribution; fix some bugs

* Add more tests; check for extra_event_dims overflow

* Rebalance unit tests

* Remove extraneous .reshape(extra_event_dims=1)
  • Loading branch information
fritzo authored and martinjankowiak committed Feb 26, 2018
1 parent 1bbfb48 commit 19a35d1
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 8 deletions.
11 changes: 7 additions & 4 deletions pyro/distributions/torch_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ class Reshape(TorchDistribution):
"""
def __init__(self, base_dist, sample_shape=torch.Size(), extra_event_dims=0):
sample_shape = torch.Size(sample_shape)
if extra_event_dims > len(sample_shape + base_dist.batch_shape):
raise ValueError('Expected extra_event_dims <= len(sample_shape + base_dist.batch_shape), '
'actual {} vs {}'.format(extra_event_dims, len(sample_shape + base_dist.batch_shape)))
self.base_dist = base_dist
self.sample_shape = sample_shape
self.extra_event_dims = extra_event_dims
Expand All @@ -167,10 +170,10 @@ def has_enumerate_support(self):
return self.base_dist.has_enumerate_support

def sample(self, sample_shape=torch.Size()):
return self.base_dist.sample(self.sample_shape + sample_shape)
return self.base_dist.sample(sample_shape + self.sample_shape)

def rsample(self, sample_shape=torch.Size()):
return self.base_dist.rsample(self.sample_shape + sample_shape)
return self.base_dist.rsample(sample_shape + self.sample_shape)

def log_prob(self, value):
return sum_rightmost(self.base_dist.log_prob(value), self.extra_event_dims)
Expand All @@ -194,8 +197,8 @@ def enumerate_support(self):

@property
def mean(self):
return self.base_dist.mean
return self.base_dist.mean.expand(self.batch_shape + self.event_shape)

@property
def variance(self):
return self.base_dist.variance
return self.base_dist.variance.expand(self.batch_shape + self.event_shape)
110 changes: 110 additions & 0 deletions tests/distributions/test_reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from __future__ import absolute_import, division, print_function

import pytest
import torch
from torch.autograd import Variable

from pyro.distributions.torch import Bernoulli


def test_sample_shape_order():
shape12 = torch.Size((1, 2))
shape34 = torch.Size((3, 4))
d = Bernoulli(0.5)

# .reshape(sample_shape=...) should add dimensions on the left.
actual = d.reshape(shape34).reshape(shape12)
expected = d.reshape(shape12 + shape34)
assert actual.event_shape == expected.event_shape
assert actual.batch_shape == expected.batch_shape


@pytest.mark.parametrize('batch_dim', [0, 1, 2])
@pytest.mark.parametrize('event_dim', [0, 1, 2])
def test_idempotent(batch_dim, event_dim):
shape = torch.Size((1, 2, 3, 4))[:batch_dim + event_dim]
batch_shape = shape[:batch_dim]
event_shape = shape[batch_dim:]

# Construct a base dist of desired starting shape.
dist0 = Bernoulli(0.5).reshape(sample_shape=shape, extra_event_dims=event_dim)
assert dist0.batch_shape == batch_shape
assert dist0.event_shape == event_shape

# Check that an empty .reshape() is a no-op.
dist = dist0.reshape()
assert dist.batch_shape == dist0.batch_shape
assert dist.event_shape == dist0.event_shape


@pytest.mark.parametrize('sample_dim,extra_event_dims',
[(s, e) for s in range(4) for e in range(4 + s)])
def test_reshape(sample_dim, extra_event_dims):
batch_dim = 3
batch_shape, event_shape = torch.Size((5, 4, 3)), torch.Size()
sample_shape = torch.Size((8, 7, 6))[3 - sample_dim:]
shape = sample_shape + batch_shape + event_shape

# Construct a base dist of desired starting shape.
dist0 = Bernoulli(Variable(0.5 * torch.ones(batch_shape)))
assert dist0.event_shape == event_shape
assert dist0.batch_shape == batch_shape

# Check that reshaping has the desired final shape.
dist = dist0.reshape(sample_shape, extra_event_dims)
sample = dist.sample()
assert sample.shape == shape
assert dist.mean.shape == shape
assert dist.variance.shape == shape
assert dist.log_prob(sample).shape == shape[:sample_dim + batch_dim - extra_event_dims]
assert dist.enumerate_support().shape == torch.Size((2,)) + shape


@pytest.mark.parametrize('sample_dim,extra_event_dims',
[(s, e) for s in range(3) for e in range(3 + s)])
def test_reshape_reshape(sample_dim, extra_event_dims):
batch_dim = 2
batch_shape, event_shape = torch.Size((6, 5)), torch.Size((4, 3))
sample_shape = torch.Size((8, 7))[2 - sample_dim:]
shape = sample_shape + batch_shape + event_shape

# Construct a base dist of desired starting shape.
dist0 = Bernoulli(Variable(0.5 * torch.ones(event_shape)))
dist1 = dist0.reshape(sample_shape=batch_shape, extra_event_dims=2)
assert dist1.event_shape == event_shape
assert dist1.batch_shape == batch_shape

# Check that reshaping has the desired final shape.
dist = dist1.reshape(sample_shape, extra_event_dims)
sample = dist.sample()
assert sample.shape == shape
assert dist.mean.shape == shape
assert dist.variance.shape == shape
assert dist.log_prob(sample).shape == shape[:sample_dim + batch_dim - extra_event_dims]
assert dist.enumerate_support().shape == torch.Size((2,)) + shape


@pytest.mark.parametrize('sample_dim', [0, 1, 2])
@pytest.mark.parametrize('batch_dim', [0, 1, 2])
@pytest.mark.parametrize('event_dim', [0, 1, 2])
def test_extra_event_dim_overflow(sample_dim, batch_dim, event_dim):
shape = torch.Size(range(sample_dim + batch_dim + event_dim))
sample_shape = shape[:sample_dim]
batch_shape = shape[sample_dim:sample_dim+batch_dim]
event_shape = shape[sample_dim + batch_dim:]

# Construct a base dist of desired starting shape.
dist0 = Bernoulli(0.5).reshape(sample_shape=batch_shape + event_shape, extra_event_dims=event_dim)
assert dist0.batch_shape == batch_shape
assert dist0.event_shape == event_shape

# Check .reshape(extra_event_dims=...) for valid values.
for extra_event_dims in range(1 + sample_dim + batch_dim):
dist = dist0.reshape(sample_shape=sample_shape, extra_event_dims=extra_event_dims)
assert dist.batch_shape == shape[:sample_dim + batch_dim - extra_event_dims]
assert dist.event_shape == shape[sample_dim + batch_dim - extra_event_dims:]

# Check .reshape(extra_event_dims=...) for invalid values.
for extra_event_dims in range(1 + sample_dim + batch_dim, 20):
with pytest.raises(ValueError):
dist0.reshape(sample_shape=sample_shape, extra_event_dims=extra_event_dims)
4 changes: 2 additions & 2 deletions tests/infer/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def gmm_model(data, verbose=False):
assert z.shape[-1:] == (1,)
z = z.long()
if verbose:
logger.debug("M{} z_{} = {}".format(" " * i, i, z))
logger.debug("M{} z_{} = {}".format(" " * i, i, z.numpy()))
pyro.observe("x_{}".format(i), dist.Normal(mus[z], sigma), data[i])


Expand All @@ -118,7 +118,7 @@ def gmm_guide(data, verbose=False):
assert z.shape[-1:] == (1,)
z = z.long()
if verbose:
logger.debug("G{} z_{} = {}".format(" " * i, i, z))
logger.debug("G{} z_{} = {}".format(" " * i, i, z.numpy()))


@pytest.mark.parametrize("data_size", [1, 2, 3])
Expand Down
3 changes: 3 additions & 0 deletions tests/infer/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def param_abs_error(name, target):
return torch.sum(torch.abs(target - pyro.param(name))).data.cpu().numpy()[0]


@pytest.mark.stage("integration", "integration_batch_1")
class NormalNormalTests(TestCase):

def setUp(self):
Expand Down Expand Up @@ -217,6 +218,7 @@ def guide():
assert_equal(0.0, beta_error, prec=0.08)


@pytest.mark.stage("integration", "integration_batch_1")
class ExponentialGammaTests(TestCase):
def setUp(self):
# exponential-gamma model
Expand Down Expand Up @@ -273,6 +275,7 @@ def guide():
assert_equal(0.0, beta_error, prec=0.08)


@pytest.mark.stage("integration", "integration_batch_2")
class BernoulliBetaTests(TestCase):
def setUp(self):
# bernoulli-beta model
Expand Down
4 changes: 2 additions & 2 deletions tests/infer/test_valid_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,8 @@ def test_enum_discrete_parallel_nested_ok(max_iarange_nesting):
def model():
p2 = Variable(torch.ones(2) / 2)
p3 = Variable(torch.ones(3) / 3)
x2 = pyro.sample("x2", dist.OneHotCategorical(p2).reshape(extra_event_dims=1))
x3 = pyro.sample("x3", dist.OneHotCategorical(p3).reshape(extra_event_dims=1))
x2 = pyro.sample("x2", dist.OneHotCategorical(p2))
x3 = pyro.sample("x3", dist.OneHotCategorical(p3))
assert x2.shape == torch.Size([2]) + iarange_shape + p2.shape
assert x3.shape == torch.Size([3, 1]) + iarange_shape + p3.shape

Expand Down

0 comments on commit 19a35d1

Please sign in to comment.