Skip to content

Commit

Permalink
Merge ae9d56c into 219652c
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao committed Aug 25, 2019
2 parents 219652c + ae9d56c commit ce516c4
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions pymc3/step_methods/hmc/nuts.py
Expand Up @@ -251,10 +251,18 @@ def extend(self, direction):
if direction > 0:
tree, diverging, turning = self._build_subtree(
self.right, self.depth, floatX(np.asarray(self.step_size)))
leftmost_begin, leftmost_end = self.left, self.right
rightmost_begin, rightmost_end = tree.left, tree.right
leftmost_p_sum = self.p_sum
rightmost_p_sum = tree.p_sum
self.right = tree.right
else:
tree, diverging, turning = self._build_subtree(
self.left, self.depth, floatX(np.asarray(-self.step_size)))
leftmost_begin, leftmost_end = tree.left, tree.right
rightmost_begin, rightmost_end = self.left, self.right
leftmost_p_sum = tree.p_sum
rightmost_p_sum = self.p_sum
self.left = tree.right

self.depth += 1
Expand All @@ -274,8 +282,12 @@ def extend(self, direction):
left, right = self.left, self.right
p_sum = self.p_sum
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
p_sum1 = leftmost_p_sum + rightmost_begin.p
turning1 = (p_sum1.dot(leftmost_begin.v) <= 0) or (p_sum1.dot(rightmost_begin.v) <= 0)
p_sum2 = leftmost_end.p + rightmost_p_sum
turning2 = (p_sum2.dot(leftmost_end.v) <= 0) or (p_sum2.dot(rightmost_end.v) <= 0)

return diverging, turning
return diverging, (turning | turning1 | turning2)

def _single_step(self, left, epsilon):
"""Perform a leapfrog step and handle error cases."""
Expand Down Expand Up @@ -323,7 +335,12 @@ def _build_subtree(self, left, depth, epsilon):

if not (diverging or turning):
p_sum = tree1.p_sum + tree2.p_sum
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
turning0 = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
p_sum1 = tree1.p_sum + tree2.left.p
turning1 = (p_sum1.dot(tree1.left.v) <= 0) or (p_sum1.dot(tree2.left.v) <= 0)
p_sum2 = tree1.right.p + tree2.p_sum
turning2 = (p_sum2.dot(tree1.right.v) <= 0) or (p_sum2.dot(tree2.right.v) <= 0)
turning = (turning0 | turning1 | turning2)

log_size = np.logaddexp(tree1.log_size, tree2.log_size)
if logbern(tree2.log_size - log_size):
Expand Down

0 comments on commit ce516c4

Please sign in to comment.