Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an AutoStructured guide and StructuredReparam #2812

Merged
merged 17 commits into from
Apr 25, 2021
Merged
8 changes: 8 additions & 0 deletions docs/source/infer.autoguide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ AutoDiscreteParallel
:special-members: __call__
:show-inheritance:

AutoStructured
--------------------
.. autoclass:: pyro.infer.autoguide.AutoStructured
:members:
:undoc-members:
:special-members: __call__
:show-inheritance:

.. _autoguide-initialization:

Initialization
Expand Down
9 changes: 9 additions & 0 deletions docs/source/infer.reparam.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,12 @@ Neural Transport
:member-order: bysource
:special-members: __call__
:show-inheritance:

Structured Preconditioning
--------------------------
.. automodule:: pyro.infer.reparam.structured
:members:
:undoc-members:
:member-order: bysource
:special-members: __call__
:show-inheritance:
2 changes: 2 additions & 0 deletions pyro/infer/autoguide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
AutoMultivariateNormal,
AutoNormal,
AutoNormalizingFlow,
AutoStructured,
)
from pyro.infer.autoguide.initialization import (
init_to_feasible,
Expand All @@ -41,6 +42,7 @@
'AutoMultivariateNormal',
'AutoNormal',
'AutoNormalizingFlow',
'AutoStructured',
'init_to_feasible',
'init_to_generated',
'init_to_mean',
Expand Down
292 changes: 287 additions & 5 deletions pyro/infer/autoguide/guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def model():
import operator
import warnings
import weakref
from contextlib import ExitStack # python 3
from collections import defaultdict
from contextlib import ExitStack
from typing import Callable, Dict, Union

import torch
from torch import nn
Expand Down Expand Up @@ -454,7 +456,6 @@ def _setup_prototype(self, *args, **kwargs):
super()._setup_prototype(*args, **kwargs)

self._event_dims = {}
self._cond_indep_stacks = {}
self.locs = PyroModule()
self.scales = PyroModule()

Expand All @@ -466,9 +467,6 @@ def _setup_prototype(self, *args, **kwargs):
event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim()
self._event_dims[name] = event_dim

# Collect independence contexts.
self._cond_indep_stacks[name] = site["cond_indep_stack"]

# If subsampling, repeat init_value to full size.
for frame in site["cond_indep_stack"]:
full_size = getattr(frame, "full_size", frame.size)
Expand Down Expand Up @@ -1190,3 +1188,287 @@ def forward(self, *args, **kwargs):
result[name] = pyro.sample(name, discrete_dist, infer={"enumerate": "parallel"})

return result


def _config_auxiliary(msg):
return {"is_auxiliary": True}


class AutoStructured(AutoGuide):
"""
Structured guide whose conditional distributions are Delta, Normal,
MultivariateNormal, or by a callable, and whose latent variables can depend
on each other either linearly (in unconstrained space) or via shearing by a
callable.

Usage::

def model(data):
x = pyro.sample("x", dist.LogNormal(0, 1))
with pyro.plate("plate", len(data)):
y = pyro.sample("y", dist.Normal(0, 1))
pyro.sample("z", dist.Normal(y, x), obs=data)

guide = AutoStructured(
model=model,
conditionals={"x": "normal", "y": "normal"},
dependencies={"x": {"y": "linear"}},
)

Once trained, this guide can be used with
:class:`~pyro.infer.reparam.structured.StructuredReparam` to precondition a
model for use in HMC and NUTS inference.

.. note:: If you declare a dependency of a high-dimensional downstream
variable on a low-dimensional upstream variable, you may want to use
a lower learning rate for that weight, e.g.::

def optim_config(param_name):
config = {"lr": 0.01}
if "deps.my_downstream.my_upstream" in param_name:
config["lr"] *= 0.1
return config

adam = pyro.optim.Adam(optim_config)

:param callable model: A Pyro model.
:param conditionals: Family of distribution with which to model each latent
variable's conditional posterior. This should be a dict mapping each
latent variable name to either a string in ("delta", "normal", or
"mvn") or to a callable that returns a sample from a zero mean (or
approximately centered) noise distribution (such callables typically
call ``pyro.param()`` and ``pyro.sample()`` internally).
:param dependencies: Dict mapping each site name to a dict of its upstream
dependencies; each inner dict maps upstream site name to either the
string "linear" or a callable that maps a *flattened* upstream
perturbation to *flattened* downstream perturbation. The string
"linear" is equivalent to ``nn.Linear(upstream.numel(),
downstream.numel(), bias=False)``. Dependencies must not contain
cycles or self-loops.
:param callable init_loc_fn: A per-site initialization function.
See :ref:`autoguide-initialization` section for available functions.
:param float init_scale: Initial scale for the standard deviation of each
(unconstrained transformed) latent variable.
:param callable create_plates: An optional function inputing the same
``*args,**kwargs`` as ``model()`` and returning a :class:`pyro.plate`
or iterable of plates. Plates not returned will be created
automatically as usual. This is useful for data subsampling.
"""

scale_constraint = constraints.softplus_positive
scale_tril_constraint = constraints.softplus_lower_cholesky

def __init__(
self,
model,
*,
conditionals: Dict[str, Union[str, Callable]] = "normal",
dependencies: Dict[str, Dict[str, Union[str, Callable]]] = "linear",
init_loc_fn=init_to_feasible,
init_scale=0.1,
create_plates=None,
):
assert isinstance(conditionals, dict)
for name, fn in conditionals.items():
assert isinstance(name, str)
assert isinstance(fn, str) or callable(fn)
assert isinstance(dependencies, dict)
for downstream, deps in dependencies.items():
assert downstream in conditionals
assert isinstance(deps, dict)
for upstream, dep in deps.items():
assert upstream in conditionals
assert upstream != downstream
assert isinstance(dep, str) or callable(dep)
if conditionals[upstream] == "delta":
raise ValueError(
f"Site {repr(downstream)} cannot depend on "
f"upstream point-estimated site {repr(upstream)}"
)
self.conditionals = conditionals
self.dependencies = dependencies

if not isinstance(init_scale, float) or not (init_scale > 0):
raise ValueError(f"Expected init_scale > 0. but got {init_scale}")
self._init_scale = init_scale
model = InitMessenger(init_loc_fn)(model)
super().__init__(model, create_plates=create_plates)

def _setup_prototype(self, *args, **kwargs):
super()._setup_prototype(*args, **kwargs)

self.locs = PyroModule()
self.scales = PyroModule()
self.scale_trils = PyroModule()
self.conds = PyroModule()
self.deps = PyroModule()
self._unconstrained_shapes = {}

# Collect unconstrained shapes.
init_locs = {}
numel = {}
for name, site in self.prototype_trace.iter_stochastic_nodes():
with helpful_support_errors(site):
init_loc = biject_to(site["fn"].support).inv(site["value"].detach()).detach()
self._unconstrained_shapes[name] = init_loc.shape
numel[name] = init_loc.numel()
init_locs[name] = init_loc.reshape(-1)

# Initialize guide params.
children = defaultdict(list)
num_pending = {}
for name, site in self.prototype_trace.iter_stochastic_nodes():
# Initialize location parameters.
init_loc = init_locs[name]
_deep_setattr(self.locs, name, PyroParam(init_loc))

# Initialize parameters of conditional distributions.
conditional = self.conditionals[name]
if callable(conditional):
_deep_setattr(self.conds, name, conditional)
else:
if conditional not in ("delta", "normal", "mvn"):
raise ValueError(f"Unsupported conditional type: {conditional}")
if conditional in ("normal", "mvn"):
init_scale = torch.full_like(init_loc, self._init_scale)
_deep_setattr(self.scales, name,
PyroParam(init_scale, self.scale_constraint))
if conditional == "mvn":
init_scale_tril = eye_like(init_loc, init_loc.numel())
_deep_setattr(self.scale_trils, name,
PyroParam(init_scale_tril, self.scale_tril_constraint))

# Initialize dependencies on upstream variables.
num_pending[name] = 0
deps = PyroModule()
_deep_setattr(self.deps, name, deps)
for upstream, dep in self.dependencies.get(name, {}).items():
assert upstream in self.prototype_trace.nodes
children[upstream].append(name)
num_pending[name] += 1
if isinstance(dep, str) and dep == "linear":
dep = torch.nn.Linear(numel[upstream], numel[name], bias=False)
dep.weight.data.zero_()
elif not callable(dep):
raise ValueError(
f"Expected either the string 'linear' or a callable, but got {dep}"
)
_deep_setattr(deps, upstream, dep)

# Topologically sort sites.
self._sorted_sites = []
while num_pending:
name, count = min(num_pending.items(), key=lambda kv: (kv[1], kv[0]))
assert count == 0, f"cyclic dependency: {name}"
del num_pending[name]
for child in children[name]:
num_pending[child] -= 1
self._sorted_sites.append((name, self.prototype_trace.nodes[name]))

@poutine.infer_config(config_fn=_config_auxiliary)
def get_deltas(self, save_params=None):
deltas = {}
aux_values = {}
compute_density = poutine.get_mask() is not False
for name, site in self._sorted_sites:
if save_params is not None and name not in save_params:
continue

# Sample zero-mean blockwise independent Delta/Normal/MVN.
log_density = 0.0
loc = _deep_getattr(self.locs, name)
zero = torch.zeros_like(loc)
conditional = self.conditionals[name]
if callable(conditional):
aux_value = _deep_getattr(self.conds, name)()
elif conditional == "delta":
aux_value = zero
elif conditional == "normal":
aux_value = pyro.sample(
name + "_aux",
dist.Normal(zero, 1).to_event(1),
infer={"is_auxiliary": True},
)
scale = _deep_getattr(self.scales, name)
aux_value = aux_value * scale
if compute_density:
log_density = log_density - scale.log().sum(-1)
elif conditional == "mvn":
# This overparametrizes by learning (scale,scale_tril),
# enabling faster learning of the more-global scale parameter.
aux_value = pyro.sample(
name + "_aux",
dist.Normal(zero, 1).to_event(1),
infer={"is_auxiliary": True},
)
scale = _deep_getattr(self.scales, name)
scale_tril = _deep_getattr(self.scale_trils, name)
aux_value = aux_value @ scale_tril.T * scale
if compute_density:
log_density = (
log_density - scale.log().sum(-1)
- scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
)
else:
raise ValueError(f"Unsupported conditional type: {conditional}")

# Accumulate upstream dependencies.
# Note: by accumulating upstream dependencies before updating the
# aux_values dict, we encode a block-sparse structure of the
# precision matrix; if we had instead accumulated after updating
# aux_values, we would encode a block-sparse structure of the
# covariance matrix.
# Note: these shear transforms have no effect on the Jacobian
# determinant, and can therefore be excluded from the log_density
# computation below, even for nonlinear dep().
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
deps = _deep_getattr(self.deps, name)
for upstream in self.dependencies.get(name, {}):
dep = _deep_getattr(deps, upstream)
aux_value = aux_value + dep(aux_values[upstream])
aux_values[name] = aux_value

# Shift by loc, reshape, and transform to constrained space.
unconstrained = aux_value + loc
shape = self._unconstrained_shapes[name]
if torch._C._get_tracing_state() or unconstrained.shape != shape:
sample_shape = unconstrained.shape[:-1]
unconstrained = unconstrained.reshape(sample_shape + shape)
transform = biject_to(site["fn"].support)
value = transform(unconstrained)

# Create a Delta distribution.
if compute_density and conditional != "delta":
ldj = transform.inv.log_abs_det_jacobian(value, unconstrained)
ldj = sum_rightmost(ldj, ldj.dim() - value.dim() + site["fn"].event_dim)
log_density = log_density + ldj
deltas[name] = dist.Delta(value, log_density, site["fn"].event_dim)

return deltas

def forward(self, *args, **kwargs):
if self.prototype_trace is None:
self._setup_prototype(*args, **kwargs)

deltas = self.get_deltas()
plates = self._create_plates(*args, **kwargs)
result = {}
for name, site in self._sorted_sites:
with ExitStack() as stack:
for frame in site["cond_indep_stack"]:
if frame.vectorized:
stack.enter_context(plates[frame.name])
result[name] = pyro.sample(name, deltas[name])

return result

@torch.no_grad()
def median(self, *args, **kwargs):
result = {}
for name, site in self._sorted_sites:
loc = _deep_getattr(self.locs, name).detach()
shape = self._unconstrained_shapes[name]
if loc.shape != shape:
sample_shape = loc.shape[:-1]
loc = loc.reshape(sample_shape + shape)
result[name] = biject_to(site["fn"].support)(loc)
return result
2 changes: 2 additions & 0 deletions pyro/infer/reparam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .softmax import GumbelSoftmaxReparam
from .split import SplitReparam
from .stable import LatentStableReparam, StableReparam, SymmetricStableReparam
from .structured import StructuredReparam
from .studentt import StudentTReparam
from .transform import TransformReparam
from .unit_jacobian import UnitJacobianReparam
Expand All @@ -28,6 +29,7 @@
"ProjectedNormalReparam",
"SplitReparam",
"StableReparam",
"StructuredReparam",
"StudentTReparam",
"SymmetricStableReparam",
"TransformReparam",
Expand Down
Loading