Skip to content

Commit

Permalink
Adopt strict batch shape semantics for distributions (#806)
Browse files Browse the repository at this point in the history
* Sketch EnumeratePoutine

* Fix dimension logic in EnumerateMessenger

* Add more test examples

* Refactor ELBO

* Attempt to get batch shapes correct for enum_discrete in trace_elbo

* Simplify Trace_ELBO

* Drop special-case for enum_discrete in Trace_ELBO

* Replace enum_discrete kwarg with enumerate_discrete() function

* Completely elimitate enum_discrete kwarg

* Fix bugs in tests/infer/test_enum.py

* Rename enumerate_discrete to config_enumerate

* Add analytic KL tests for parallel enumeration

* Add test for sum_rightmost()

* Skip slow tests on travis

* Add another gradient test for enumeration

* Add TODOs for more tests

* Add failing checks for strict shape semantics

* Add variously-sized categoricals test

* Remove excruciatingly slow test

* Fix scalar error

* Flake8

* Fix zero_grads()

* Get test_valid_models.py to pass tests

* flake8

* Remove name arg to _iter_discrete_filter

* Updates per review

* Updates per review

* Remove xfailing death test that should no longer error

* Fix more tests

* Fix some integration tests

* flake8

* Remove PoissonGammaTests

* Use TorchDistribution.mask() in dmm example

* Fix shaping errors in dmm.py

* fixes to examples and tutorials

* fix integration tests

* fix xfail marker; vae test

* address comment; only use iarange for obs in test_inference

* address comment
  • Loading branch information
fritzo authored and martinjankowiak committed Mar 1, 2018
1 parent 4ff08b8 commit fe32be3
Show file tree
Hide file tree
Showing 19 changed files with 329 additions and 318 deletions.
9 changes: 5 additions & 4 deletions examples/bayesian_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def build_linear_dataset(N, p, noise_std=0.01):
y = np.matmul(X, w) + np.repeat(1, N) + np.random.normal(0, noise_std, size=N)
y = y.reshape(N, 1)
X, y = Variable(torch.Tensor(X)), Variable(torch.Tensor(y))
return torch.cat((X, y), 1)
data = torch.cat((X, y), 1)
assert data.shape == (N, p + 1)
return data


# NN with one linear layer
Expand Down Expand Up @@ -65,9 +67,8 @@ def model(data):
x_data = data[:, :-1]
y_data = data[:, -1]
# run the regressor forward conditioned on inputs
prediction_mean = lifted_reg_model(x_data)
pyro.sample("obs",
Normal(prediction_mean, Variable(torch.ones(data.size(0))).type_as(data)),
prediction_mean = lifted_reg_model(x_data).squeeze(-1)
pyro.sample("obs", Normal(prediction_mean, 1),
obs=y_data)


Expand Down
43 changes: 26 additions & 17 deletions examples/dmm/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,20 +192,24 @@ def model(self, mini_batch, mini_batch_reversed, mini_batch_mask,

# first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1})
z_mu, z_sigma = self.trans(z_prev)
# then sample z_t according to dist.Normal(z_mu, z_sigma)
with poutine.scale(None, annealing_factor):
z_t = pyro.sample("z_%d" % t,
dist.Normal(z_mu, z_sigma)
.mask(mini_batch_mask[:, t - 1:t]))

# compute the probabilities that parameterize the bernoulli likelihood
emission_probs_t = self.emitter(z_t)
# the next statement instructs pyro to observe x_t according to the
# bernoulli distribution p(x_t|z_t)
pyro.sample("obs_x_%d" % t,
dist.Bernoulli(emission_probs_t)
.mask(mini_batch_mask[:, t - 1:t]),
obs=mini_batch[:, t - 1, :])
with pyro.iarange("z_minibatch_%d" % t, len(mini_batch)):

# then sample z_t according to dist.Normal(z_mu, z_sigma)
with poutine.scale(None, annealing_factor):
z_t = pyro.sample("z_%d" % t,
dist.Normal(z_mu, z_sigma)
.mask(mini_batch_mask[:, t - 1:t])
.reshape(extra_event_dims=1))

# compute the probabilities that parameterize the bernoulli likelihood
emission_probs_t = self.emitter(z_t)
# the next statement instructs pyro to observe x_t according to the
# bernoulli distribution p(x_t|z_t)
pyro.sample("obs_x_%d" % t,
dist.Bernoulli(emission_probs_t)
.mask(mini_batch_mask[:, t - 1:t])
.reshape(extra_event_dims=1),
obs=mini_batch[:, t - 1, :])
# the latent sampled at this time step will be conditioned upon
# in the next time step so keep track of it
z_prev = z_t
Expand All @@ -228,7 +232,7 @@ def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask,
# reverse the time-ordering in the hidden state and un-pack it
rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
# set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
z_prev = self.z_q_0
z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0))

# sample the latents z one time step at a time
for t in range(1, T_max + 1):
Expand All @@ -242,10 +246,15 @@ def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask,
z_dist = TransformedDistribution(dist.Normal(z_mu, z_sigma), self.iafs)
else:
z_dist = dist.Normal(z_mu, z_sigma)
assert z_dist.event_shape == ()
assert z_dist.batch_shape == (len(mini_batch), self.z_q_0.size(0))

# sample z_t from the distribution z_dist
with pyro.poutine.scale(None, annealing_factor * mini_batch_mask[:, t - 1:t]):
z_t = pyro.sample("z_%d" % t, z_dist)
with pyro.iarange("z_minibatch_%d" % t, len(mini_batch)):
with pyro.poutine.scale(None, annealing_factor):
z_t = pyro.sample("z_%d" % t,
z_dist.mask(mini_batch_mask[:, t - 1:t])
.reshape(extra_event_dims=1))
# the latent sampled at this time step will be conditioned upon in the next time step
# so keep track of it
z_prev = z_t
Expand Down
30 changes: 16 additions & 14 deletions examples/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,25 +87,27 @@ def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):
def model(self, x):
# register PyTorch module `decoder` with Pyro
pyro.module("decoder", self.decoder)
# setup hyperparameters for prior p(z)
# the type_as ensures we get cuda Tensors if x is on gpu
z_mu = ng_zeros([x.size(0), self.z_dim], type_as=x.data)
z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data)
# sample from prior (value will be sampled by guide when computing the ELBO)
z = pyro.sample("latent", dist.Normal(z_mu, z_sigma))
# decode the latent code z
mu_img = self.decoder.forward(z)
# score against actual images
pyro.sample("obs", dist.Bernoulli(mu_img), obs=x.view(-1, 784))
with pyro.iarange("data", x.size(0)):
# setup hyperparameters for prior p(z)
# the type_as ensures we get cuda Tensors if x is on gpu
z_mu = ng_zeros([x.size(0), self.z_dim], type_as=x.data)
z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data)
# sample from prior (value will be sampled by guide when computing the ELBO)
z = pyro.sample("latent", dist.Normal(z_mu, z_sigma).reshape(extra_event_dims=1))
# decode the latent code z
mu_img = self.decoder.forward(z)
# score against actual images
pyro.sample("obs", dist.Bernoulli(mu_img).reshape(extra_event_dims=1), obs=x.view(-1, 784))

