0.2.0
Support for PyTorch 0.4
Pyro 0.2 supports PyTorch 0.4. See PyTorch release notes for comprehensive changes. The most important change is that Variable
and Tensor
have been merged, so you can now simplify
- pyro.param("my_param", Variable(torch.ones(1), requires_grad=True))
+ pyro.param("my_param", torch.ones(1))
PyTorch distributions
PyTorch's torch.distributions library is now Pyro’s main source for distribution implementations. The Pyro team helped create this library by collaborating with Adam Paszke, Alican Bozkurt, Vishwak Srinivasan, Rachit Singh, Brooks Paige, Jan-Willem Van De Meent, and many other contributors and reviewers. See the Pyro wrapper docs for wrapped PyTorch distributions and the Pyro distribution docs for Pyro-specific distributions.
Constrained parameters
Parameters can now be constrained easily using notation like
from torch.distributions import constraints
pyro.param(“sigma”, torch.ones(10), constraint=constraints.positive)
See the torch.distributions.constraints library and all of our Pyro tutorials for example usage.
Arbitrary tensor shapes
Arbitrary tensor shapes and batching are now supported in Pyro. This includes support for nested batching via iarange
and support for batched multivariate distributions. The iarange
context and irange
generator are now much more flexible and can be combined freely. With power comes complexity, so check out our tensor shapes tutorial (hint: you’ll need to use .expand_by()
and .independent()
).
Parallel enumeration
Discrete enumeration can now be parallelized. This makes it especially easy and cheap to enumerate out discrete latent variables. Check out the Gaussian Mixture Model tutorial for example usage. To use parallel enumeration, you'll need to first configure sites, then use the TraceEnum_ELBO
losss:
def model(...):
...
@config_enumerate(default="parallel") # configures sites
def guide(...):
with pyro.iarange("foo", 10):
x = pyro.sample("x", dist.Bernoulli(0.5).expand_by([10]))
...
svi = SVI(model, guide, Adam({}),
loss=TraceEnum_ELBO(max_iarange_nesting=1)) # specify loss
svi.step()
Markov chain monte carlo via HMC and NUTS
This release adds experimental support for gradient-based Markov Chain Monte Carlo inference via Hamiltonian Monte Carlo pyro.infer.HMC
and the No U-Turn Sampler pyro.infer.NUTS
. See the docs and example for details.
Gaussian Processes
A new Gaussian Process module pyro.contrib.gp provides a framework for learning with Gaussian Processes. To get started, take a look at our Gaussian Process Tutorial. Thanks to Du Phan for this extensive contribution!
Automatic guide generation
Guides can now be created automatically with the pyro.contrib.autoguide library. These work only for models with simple structure (no irange
or iarange
), and are easy to use:
from pyro.contrib.autoguide import AutoDiagNormal
def model(...):
...
guide = AutoDiagonalNormal(model)
svi = SVI(model, guide, ...)
Validation
Model validation is now available via three toggles:
pyro.enable_validation()
pyro.infer.enable_validation()
# Turns on validation for PyTorch distributions.
pyro.distributions.enable_validation()
These can also be used temporarily as context managers
# Run with validation in first step.
with pyro.validation_enabled(True):
svi.step()
# Avoid validation on subsequent steps (may miss NAN errors).
with pyro.validation_enabled(False):
for i in range(1000):
svi.step()
Rejection sampling variational inference (RSVI)
We've added support for vectorized rejection sampling in a new Rejector
distribution. See docs or RejectionStandardGamma
class for example usage.