diff --git a/pymc3/step_methods/nuts.py b/pymc3/step_methods/nuts.py index f042373745..5121681256 100644 --- a/pymc3/step_methods/nuts.py +++ b/pymc3/step_methods/nuts.py @@ -92,18 +92,40 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False, state self.m = 1 shared = make_shared_replacements(vars, model) - self.leapfrog1_dE = leapfrog1_dE( - model.logpt, vars, shared, self.potential, profile=profile) + + def create_hamiltonian(vars, shared, model): + dlogp = gradient(model.logpt, vars) + (logp, dlogp), q = join_nonshared_inputs( + [model.logpt, dlogp], vars, shared) + logp = CallableTensor(logp) + dlogp = CallableTensor(dlogp) + + return Hamiltonian(logp, dlogp, self.potential), q + + def create_energy_func(q): + p = tt.dvector('p') + p.tag.test_value = q.tag.test_value + E0 = energy(self.H, q, p) + E0_func = theano.function([q, p], E0) + E0_func.trust_input = True + + return E0_func + + self.H, q = create_hamiltonian(vars, shared, model) + self.compute_energy = create_energy_func(q) + + self.leapfrog1_dE = leapfrog1_dE(self.H, q, profile=profile) super(NUTS, self).__init__(vars, shared, **kwargs) def astep(self, q0): - # Hamiltonian(self.logp, self.dlogp, self.potential) - H = self.leapfrog1_dE + leapfrog = self.leapfrog1_dE Emax = self.Emax e = self.step_size p0 = self.potential.random() + E0 = self.compute_energy(q0, p0) + u = uniform() q = qn = qp = q0 p = pn = pp = p0 @@ -115,10 +137,10 @@ def astep(self, q0): if v == -1: qn, pn, _, _, q1, n1, s1, a, na = buildtree( - H, qn, pn, u, v, j, e, Emax, q0, p0) + leapfrog, qn, pn, u, v, j, e, Emax, E0) else: _, _, qp, pp, q1, n1, s1, a, na = buildtree( - H, qp, pp, u, v, j, e, Emax, q0, p0) + leapfrog, qp, pp, u, v, j, e, Emax, E0) if s1 == 1 and bern(min(1, n1 * 1. / n)): q = q1 @@ -147,24 +169,23 @@ def competence(var): return Competence.INCOMPATIBLE -def buildtree(H, q, p, u, v, j, e, Emax, q0, p0): +def buildtree(leapfrog1_dE, q, p, u, v, j, e, Emax, E0): if j == 0: - leapfrog1_dE = H - q1, p1, dE = leapfrog1_dE(q, p, array(v * e), q0, p0) + q1, p1, dE = leapfrog1_dE(q, p, array(v * e), E0) n1 = int(log(u) + dE <= 0) s1 = int(log(u) + dE < Emax) return q1, p1, q1, p1, q1, n1, s1, min(1, exp(-dE)), 1 else: qn, pn, qp, pp, q1, n1, s1, a1, na1 = buildtree( - H, q, p, u, v, j - 1, e, Emax, q0, p0) + leapfrog1_dE, q, p, u, v, j - 1, e, Emax, E0) if s1 == 1: if v == -1: qn, pn, _, _, q11, n11, s11, a11, na11 = buildtree( - H, qn, pn, u, v, j - 1, e, Emax, q0, p0) + leapfrog1_dE, qn, pn, u, v, j - 1, e, Emax, E0) else: _, _, qp, pp, q11, n11, s11, a11, na11 = buildtree( - H, qp, pp, u, v, j - 1, e, Emax, q0, p0) + leapfrog1_dE, qp, pp, u, v, j - 1, e, Emax, E0) if bern(n11 * 1. / (max(n1 + n11, 1))): q1 = q11 @@ -179,44 +200,33 @@ def buildtree(H, q, p, u, v, j, e, Emax, q0, p0): return -def leapfrog1_dE(logp, vars, shared, pot, profile): +def leapfrog1_dE(H, q, profile): """Computes a theano function that computes one leapfrog step and the energy difference between the beginning and end of the trajectory. Parameters ---------- - logp : TensorVariable - vars : list of tensor variables - shared : list of shared variables not to compute leapfrog over - pot : quadpotential - porifle : Boolean + H : Hamiltonian + q : theano.tensor + profile : Boolean Returns ------- theano function which returns - q_new, p_new, delta_E + q_new, p_new, dE """ - dlogp = gradient(logp, vars) - (logp, dlogp), q = join_nonshared_inputs([logp, dlogp], vars, shared) - logp = CallableTensor(logp) - dlogp = CallableTensor(dlogp) - - H = Hamiltonian(logp, dlogp, pot) - p = tt.dvector('p') p.tag.test_value = q.tag.test_value - q0 = tt.dvector('q0') - q0.tag.test_value = q.tag.test_value - p0 = tt.dvector('p0') - p0.tag.test_value = p.tag.test_value - e = tt.dscalar('e') e.tag.test_value = 1 q1, p1 = leapfrog(H, q, p, 1, e) E = energy(H, q1, p1) - E0 = energy(H, q0, p0) + + E0 = tt.dscalar('E0') + E0.tag.test_value = 1 + dE = E - E0 - f = theano.function([q, p, e, q0, p0], [q1, p1, dE], profile=profile) + f = theano.function([q, p, e, E0], [q1, p1, dE], profile=profile) f.trust_input = True return f