# define the guide (i.e. variational distribution) q(z|x)
def guide(self, x):
# register PyTorch module `encoder` with Pyro
pyro.module("encoder", self.encoder)
# use the encoder to get the parameters used to define q(z|x)
z_mu, z_sigma = self.encoder.forward(x)
# sample the latent code z
pyro.sample("latent", dist.Normal(z_mu, z_sigma))
with pyro.iarange("data", x.size(0)):
# use the encoder to get the parameters used to define q(z|x)
z_mu, z_sigma = self.encoder.forward(x)
# sample the latent code z
pyro.sample("latent", dist.Normal(z_mu, z_sigma).reshape(extra_event_dims=1))

# define a helper function for reconstructing images
def reconstruct_img(self, x):
Expand Down
19 changes: 10 additions & 9 deletions examples/vae_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from pyro.distributions import Normal, Bernoulli
from pyro.infer import SVI
from pyro.optim import Adam
from pyro.shim import torch_no_grad
from pyro.util import ng_zeros, ng_ones

"""
Expand Down Expand Up @@ -126,7 +125,7 @@ def test(self, epoch):
self.set_train(is_train=False)
test_loss = 0
for i, (x, _) in enumerate(self.test_loader):
with torch_no_grad():
with torch.no_grad():
x = Variable(x)
recon_x = self.model_eval(x)[0]
test_loss += self.compute_loss_and_gradient(x)
Expand Down Expand Up @@ -187,16 +186,18 @@ def __init__(self, *args, **kwargs):
def model(self, data):
decoder = pyro.module('decoder', self.vae_decoder)
z_mean, z_std = ng_zeros([data.size(0), 20]), ng_ones([data.size(0), 20])
z = pyro.sample('latent', Normal(z_mean, z_std))
img = decoder.forward(z)
pyro.sample('obs',
Bernoulli(img),
obs=data.view(-1, 784))
with pyro.iarange('data', data.size(0)):
z = pyro.sample('latent', Normal(z_mean, z_std).reshape(extra_event_dims=1))
img = decoder.forward(z)
pyro.sample('obs',
Bernoulli(img).reshape(extra_event_dims=2),
obs=data.view(-1, 784))

def guide(self, data):
encoder = pyro.module('encoder', self.vae_encoder)
z_mean, z_var = encoder.forward(data)
pyro.sample('latent', Normal(z_mean, z_var.sqrt()))
with pyro.iarange('data', data.size(0)):
z_mean, z_var = encoder.forward(data)
pyro.sample('latent', Normal(z_mean, z_var.sqrt()).reshape(extra_event_dims=1))

def compute_loss_and_gradient(self, x):
if self.mode == TRAIN:
Expand Down
19 changes: 9 additions & 10 deletions pyro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def _subsample(name, size=None, subsample_size=None, subsample=None, use_cuda=No
if size is None:
assert subsample_size is None
assert subsample is None
size = 1
subsample_size = 1
size = -1 # This is PyTorch convention for "arbitrary size"
subsample_size = -1
elif subsample is None:
names = [name]
names += [str(f.counter) for f in _PYRO_STACK if isinstance(f, poutine.IndepMessenger)]
Expand All @@ -162,8 +162,7 @@ def _subsample(name, size=None, subsample_size=None, subsample=None, use_cuda=No
subsample_size, len(subsample)) +
" Did you accidentally use different subsample_size in the model and guide?")

scale = size / subsample_size
return subsample, scale
return size, subsample_size, subsample


@contextlib.contextmanager
Expand Down Expand Up @@ -232,12 +231,12 @@ def iarange(name, size=None, subsample_size=None, subsample=None, use_cuda=None)
See `SVI Part II <http://pyro.ai/examples/svi_part_ii.html>`_ for an
extended discussion.
"""
subsample, scale = _subsample(name, size, subsample_size, subsample, use_cuda)
size, subsample_size, subsample = _subsample(name, size, subsample_size, subsample, use_cuda)
if not am_i_wrapped():
yield subsample
else:
with poutine.scale(None, scale):
with poutine.indep(name, vectorized=True):
with poutine.scale(None, size / subsample_size):
with poutine.indep(name, vectorized=True, size=subsample_size):
yield subsample


