Skip to content

0.2.0

Compare
Choose a tag to compare
@fritzo fritzo released this 25 Apr 00:46
· 1377 commits to dev since this release

Support for PyTorch 0.4

Pyro 0.2 supports PyTorch 0.4. See PyTorch release notes for comprehensive changes. The most important change is that Variable and Tensor have been merged, so you can now simplify

- pyro.param("my_param", Variable(torch.ones(1), requires_grad=True))
+ pyro.param("my_param", torch.ones(1))

PyTorch distributions

PyTorch's torch.distributions library is now Pyro’s main source for distribution implementations. The Pyro team helped create this library by collaborating with Adam Paszke, Alican Bozkurt, Vishwak Srinivasan, Rachit Singh, Brooks Paige, Jan-Willem Van De Meent, and many other contributors and reviewers. See the Pyro wrapper docs for wrapped PyTorch distributions and the Pyro distribution docs for Pyro-specific distributions.

Constrained parameters

Parameters can now be constrained easily using notation like

from torch.distributions import constraints

pyro.param(“sigma”, torch.ones(10), constraint=constraints.positive)

See the torch.distributions.constraints library and all of our Pyro tutorials for example usage.

Arbitrary tensor shapes

Arbitrary tensor shapes and batching are now supported in Pyro. This includes support for nested batching via iarange and support for batched multivariate distributions. The iarange context and irange generator are now much more flexible and can be combined freely. With power comes complexity, so check out our tensor shapes tutorial (hint: you’ll need to use .expand_by() and .independent()).

Parallel enumeration

Discrete enumeration can now be parallelized. This makes it especially easy and cheap to enumerate out discrete latent variables. Check out the Gaussian Mixture Model tutorial for example usage. To use parallel enumeration, you'll need to first configure sites, then use the TraceEnum_ELBO losss:

def model(...):
    ...

@config_enumerate(default="parallel")  # configures sites
def guide(...):
    with pyro.iarange("foo", 10):
        x = pyro.sample("x", dist.Bernoulli(0.5).expand_by([10]))
        ...

svi = SVI(model, guide, Adam({}),
          loss=TraceEnum_ELBO(max_iarange_nesting=1))  # specify loss
svi.step()

Markov chain monte carlo via HMC and NUTS

This release adds experimental support for gradient-based Markov Chain Monte Carlo inference via Hamiltonian Monte Carlo pyro.infer.HMC and the No U-Turn Sampler pyro.infer.NUTS. See the docs and example for details.

Gaussian Processes

A new Gaussian Process module pyro.contrib.gp provides a framework for learning with Gaussian Processes. To get started, take a look at our Gaussian Process Tutorial. Thanks to Du Phan for this extensive contribution!

Automatic guide generation

Guides can now be created automatically with the pyro.contrib.autoguide library. These work only for models with simple structure (no irange or iarange), and are easy to use:

from pyro.contrib.autoguide import AutoDiagNormal
def model(...):
    ...
    
guide = AutoDiagonalNormal(model)
svi = SVI(model, guide, ...)

Validation

Model validation is now available via three toggles:

pyro.enable_validation()
pyro.infer.enable_validation()
# Turns on validation for PyTorch distributions.
pyro.distributions.enable_validation()

These can also be used temporarily as context managers

# Run with validation in first step.
with pyro.validation_enabled(True):
    svi.step()
# Avoid validation on subsequent steps (may miss NAN errors).
with pyro.validation_enabled(False):
    for i in range(1000):
        svi.step()

Rejection sampling variational inference (RSVI)

We've added support for vectorized rejection sampling in a new Rejector distribution. See docs or RejectionStandardGamma class for example usage.