Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement StaticSVI #1562

Closed
wants to merge 5 commits into
base: dev
from
Closed
Changes from all commits
Commits
File filter...
Filter file types
Jump to…
Jump to file or symbol
Failed to load files and symbols.
+205 −35
Diff settings

Always

Just for now

@@ -6,6 +6,14 @@ SVI
:undoc-members:
:show-inheritance:

StaticSVI
---------

.. automodule:: pyro.infer.static_svi
:members:
:undoc-members:
:show-inheritance:

ELBO
----

Copy path View file
@@ -5,10 +5,10 @@
from torch.nn import Parameter

import pyro
from pyro.contrib.gp.util import Parameterized
import pyro.distributions as dist
import pyro.infer as infer
import pyro.optim as optim
from pyro.contrib.gp.util import Parameterized
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.params import param_with_module_name


@@ -121,20 +121,20 @@ def forward(self, **kwargs):
self._call_base_model_guide = True
return self.base_model(**kwargs)

def optimize(self, optimizer=optim.Adam({}), num_steps=1000):
def optimize(self, optimizer=None, loss=None, num_steps=1000):
"""
A convenient method to optimize parameters for GPLVM model using
:class:`~pyro.infer.svi.SVI`.
:param ~optim.PyroOptim optimizer: A Pyro optimizer.
:param ~pyro.optim.optim.PyroOptim optimizer: A Pyro optimizer.
:param ~pyro.infer.elbo.ELBO loss: A Pyro loss instance.
:param int num_steps: Number of steps to run SVI.
:returns: a list of losses during the training procedure
:rtype: list
"""
if not isinstance(optimizer, optim.PyroOptim):
raise ValueError("Optimizer should be an instance of "
"pyro.optim.PyroOptim class.")
svi = infer.SVI(self.model, self.guide, optimizer, loss=infer.Trace_ELBO())
optimizer = Adam({}) if optimizer is None else optimizer
loss = Trace_ELBO() if loss is None else loss
svi = SVI(self.model, self.guide, optimizer, loss=loss)
losses = []
for i in range(num_steps):
losses.append(svi.step())
Copy path View file
@@ -2,7 +2,7 @@

from pyro.contrib.gp.util import Parameterized
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam, PyroOptim
from pyro.optim import Adam