Expand Down Expand Up @@ -267,13 +266,13 @@ def irange(name, size, subsample_size=None, subsample=None, use_cuda=None):
See `SVI Part II <http://pyro.ai/examples/svi_part_ii.html>`_ for an extended discussion.
"""
subsample, scale = _subsample(name, size, subsample_size, subsample, use_cuda)
size, subsample_size, subsample = _subsample(name, size, subsample_size, subsample, use_cuda)
if not am_i_wrapped():
for i in subsample:
yield i.item() if isinstance(i, Variable) else i
else:
indep_context = poutine.indep(name, vectorized=False)
with poutine.scale(None, scale):
indep_context = poutine.indep(name, vectorized=False, size=subsample_size)
with poutine.scale(None, size / subsample_size):
for i in subsample:
indep_context.next_context()
with indep_context:
Expand Down
5 changes: 3 additions & 2 deletions pyro/contrib/gp/likelihoods/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@ def __init__(self, variance=None):

def forward(self, f, obs=None):
variance = self.get_param("variance").expand_as(f)

return pyro.sample("y", dist.Normal(f, variance), obs=obs)
event_dims = f.dim()
return pyro.sample("y", dist.Normal(f, variance).reshape(extra_event_dims=event_dims),
obs=obs)
2 changes: 1 addition & 1 deletion pyro/distributions/testing/rejection_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, alpha):
if alpha.data.min() < 1:
raise NotImplementedError('alpha < 1 is not supported')
self.alpha = alpha
self._standard_gamma = Gamma(alpha, alpha.new([1]).expand_as(alpha))
self._standard_gamma = Gamma(alpha, alpha.new([1]).squeeze().expand_as(alpha))
# The following are Marsaglia & Tsang's variable names.
self._d = self.alpha - 1.0 / 3.0
self._c = 1.0 / torch.sqrt(9.0 * self._d)
Expand Down
4 changes: 3 additions & 1 deletion pyro/infer/trace_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pyro.infer.util import torch_backward, torch_data_sum, torch_sum
from pyro.poutine.enumerate_poutine import EnumeratePoutine
from pyro.poutine.util import prune_subsample_sites
from pyro.util import check_model_guide_match, is_nan
from pyro.util import check_model_guide_match, check_site_shape, is_nan


class Trace_ELBO(ELBO):
Expand Down Expand Up @@ -43,10 +43,12 @@ def _get_traces(self, model, guide, *args, **kwargs):
model_trace.compute_batch_log_pdf()
for site in model_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
log_r = log_r + sum_rightmost(site["batch_log_pdf"], self.max_iarange_nesting)
guide_trace.compute_score_parts()
for site in guide_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
log_r = log_r - sum_rightmost(site["batch_log_pdf"], self.max_iarange_nesting)

weight = scale / self.num_particles
Expand Down
8 changes: 7 additions & 1 deletion pyro/infer/tracegraph_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pyro.infer.util import MultiViewTensor as MVT
from pyro.infer.util import torch_backward, torch_data_sum
from pyro.poutine.util import prune_subsample_sites
from pyro.util import check_model_guide_match, detach_iterable
from pyro.util import check_model_guide_match, check_site_shape, detach_iterable


def _get_baseline_options(site):
Expand Down Expand Up @@ -268,7 +268,13 @@ def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
# have the trace compute all the individual (batch) log pdf terms
# and score function terms (if present) so that they are available below
model_trace.compute_batch_log_pdf()
for site in model_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
guide_trace.compute_score_parts()
for site in guide_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)

# compute elbo for reparameterized nodes
non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
Expand Down
5 changes: 3 additions & 2 deletions pyro/poutine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,18 @@ def infer_config(fn, config_fn):
return InferConfigPoutine(fn, config_fn)


def indep(name, vectorized):
def indep(name, vectorized, size):
"""
:param str name: a name for subsample sites
:param bool vectorized: True for ``iarange``, False for ``irange``
:param int size: Size of the subsampled batch, or -1 if unknown
:rtype: pyro.poutine.IndepMessenger
Alias for IndepMessenger constructor.
Used internally by ``iarange`` and ``irange``.
"""
return IndepMessenger(name=name, vectorized=vectorized)
return IndepMessenger(name=name, vectorized=vectorized, size=size)


def scale(null, scale):
Expand Down
7 changes: 4 additions & 3 deletions pyro/poutine/indep_poutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from .poutine import Messenger

CondIndepStackFrame = namedtuple("CondIndepStackFrame", ["name", "counter", "vectorized"])
CondIndepStackFrame = namedtuple("CondIndepStackFrame", ["name", "counter", "vectorized", "size"])


class IndepMessenger(Messenger):
Expand All @@ -14,7 +14,7 @@ class IndepMessenger(Messenger):
a ``cond_indep_stack`` at each sample/observe site for consumption by
``TracePoutine``.
"""
def __init__(self, name, vectorized):
def __init__(self, name, vectorized, size):
"""
Constructor: basically default, but store a counter to keep track of
which ``irange`` branch we're in.
Expand All @@ -23,6 +23,7 @@ def __init__(self, name, vectorized):
self.name = name
self.counter = 0
self.vectorized = vectorized
self.size = size

def next_context(self):
"""
Expand All @@ -31,5 +32,5 @@ def next_context(self):
self.counter += 1

def _process_message(self, msg):
msg["cond_indep_stack"].insert(0, CondIndepStackFrame(self.name, self.counter, self.vectorized))
msg["cond_indep_stack"].insert(0, CondIndepStackFrame(self.name, self.counter, self.vectorized, self.size))
return None
20 changes: 0 additions & 20 deletions pyro/shim.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import absolute_import, division, print_function

import contextlib
import re

import torch
Expand All @@ -17,22 +16,3 @@ def parse_torch_version():
major, minor, patch = map(int, match.group(1).split("."))
extra_stuff = match.group(2)
return major, minor, patch, extra_stuff


# Polyfill to bridge the change of .volatile between PyTorch 0.3 and 0.4.
try:
# These work in PyTorch 0.4 prerelease.
torch_no_grad = torch.no_grad

def is_volatile(variable):
return False

except AttributeError:
# These work in PyTorch 0.3 and earlier.

@contextlib.contextmanager
def torch_no_grad():
yield

def is_volatile(variable):
return variable.volatile

0 comments on commit fe32be3

Please sign in to comment.