Skip to content

Commit

Permalink
Modifies PHMC to support momentum distribution with batch size differ…
Browse files Browse the repository at this point in the history
…ent from chain shape.

PiperOrigin-RevId: 360498534
  • Loading branch information
brianwa84 authored and tensorflower-gardener committed Mar 2, 2021
1 parent 12e69e0 commit 866a3f1
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 12 deletions.
2 changes: 2 additions & 0 deletions tensorflow_probability/python/experimental/mcmc/BUILD
Expand Up @@ -267,9 +267,11 @@ multi_substrate_py_library(
srcs_version = "PY3",
deps = [
# tensorflow dep,
"//tensorflow_probability/python/distributions:batch_broadcast",
"//tensorflow_probability/python/distributions:independent",
"//tensorflow_probability/python/distributions:joint_distribution_sequential",
"//tensorflow_probability/python/distributions:normal",
"//tensorflow_probability/python/distributions:sample",
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/internal:prefer_static",
"//tensorflow_probability/python/internal:samplers",
Expand Down
Expand Up @@ -22,9 +22,10 @@
# Dependency imports
import tensorflow.compat.v2 as tf

from tensorflow_probability.python.distributions import independent
from tensorflow_probability.python.distributions import batch_broadcast
from tensorflow_probability.python.distributions import joint_distribution_sequential as jds
from tensorflow_probability.python.distributions import normal
from tensorflow_probability.python.distributions import sample
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.mcmc import hmc
Expand All @@ -47,6 +48,9 @@ class UncalibratedPreconditionedHamiltonianMonteCarloKernelResults(
__slots__ = ()


DefaultStandardNormal = collections.namedtuple('DefaultStandardNormal', [])


class PreconditionedHamiltonianMonteCarlo(hmc.HamiltonianMonteCarlo):
"""Hamiltonian Monte Carlo, with given momentum distribution.
Expand Down Expand Up @@ -328,7 +332,7 @@ def bootstrap_results(self, init_state):

if (not self._store_parameters_in_results or
self.momentum_distribution is None):
momentum_distribution = []
momentum_distribution = DefaultStandardNormal()
else:
momentum_distribution = self.momentum_distribution
result = UncalibratedPreconditionedHamiltonianMonteCarloKernelResults(
Expand Down Expand Up @@ -452,16 +456,14 @@ def _prepare_args(target_log_prob_fn,
step_size, dtype=target_log_prob.dtype, name='step_size')

# Default momentum distribution is None, but if `store_parameters_in_results`
# is true, then `momentum_distribution` defaults to an empty list.
# In any other case, `momentum_distribution` must be a single distribution,
# so we do not have to check that the list is actually empty.
if momentum_distribution is None or isinstance(momentum_distribution, list):
# is true, then `momentum_distribution` defaults to DefaultStandardNormal().
if (momentum_distribution is None or
isinstance(momentum_distribution, DefaultStandardNormal)):
batch_rank = ps.rank(target_log_prob)
def _batched_isotropic_normal_like(state_part):
event_ndims = ps.rank(state_part) - batch_rank
return independent.Independent(
normal.Normal(ps.zeros_like(state_part), 1.),
reinterpreted_batch_ndims=event_ndims)
return sample.Sample(
normal.Normal(ps.zeros([], dtype=state_part.dtype), 1.),
ps.shape(state_part)[batch_rank:])

momentum_distribution = jds.JointDistributionSequential(
[_batched_isotropic_normal_like(state_part)
Expand All @@ -474,6 +476,16 @@ def _batched_isotropic_normal_like(state_part):
momentum_distribution = jds.JointDistributionSequential(
[momentum_distribution])

# If all underlying distributions are independent, we can offer some help.
# This code will also trigger for the output of the two blocks above.
if (isinstance(momentum_distribution, jds.JointDistributionSequential) and
not any(callable(dist_fn) for dist_fn in momentum_distribution.model)):
batch_shape = ps.shape(target_log_prob)
momentum_distribution = momentum_distribution.copy(model=[
batch_broadcast.BatchBroadcast(md, to_shape=batch_shape)
for md in momentum_distribution.model
])

if len(step_sizes) == 1:
step_sizes *= len(state_parts)
if len(state_parts) != len(step_sizes):
Expand Down
Expand Up @@ -471,8 +471,7 @@ def test_f64(self, use_default):
1, kernel=kernel, current_state=tf.ones([], tf.float64),
num_burnin_steps=5, trace_fn=None, seed=test_util.test_seed()))

# TODO(b/175787154): Enable this test
def DISABLED_test_f64_multichain(self, use_default):
def test_f64_multichain(self, use_default):
if use_default:
momentum_distribution = None
else:
Expand All @@ -487,6 +486,25 @@ def DISABLED_test_f64_multichain(self, use_default):
1, kernel=kernel, current_state=tf.ones([nchains], tf.float64),
num_burnin_steps=5, trace_fn=None, seed=test_util.test_seed()))

def test_f64_multichain_multipart(self, use_default):
if use_default:
momentum_distribution = None
else:
momentum_distribution = _make_composite_tensor(
tfd.JointDistributionSequential([
tfd.Normal(0., tf.constant(.5, dtype=tf.float64)),
tfd.Normal(0., tf.constant(.25, dtype=tf.float64))]))
kernel = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
lambda x, y: -x**2 - y**2, step_size=.5, num_leapfrog_steps=2,
momentum_distribution=momentum_distribution)
kernel = tfp.mcmc.SimpleStepSizeAdaptation(kernel, num_adaptation_steps=3)
nchains = 7
self.evaluate(tfp.mcmc.sample_chain(
1, kernel=kernel,
current_state=(tf.ones([nchains], tf.float64),
tf.ones([nchains], tf.float64)),
num_burnin_steps=5, trace_fn=None, seed=test_util.test_seed()))

def test_diag(self, use_default):
"""Test that a diagonal multivariate normal can be effectively sampled from.
Expand Down

0 comments on commit 866a3f1

Please sign in to comment.