Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not recalculate gradient in NUTS #1730

Merged
merged 4 commits into from
Feb 2, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymc3/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False,
if theano_kwargs is None:
theano_kwargs = {}

self.H, self.compute_energy, self.leapfrog, self._vars = get_theano_hamiltonian_functions(
self.H, self.compute_energy, self.leapfrog, self.dlogp = get_theano_hamiltonian_functions(
vars, shared, model.logpt, self.potential, use_single_leapfrog, **theano_kwargs)

super(BaseHMC, self).__init__(vars, shared, blocked=blocked)
81 changes: 44 additions & 37 deletions pymc3/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
from collections import namedtuple

from ..arraystep import Competence
from .base_hmc import BaseHMC
from pymc3.vartypes import continuous_types
from pymc3.theanof import floatX
from pymc3.vartypes import continuous_types

import numpy as np
import numpy.random as nr

__all__ = ['NUTS']


BinaryTree = namedtuple('BinaryTree',
'q, p, q_grad, proposal, leaf_size, is_valid_sample, p_accept, n_proposals')


def bern(p):
return np.random.uniform() < p
return nr.uniform() < p


class NUTS(BaseHMC):
Expand Down Expand Up @@ -72,36 +79,35 @@ def astep(self, q0):
u = floatX(nr.uniform())

q = qn = qp = q0
qn_grad = qp_grad = self.dlogp(q)
pn = pp = p0

tree_size, depth = 1., 0
keep_sampling = True

while keep_sampling:
direction = bern(0.5) * 2 - 1
q_edge, p_edge = {-1: (qn, pn), 1: (qp, pp)}[direction]
q_edge, p_edge, q_edge_grad = {-1: (qn, pn, qn_grad), 1: (qp, pp, qp_grad)}[direction]

q_edge, p_edge, proposal, subtree_size, is_valid_sample, a, na = buildtree(
self.leapfrog, q_edge, p_edge,
u, direction, depth,
step_size, self.Emax, start_energy)
tree = buildtree(self.leapfrog, q_edge, p_edge, q_edge_grad, u, direction,
depth, step_size, self.Emax, start_energy)

if direction == -1:
qn, pn = q_edge, p_edge
qn, pn, qn_grad = tree.q, tree.p, tree.q_grad
else:
qp, pp = q_edge, p_edge
qp, pp, qp_grad = tree.q, tree.p, tree.q_grad

if is_valid_sample and bern(min(1, subtree_size / tree_size)):
q = proposal
if tree.is_valid_sample and bern(min(1, tree.leaf_size / tree_size)):
q = tree.proposal

tree_size += subtree_size
tree_size += tree.leaf_size

span = qp - qn
keep_sampling = is_valid_sample and (span.dot(pn) >= 0) and (span.dot(pp) >= 0)
keep_sampling = tree.is_valid_sample and (span.dot(pn) >= 0) and (span.dot(pp) >= 0)
depth += 1

w = 1. / (self.m + self.t0)
self.h_bar = (1 - w) * self.h_bar + w * (self.target_accept - a * 1. / na)
self.h_bar = ((1 - w) * self.h_bar +
w * (self.target_accept - tree.p_accept * 1. / tree.n_proposals))

if self.tune:
self.log_step_size = self.mu - self.h_bar * np.sqrt(self.m) / self.gamma
Expand All @@ -119,34 +125,35 @@ def competence(var):
return Competence.INCOMPATIBLE


def buildtree(leapfrog, q, p, u, direction, depth, step_size, Emax, start_energy):
def buildtree(leapfrog, q, p, q_grad, u, direction, depth, step_size, Emax, start_energy):
if depth == 0:
q_edge, p_edge, new_energy = leapfrog(q, p,
floatX(np.asarray(direction * step_size)))
epsilon = floatX(np.asarray(direction * step_size))
q, p, q_grad, new_energy = leapfrog(q, p, q_grad, epsilon)
energy_change = new_energy - start_energy

leaf_size = int(np.log(u) + energy_change <= 0)
is_valid_sample = (np.log(u) + energy_change < Emax)
return q_edge, p_edge, q_edge, leaf_size, is_valid_sample, min(1, np.exp(-energy_change)), 1
p_accept = min(1, np.exp(-energy_change))
return BinaryTree(q, p, q_grad, q, leaf_size, is_valid_sample, p_accept, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This use of namedtuple definitely cleans things up a bit 👍.

else:
depth -= 1

q, p, proposal, tree_size, is_valid_sample, a1, na1 = buildtree(
leapfrog, q, p, u, direction, depth, step_size, Emax, start_energy)

if is_valid_sample:
q_edge, p_edge, new_proposal, subtree_size, is_valid_subsample, a11, na11 = buildtree(
leapfrog, q, p, u, direction, depth, step_size, Emax, start_energy)
tree = buildtree(leapfrog, q, p, q_grad, u, direction, depth, step_size, Emax, start_energy)

tree_size += subtree_size
if bern(subtree_size * 1. / max(tree_size, 1)):
proposal = new_proposal

a1 += a11
na1 += na11
span = direction * (q_edge - q)
is_valid_sample = is_valid_subsample and (span.dot(p_edge) >= 0) and (span.dot(p) >= 0)
if tree.is_valid_sample:
subtree = buildtree(leapfrog, tree.q, tree.p, tree.q_grad, u, direction, depth,
step_size, Emax, start_energy)
if bern(subtree.leaf_size * 1. / max(subtree.leaf_size + tree.leaf_size, 1)):
proposal = subtree.proposal
else:
proposal = tree.proposal
leaf_size = subtree.leaf_size + tree.leaf_size
p_accept = subtree.p_accept + tree.p_accept
n_proposals = subtree.n_proposals + tree.n_proposals
span = direction * (subtree.q - tree.q)
is_valid_sample = (subtree.is_valid_sample and
span.dot(subtree.p) >= 0 and
span.dot(tree.p) >= 0)
q, p, q_grad = subtree.q, subtree.p, subtree.q_grad
return BinaryTree(q, p, q_grad, proposal, leaf_size, is_valid_sample, p_accept, n_proposals)
else:
q_edge, p_edge = q, p

return q_edge, p_edge, proposal, tree_size, is_valid_sample, a1, na1
return tree
20 changes: 12 additions & 8 deletions pymc3/step_methods/hmc/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ def _theano_hamiltonian(model_vars, shared, logpt, potential):
"""
dlogp = gradient(logpt, model_vars)
(logp, dlogp), q = join_nonshared_inputs([logpt, dlogp], model_vars, shared)
dlogp_func = theano.function(inputs=[q], outputs=dlogp)
dlogp_func.trust_input = True
logp = CallableTensor(logp)
dlogp = CallableTensor(dlogp)
return Hamiltonian(logp, dlogp, potential), q
return Hamiltonian(logp, dlogp, potential), q, dlogp_func


def _theano_energy_function(H, q, **theano_kwargs):
Expand Down Expand Up @@ -105,13 +107,13 @@ def get_theano_hamiltonian_functions(model_vars, shared, logpt, potential,
leapfrog_integrator : theano function integrating the Hamiltonian from a point in phase space
theano_variables : dictionary of variables used in the computation graph which may be useful
"""
H, q = _theano_hamiltonian(model_vars, shared, logpt, potential)
H, q, dlogp = _theano_hamiltonian(model_vars, shared, logpt, potential)
energy_function, p = _theano_energy_function(H, q, **theano_kwargs)
if use_single_leapfrog:
leapfrog_integrator = _theano_single_leapfrog(H, q, p, **theano_kwargs)
leapfrog_integrator = _theano_single_leapfrog(H, q, p, H.dlogp(q), **theano_kwargs)
else:
leapfrog_integrator = _theano_leapfrog_integrator(H, q, p, **theano_kwargs)
return H, energy_function, leapfrog_integrator, {'q': q, 'p': p}
return H, energy_function, leapfrog_integrator, dlogp


def energy(H, q, p):
Expand Down Expand Up @@ -165,7 +167,7 @@ def full_update(p, q):
return q, p


def _theano_single_leapfrog(H, q, p, **theano_kwargs):
def _theano_single_leapfrog(H, q, p, q_grad, **theano_kwargs):
"""Leapfrog integrator for a single step.

See above for documentation. This is optimized for the case where only a single step is
Expand All @@ -174,11 +176,13 @@ def _theano_single_leapfrog(H, q, p, **theano_kwargs):
epsilon = tt.scalar('epsilon')
epsilon.tag.test_value = 1.

p_new = p + 0.5 * epsilon * H.dlogp(q) # half momentum update
p_new = p + 0.5 * epsilon * q_grad # half momentum update
q_new = q + epsilon * H.pot.velocity(p_new) # full position update
p_new += 0.5 * epsilon * H.dlogp(q_new) # half momentum update
q_new_grad = H.dlogp(q_new)
p_new += 0.5 * epsilon * q_new_grad # half momentum update
energy_new = energy(H, q_new, p_new)

f = theano.function(inputs=[q, p, epsilon], outputs=[q_new, p_new, energy_new], **theano_kwargs)
f = theano.function(inputs=[q, p, q_grad, epsilon],
outputs=[q_new, p_new, q_new_grad, energy_new], **theano_kwargs)
f.trust_input = True
return f
2 changes: 1 addition & 1 deletion pymc3/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def test_posterior_estimate(self):
Y_obs = Normal('Y_obs', mu=mu, sd=sigma, observed=Y)

for step_method, params in ((NUTS, {"target_accept": 0.95}), (Slice, {}), (Metropolis, {'scaling': 10.})):
trace = sample(100000, step=step_method(**params), progressbar=False, tune=1000)
trace = sample(100000, step=step_method(**params), tune=1000)
trace_ = trace[-300::5]

# We do the same for beta - using more burnin.
Expand Down