Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 43 additions & 33 deletions pymc3/step_methods/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,40 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state
self.m = 1

shared = make_shared_replacements(vars, model)
self.leapfrog1_dE = leapfrog1_dE(
model.logpt, vars, shared, self.potential, profile=profile)

def create_hamiltonian(vars, shared, model):
dlogp = gradient(model.logpt, vars)
(logp, dlogp), q = join_nonshared_inputs(
[model.logpt, dlogp], vars, shared)
logp = CallableTensor(logp)
dlogp = CallableTensor(dlogp)

return Hamiltonian(logp, dlogp, self.potential), q

def create_energy_func(q):
p = tt.dvector('p')
p.tag.test_value = q.tag.test_value
E0 = energy(self.H, q, p)
E0_func = theano.function([q, p], E0)
E0_func.trust_input = True

return E0_func

self.H, q = create_hamiltonian(vars, shared, model)
self.compute_energy = create_energy_func(q)

self.leapfrog1_dE = leapfrog1_dE(self.H, q, profile=profile)

super(NUTS, self).__init__(vars, shared, **kwargs)

def astep(self, q0):
# Hamiltonian(self.logp, self.dlogp, self.potential)
H = self.leapfrog1_dE
leapfrog = self.leapfrog1_dE
Emax = self.Emax
e = self.step_size

p0 = self.potential.random()
E0 = self.compute_energy(q0, p0)

u = uniform()
q = qn = qp = q0
p = pn = pp = p0
Expand All @@ -115,10 +137,10 @@ def astep(self, q0):

if v == -1:
qn, pn, _, _, q1, n1, s1, a, na = buildtree(
H, qn, pn, u, v, j, e, Emax, q0, p0)
leapfrog, qn, pn, u, v, j, e, Emax, E0)
else:
_, _, qp, pp, q1, n1, s1, a, na = buildtree(
H, qp, pp, u, v, j, e, Emax, q0, p0)
leapfrog, qp, pp, u, v, j, e, Emax, E0)

if s1 == 1 and bern(min(1, n1 * 1. / n)):
q = q1
Expand Down Expand Up @@ -147,24 +169,23 @@ def competence(var):
return Competence.INCOMPATIBLE


def buildtree(H, q, p, u, v, j, e, Emax, q0, p0):
def buildtree(leapfrog1_dE, q, p, u, v, j, e, Emax, E0):
if j == 0:
leapfrog1_dE = H
q1, p1, dE = leapfrog1_dE(q, p, array(v * e), q0, p0)
q1, p1, dE = leapfrog1_dE(q, p, array(v * e), E0)

n1 = int(log(u) + dE <= 0)
s1 = int(log(u) + dE < Emax)
return q1, p1, q1, p1, q1, n1, s1, min(1, exp(-dE)), 1
else:
qn, pn, qp, pp, q1, n1, s1, a1, na1 = buildtree(
H, q, p, u, v, j - 1, e, Emax, q0, p0)
leapfrog1_dE, q, p, u, v, j - 1, e, Emax, E0)
if s1 == 1:
if v == -1:
qn, pn, _, _, q11, n11, s11, a11, na11 = buildtree(
H, qn, pn, u, v, j - 1, e, Emax, q0, p0)
leapfrog1_dE, qn, pn, u, v, j - 1, e, Emax, E0)
else:
_, _, qp, pp, q11, n11, s11, a11, na11 = buildtree(
H, qp, pp, u, v, j - 1, e, Emax, q0, p0)
leapfrog1_dE, qp, pp, u, v, j - 1, e, Emax, E0)

if bern(n11 * 1. / (max(n1 + n11, 1))):
q1 = q11
Expand All @@ -179,44 +200,33 @@ def buildtree(H, q, p, u, v, j, e, Emax, q0, p0):
return


def leapfrog1_dE(logp, vars, shared, pot, profile):
def leapfrog1_dE(H, q, profile):
"""Computes a theano function that computes one leapfrog step and the energy difference between the beginning and end of the trajectory.
Parameters
----------
logp : TensorVariable
vars : list of tensor variables
shared : list of shared variables not to compute leapfrog over
pot : quadpotential
porifle : Boolean
H : Hamiltonian
q : theano.tensor
profile : Boolean

Returns
-------
theano function which returns
q_new, p_new, delta_E
q_new, p_new, dE
"""
dlogp = gradient(logp, vars)
(logp, dlogp), q = join_nonshared_inputs([logp, dlogp], vars, shared)
logp = CallableTensor(logp)
dlogp = CallableTensor(dlogp)

H = Hamiltonian(logp, dlogp, pot)

p = tt.dvector('p')
p.tag.test_value = q.tag.test_value

q0 = tt.dvector('q0')
q0.tag.test_value = q.tag.test_value
p0 = tt.dvector('p0')
p0.tag.test_value = p.tag.test_value

e = tt.dscalar('e')
e.tag.test_value = 1

q1, p1 = leapfrog(H, q, p, 1, e)
E = energy(H, q1, p1)
E0 = energy(H, q0, p0)

E0 = tt.dscalar('E0')
E0.tag.test_value = 1

dE = E - E0

f = theano.function([q, p, e, q0, p0], [q1, p1, dE], profile=profile)
f = theano.function([q, p, e, E0], [q1, p1, dE], profile=profile)
f.trust_input = True
return f