Skip to content

Commit

Permalink
Do not recalculate gradient in NUTS (#1730)
Browse files Browse the repository at this point in the history
* Add gradient as an argument

* Use named tuples for trees, fix shape

* Enable progressbar to avoid timeouts

* Use custom floatX function
  • Loading branch information
ColCarroll authored and twiecki committed Feb 2, 2017
1 parent 41ae646 commit c014311
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 47 deletions.
2 changes: 1 addition & 1 deletion pymc3/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False,
if theano_kwargs is None:
theano_kwargs = {}

self.H, self.compute_energy, self.leapfrog, self._vars = get_theano_hamiltonian_functions(
self.H, self.compute_energy, self.leapfrog, self.dlogp = get_theano_hamiltonian_functions(
vars, shared, model.logpt, self.potential, use_single_leapfrog, **theano_kwargs)

super(BaseHMC, self).__init__(vars, shared, blocked=blocked)
81 changes: 44 additions & 37 deletions pymc3/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
from collections import namedtuple

from ..arraystep import Competence
from .base_hmc import BaseHMC
from pymc3.vartypes import continuous_types
from pymc3.theanof import floatX
from pymc3.vartypes import continuous_types

import numpy as np
import numpy.random as nr

__all__ = ['NUTS']


BinaryTree = namedtuple('BinaryTree',
'q, p, q_grad, proposal, leaf_size, is_valid_sample, p_accept, n_proposals')


def bern(p):
return np.random.uniform() < p
return nr.uniform() < p


class NUTS(BaseHMC):
Expand Down Expand Up @@ -72,36 +79,35 @@ def astep(self, q0):
u = floatX(nr.uniform())

q = qn = qp = q0
qn_grad = qp_grad = self.dlogp(q)
pn = pp = p0

tree_size, depth = 1., 0
keep_sampling = True

while keep_sampling:
direction = bern(0.5) * 2 - 1
q_edge, p_edge = {-1: (qn, pn), 1: (qp, pp)}[direction]
q_edge, p_edge, q_edge_grad = {-1: (qn, pn, qn_grad), 1: (qp, pp, qp_grad)}[direction]

q_edge, p_edge, proposal, subtree_size, is_valid_sample, a, na = buildtree(
self.leapfrog, q_edge, p_edge,
u, direction, depth,
step_size, self.Emax, start_energy)
tree = buildtree(self.leapfrog, q_edge, p_edge, q_edge_grad, u, direction,
depth, step_size, self.Emax, start_energy)

if direction == -1:
qn, pn = q_edge, p_edge
qn, pn, qn_grad = tree.q, tree.p, tree.q_grad
else:
qp, pp = q_edge, p_edge
qp, pp, qp_grad = tree.q, tree.p, tree.q_grad

if is_valid_sample and bern(min(1, subtree_size / tree_size)):
q = proposal
if tree.is_valid_sample and bern(min(1, tree.leaf_size / tree_size)):
q = tree.proposal

tree_size += subtree_size
tree_size += tree.leaf_size

span = qp - qn
keep_sampling = is_valid_sample and (span.dot(pn) >= 0) and (span.dot(pp) >= 0)
keep_sampling = tree.is_valid_sample and (span.dot(pn) >= 0) and (span.dot(pp) >= 0)
depth += 1

w = 1. / (self.m + self.t0)
self.h_bar = (1 - w) * self.h_bar + w * (self.target_accept - a * 1. / na)
self.h_bar = ((1 - w) * self.h_bar +
w * (self.target_accept - tree.p_accept * 1. / tree.n_proposals))

if self.tune:
self.log_step_size = self.mu - self.h_bar * np.sqrt(self.m) / self.gamma
Expand All @@ -119,34 +125,35 @@ def competence(var):
return Competence.INCOMPATIBLE


def buildtree(leapfrog, q, p, u, direction, depth, step_size, Emax, start_energy):
def buildtree(leapfrog, q, p, q_grad, u, direction, depth, step_size, Emax, start_energy):
if depth == 0:
q_edge, p_edge, new_energy = leapfrog(q, p,
floatX(np.asarray(direction * step_size)))
epsilon = floatX(np.asarray(direction * step_size))
q, p, q_grad, new_energy = leapfrog(q, p, q_grad, epsilon)
energy_change = new_energy - start_energy

leaf_size = int(np.log(u) + energy_change <= 0)
is_valid_sample = (np.log(u) + energy_change < Emax)
return q_edge, p_edge, q_edge, leaf_size, is_valid_sample, min(1, np.exp(-energy_change)), 1
p_accept = min(1, np.exp(-energy_change))
return BinaryTree(q, p, q_grad, q, leaf_size, is_valid_sample, p_accept, 1)
else:
depth -= 1

q, p, proposal, tree_size, is_valid_sample, a1, na1 = buildtree(
leapfrog, q, p, u, direction, depth, step_size, Emax, start_energy)

if is_valid_sample:
q_edge, p_edge, new_proposal, subtree_size, is_valid_subsample, a11, na11 = buildtree(
leapfrog, q, p, u, direction, depth, step_size, Emax, start_energy)
tree = buildtree(leapfrog, q, p, q_grad, u, direction, depth, step_size, Emax, start_energy)

tree_size += subtree_size
if bern(subtree_size * 1. / max(tree_size, 1)):
proposal = new_proposal

a1 += a11
na1 += na11
span = direction * (q_edge - q)
is_valid_sample = is_valid_subsample and (span.dot(p_edge) >= 0) and (span.dot(p) >= 0)
if tree.is_valid_sample:
subtree = buildtree(leapfrog, tree.q, tree.p, tree.q_grad, u, direction, depth,
step_size, Emax, start_energy)
if bern(subtree.leaf_size * 1. / max(subtree.leaf_size + tree.leaf_size, 1)):
proposal = subtree.proposal
else:
proposal = tree.proposal
leaf_size = subtree.leaf_size + tree.leaf_size
p_accept = subtree.p_accept + tree.p_accept
n_proposals = subtree.n_proposals + tree.n_proposals
span = direction * (subtree.q - tree.q)
is_valid_sample = (subtree.is_valid_sample and
span.dot(subtree.p) >= 0 and
span.dot(tree.p) >= 0)
q, p, q_grad = subtree.q, subtree.p, subtree.q_grad
return BinaryTree(q, p, q_grad, proposal, leaf_size, is_valid_sample, p_accept, n_proposals)
else:
q_edge, p_edge = q, p

return q_edge, p_edge, proposal, tree_size, is_valid_sample, a1, na1
return tree
20 changes: 12 additions & 8 deletions pymc3/step_methods/hmc/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ def _theano_hamiltonian(model_vars, shared, logpt, potential):
"""
dlogp = gradient(logpt, model_vars)
(logp, dlogp), q = join_nonshared_inputs([logpt, dlogp], model_vars, shared)
dlogp_func = theano.function(inputs=[q], outputs=dlogp)
dlogp_func.trust_input = True
logp = CallableTensor(logp)
dlogp = CallableTensor(dlogp)
return Hamiltonian(logp, dlogp, potential), q
return Hamiltonian(logp, dlogp, potential), q, dlogp_func


def _theano_energy_function(H, q, **theano_kwargs):
Expand Down Expand Up @@ -105,13 +107,13 @@ def get_theano_hamiltonian_functions(model_vars, shared, logpt, potential,
leapfrog_integrator : theano function integrating the Hamiltonian from a point in phase space
theano_variables : dictionary of variables used in the computation graph which may be useful
"""
H, q = _theano_hamiltonian(model_vars, shared, logpt, potential)
H, q, dlogp = _theano_hamiltonian(model_vars, shared, logpt, potential)
energy_function, p = _theano_energy_function(H, q, **theano_kwargs)
if use_single_leapfrog:
leapfrog_integrator = _theano_single_leapfrog(H, q, p, **theano_kwargs)
leapfrog_integrator = _theano_single_leapfrog(H, q, p, H.dlogp(q), **theano_kwargs)
else:
leapfrog_integrator = _theano_leapfrog_integrator(H, q, p, **theano_kwargs)
return H, energy_function, leapfrog_integrator, {'q': q, 'p': p}
return H, energy_function, leapfrog_integrator, dlogp


def energy(H, q, p):
Expand Down Expand Up @@ -165,7 +167,7 @@ def full_update(p, q):
return q, p


def _theano_single_leapfrog(H, q, p, **theano_kwargs):
def _theano_single_leapfrog(H, q, p, q_grad, **theano_kwargs):
"""Leapfrog integrator for a single step.
See above for documentation. This is optimized for the case where only a single step is
Expand All @@ -174,11 +176,13 @@ def _theano_single_leapfrog(H, q, p, **theano_kwargs):
epsilon = tt.scalar('epsilon')
epsilon.tag.test_value = 1.

p_new = p + 0.5 * epsilon * H.dlogp(q) # half momentum update
p_new = p + 0.5 * epsilon * q_grad # half momentum update
q_new = q + epsilon * H.pot.velocity(p_new) # full position update
p_new += 0.5 * epsilon * H.dlogp(q_new) # half momentum update
q_new_grad = H.dlogp(q_new)
p_new += 0.5 * epsilon * q_new_grad # half momentum update
energy_new = energy(H, q_new, p_new)

f = theano.function(inputs=[q, p, epsilon], outputs=[q_new, p_new, energy_new], **theano_kwargs)
f = theano.function(inputs=[q, p, q_grad, epsilon],
outputs=[q_new, p_new, q_new_grad, energy_new], **theano_kwargs)
f.trust_input = True
return f
2 changes: 1 addition & 1 deletion pymc3/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def test_posterior_estimate(self):
Y_obs = Normal('Y_obs', mu=mu, sd=sigma, observed=Y)

for step_method, params in ((NUTS, {"target_accept": 0.95}), (Slice, {}), (Metropolis, {'scaling': 10.})):
trace = sample(100000, step=step_method(**params), progressbar=False, tune=1000)
trace = sample(100000, step=step_method(**params), tune=1000)
trace_ = trace[-300::5]

# We do the same for beta - using more burnin.
Expand Down

0 comments on commit c014311

Please sign in to comment.