def _zero_mean_function(x):
@@ -159,7 +159,6 @@ def set_data(self, X, y=None):
+ Making a two-layer Gaussian Process stochastic function:
>>> gpr1 = gp.models.GPRegression(X, None, kernel, name="GPR1")
>>> Z, _ = gpr1.model()
>>> gpr2 = gp.models.GPRegression(Z, y, kernel, name="GPR2")
@@ -193,19 +192,14 @@ def optimize(self, optimizer=None, loss=None, num_steps=1000):
A convenient method to optimize parameters for the Gaussian Process model
using :class:`~pyro.infer.svi.SVI`.
:param PyroOptim optimizer: A Pyro optimizer.
:param ELBO loss: A Pyro loss instance.
:param ~pyro.optim.optim.PyroOptim optimizer: A Pyro optimizer.
:param ~pyro.infer.elbo.ELBO loss: A Pyro loss instance.
:param int num_steps: Number of steps to run SVI.
:returns: a list of losses during the training procedure
:rtype: list
"""
if optimizer is None:
optimizer = Adam({})
if not isinstance(optimizer, PyroOptim):
raise ValueError("Optimizer should be an instance of "
"pyro.optim.PyroOptim class.")
if loss is None:
loss = Trace_ELBO()
optimizer = Adam({}) if optimizer is None else optimizer
loss = Trace_ELBO() if loss is None else loss
svi = SVI(self.model, self.guide, optimizer, loss=loss)
losses = []
for i in range(num_steps):
Copy path View file
@@ -7,6 +7,7 @@
from pyro.infer.importance import Importance
from pyro.infer.renyi_elbo import RenyiELBO
from pyro.infer.svi import SVI
from pyro.infer.static_svi import StaticSVI
from pyro.infer.trace_mean_field_elbo import TraceMeanField_ELBO
from pyro.infer.trace_elbo import JitTrace_ELBO, Trace_ELBO
from pyro.infer.traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO
@@ -26,6 +27,7 @@
"JitTrace_ELBO",
"RenyiELBO",
"SVI",
"StaticSVI",
"TraceEnum_ELBO",
"TraceGraph_ELBO",
"TraceMeanField_ELBO",
Copy path View file
@@ -0,0 +1,79 @@
from __future__ import absolute_import, division, print_function

import pyro
import pyro.poutine as poutine
from pyro.infer.svi import SVI
from pyro.infer.util import is_validation_enabled, torch_item


def _check_params_not_dymanically_generated(org_params, new_params):
for p in new_params:
if p not in org_params:
raise ValueError("Param `{}` is not available in original traces."
.format(pyro.get_param_store().param_name(p)))


class StaticSVI(SVI):
"""
An interface for stochastic variational inference with params in model and guide
not dynamically generated during the optimization.
:param callable model: the model (callable containing Pyro primitives)
:param callable guide: the guide (callable containing Pyro primitives)
:param ~pyro.optim.PyroOptim optim: a wrapper for a PyTorch optimizer
:param ~pyro.infer.elbo.ELBO loss: an instance of a subclass of :class:`~pyro.infer.elbo.ELBO`.
:param callable loss_and_grads: a function which takes inputs are `model`, `guide`,
and their arguments, computes loss, runs backward, and returns the loss
:param int num_samples: the number of samples for Monte Carlo posterior approximation
:param int num_steps: the number of optimization steps to take in ``run()``
"""
def __init__(self, model, guide, optim, loss, loss_and_grads=None,
num_samples=10, num_steps=0, **kwargs):
super(StaticSVI, self).__init__(model, guide, optim, loss, loss_and_grads,
num_samples, num_steps, **kwargs)
self._params = None

def _check_optim(self, optim):
if not isinstance(optim, pyro.optim.PyroOptim):
raise ValueError("Optimizer should be an instance of pyro.optim.PyroOptim class.")

def _setup(self, *args, **kwargs):
with poutine.trace(param_only=True) as param_capture:
self.loss_and_grads(self.model, self.guide, *args, **kwargs)

self._params = set(site["value"].unconstrained()
for site in param_capture.trace.nodes.values())

self._pt_optim = self.optim.pt_optim_constructor(self._params, **self.optim.pt_optim_args)

def step(self, *args, **kwargs):
"""
:returns: estimate of the loss
:rtype: float
Take a gradient step on the loss function (and any auxiliary loss functions
generated under the hood by `loss_and_grads`).
Any args or kwargs are passed to the model and guide
"""
if self._params is None:
self._setup(*args, **kwargs)

def closure():
pyro.infer.util.zero_grads(self._params)

if is_validation_enabled():
with poutine.trace(param_only=True) as param_capture:
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)

params = set(site["value"].unconstrained()
for site in param_capture.trace.nodes.values())

_check_params_not_dymanically_generated(self._params, params)
else:
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)

return loss

loss = self._pt_optim.step(closure)

return torch_item(loss)
Copy path View file
@@ -11,24 +11,24 @@

class SVI(TracePosterior):
"""
:param model: the model (callable containing Pyro primitives)
:param guide: the guide (callable containing Pyro primitives)
:param optim: a wrapper a for a PyTorch optimizer
:type optim: pyro.optim.PyroOptim
:param loss: an instance of a subclass of :class:`~pyro.infer.elbo.ELBO`.
A unified interface for stochastic variational inference in Pyro. The most
commonly used loss is ``loss=Trace_ELBO()``. See the tutorial
`SVI Part I <http://pyro.ai/examples/svi_part_i.html>`_ for a discussion.
:param callable model: the model (callable containing Pyro primitives)
:param callable guide: the guide (callable containing Pyro primitives)
:param ~pyro.optim.PyroOptim optim: a wrapper for a PyTorch optimizer
:param ~pyro.infer.elbo.ELBO loss: an instance of a subclass of :class:`~pyro.infer.elbo.ELBO`.
Pyro provides three built-in losses:
:class:`~pyro.infer.trace_elbo.Trace_ELBO`,
:class:`~pyro.infer.tracegraph_elbo.TraceGraph_ELBO`, and
:class:`~pyro.infer.traceenum_elbo.TraceEnum_ELBO`.
See the :class:`~pyro.infer.elbo.ELBO` docs to learn how to implement
a custom loss.
:type loss: pyro.infer.elbo.ELBO
:param num_samples: the number of samples for Monte Carlo posterior approximation
:param num_steps: the number of optimization steps to take in ``run()``
A unified interface for stochastic variational inference in Pyro. The most
commonly used loss is ``loss=Trace_ELBO()``. See the tutorial
`SVI Part I <http://pyro.ai/examples/svi_part_i.html>`_ for a discussion.
:param callable loss_and_grads: a function which takes inputs are `model`, `guide`,
and their arguments, computes loss, runs backward, and returns the loss
:param int num_samples: the number of samples for Monte Carlo posterior approximation
:param int num_steps: the number of optimization steps to take in ``run()``
"""
def __init__(self,
model,
@@ -39,6 +39,7 @@ def __init__(self,
num_samples=10,
num_steps=0,
**kwargs):
self._check_optim(optim)
self.model = model
self.guide = guide
self.optim = optim
@@ -59,6 +60,12 @@ def _loss_and_grads(*args, **kwargs):
self.loss = loss
self.loss_and_grads = loss_and_grads

def _check_optim(self, optim):
if not isinstance(optim, pyro.optim.PyroOptim):
raise ValueError("Optimizer should be an instance of pyro.optim.PyroOptim class.")
if isinstance(optim.pt_optim_constructor, torch.optim.LBFGS):
raise ValueError("SVI is not compatible with LBFGS optimizer.")

def run(self, *args, **kwargs):
if self.num_steps > 0:
with poutine.block():
@@ -12,9 +12,6 @@
continue
if _Optim is torch.optim.Optimizer:
continue
if _Optim is torch.optim.LBFGS:
# XXX LBFGS is not supported for SVI yet
continue

_PyroOptim = (lambda _Optim: lambda optim_args: PyroOptim(_Optim, optim_args))(_Optim)
_PyroOptim.__name__ = _name
Copy path View file
@@ -0,0 +1,83 @@
from __future__ import absolute_import, division, print_function

import pytest
import torch
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
import pyro.optim as optim
from pyro.infer import StaticSVI, Trace_ELBO

from tests.common import assert_equal


def test_inference():
alpha0 = torch.tensor(1.0)
beta0 = torch.tensor(1.0) # beta prior hyperparameter
data = torch.tensor([0.0, 1.0, 1.0, 1.0])
n_data = len(data)
data_sum = data.sum()
alpha_n = alpha0 + data_sum # posterior alpha
beta_n = beta0 - data_sum + torch.tensor(float(n_data)) # posterior beta
log_alpha_n = torch.log(alpha_n)
log_beta_n = torch.log(beta_n)

def model():
p_latent = pyro.sample("p_latent", dist.Beta(alpha0, beta0))
with pyro.plate("data", n_data):
pyro.sample("obs", dist.Bernoulli(p_latent), obs=data)
return p_latent

def guide():
alpha_q_log = pyro.param("alpha_q_log", log_alpha_n + 0.17)
beta_q_log = pyro.param("beta_q_log", log_beta_n - 0.143)
alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log)
pyro.sample("p_latent", dist.Beta(alpha_q, beta_q))

adam = optim.Adam({"lr": .001})

This comment has been minimized.

@fritzo

fritzo Nov 23, 2018

Collaborator

Can you also add a test using LBFGS?

This comment has been minimized.

@fehiepsi

fehiepsi Nov 23, 2018

Author Collaborator

@fritzo I have added another test for it. Using LBFGS for this model is quite flaky.

svi = StaticSVI(model, guide, adam, loss=Trace_ELBO())
for i in range(1000):
svi.step()

assert_equal(pyro.param("alpha_q_log"), log_alpha_n, prec=0.04)
assert_equal(pyro.param("beta_q_log"), log_beta_n, prec=0.04)


@pytest.mark.init(rng_seed=3)
def test_lbfgs():
x = 1 + torch.randn(10)
x_mean = x.mean()
x_std = x.std()

def model():
mu = pyro.param("mu", torch.tensor(0.))
sigma = pyro.param("sigma", torch.tensor(1.), constraint=constraints.positive)
with pyro.plate("plate"):
return pyro.sample("x", dist.Normal(mu, sigma), obs=x)

def guide():
pass

adam = optim.LBFGS({})
svi = StaticSVI(model, guide, adam, loss=Trace_ELBO())
svi.step()

assert_equal(pyro.param("mu"), x_mean, prec=0.01)
assert_equal(pyro.param("sigma"), x_std, prec=0.05)


def test_params_not_match():
def model(i):
p = pyro.param(str(i), torch.tensor(float(i)))
return pyro.sample("obs", dist.Normal(p, 1), obs=torch.tensor(0.))

def guide(i):
pass

adam = optim.Adam({})
svi = StaticSVI(model, guide, adam, loss=Trace_ELBO())
svi.step(i=0)

with pytest.raises(ValueError, match="Param `{}` is not available".format(1)):
svi.step(i=1)
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.