Skip to content

Commit

Permalink
wip: check for divergence
Browse files Browse the repository at this point in the history
  • Loading branch information
chentoast committed Aug 25, 2022
1 parent 5553cfb commit 2a959a8
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/inference/hmc_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ function assess_momenta(momenta)
logprob += logpdf(normal, val, 0, 1)
end
logprob
end
end
75 changes: 54 additions & 21 deletions src/inference/nuts.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using LinearAlgebra: dot

Tree = @NamedTuple begin
struct Tree
val_left
momenta_left
val_right
Expand All @@ -9,9 +9,10 @@ Tree = @NamedTuple begin
n :: Int
weight :: Float64
stop :: Bool
diverging :: Bool
end

Stats = @NamedTuple begin
struct SamplerStats
depth
n
accept
Expand Down Expand Up @@ -47,11 +48,13 @@ function leapfrog(values, momenta, eps, integrator_state)
return values, momenta, get_score(trace)
end

function build_root(val, momenta, eps, direction, integrator_state)
function build_root(val, momenta, eps, direction, weight_init, integrator_state)
val, momenta, lp = leapfrog(val, momenta, direction * eps, integrator_state)
weight = lp + assess_momenta(momenta)

return Tree((val, momenta, val, momenta, val, 1, weight, false))
diverging = weight - weight_init > 1000

return Tree(val, momenta, val, momenta, val, 1, weight, false, diverging)
end

function merge_trees(tree_left, tree_right)
Expand All @@ -67,26 +70,29 @@ function merge_trees(tree_left, tree_right)
stop = tree_left.stop || tree_right.stop || u_turn(
tree_left.val_left, tree_right.val_right, tree_left.momenta_left, tree_right.momenta_right
)
diverging = tree_left.diverging || tree_right.diverging

return Tree((tree_left.val_left, tree_left.momenta_left, tree_right.val_right,
tree_right.momenta_right, sample, n, weight, stop))
return Tree(tree_left.val_left, tree_left.momenta_left, tree_right.val_right,
tree_right.momenta_right, sample, n, weight, stop, diverging)
end

function build_tree(val, momenta, depth, eps, direction, integrator_state)
function build_tree(val, momenta, depth, eps, direction, weight_init, integrator_state)
if depth == 0
return build_root(val, momenta, eps, direction, integrator_state)
return build_root(val, momenta, eps, direction, weight_init, integrator_state)
end

tree = build_tree(val, momenta, depth - 1, eps, direction, integrator_state)
tree = build_tree(val, momenta, depth - 1, eps, direction, weight_init, integrator_state)

if tree.stop
if tree.stop || tree.diverging
return tree
end

if direction == 1
other_tree = build_tree(tree.val_right, tree.momenta_right, depth - 1, eps, direction, integrator_state)
other_tree = build_tree(tree.val_right, tree.momenta_right, depth - 1, eps, direction,
weight_init, integrator_state)
else
other_tree = build_tree(tree.val_left, tree.momenta_left, depth - 1, eps, direction, integrator_state)
other_tree = build_tree(tree.val_left, tree.momenta_left, depth - 1, eps, direction,
weight_init, integrator_state)
end

if direction == 1
Expand All @@ -96,6 +102,27 @@ function build_tree(val, momenta, depth, eps, direction, integrator_state)
end
end

"""
(new_trace, sampler_statistics) = nuts(
trace, selection::Selection;eps=0.1,
max_treedepth=15, check=false, observations=EmptyChoiceMap())
Apply a Hamiltonian Monte Carlo (HMC) update with a No U Turn stopping criterion that proposes new values for the selected addresses, returning the new trace (which is equal to the previous trace if the move was not accepted) and a struct `sampler_statistics` containing information about the sampled trajectory.
The NUT sampler allows for sampling trajectories of dynamic lengths, removing the need to specify the length of the trajectory as a parameter.
The sample will be returned early if the height of the sampled tree exceeds `max_treedepth`.
`sampler_statistics` is a struct containing the following fields:
- depth: the depth of the trajectory tree
- n: the number of samples in the trajectory tree
- sum_alpha: the sum of the individual mh acceptance probabilities for each sample in the tree
- n_accept: how many intermediate samples were accepted
- accept: whether the sample was accepted or not
# References
Betancourt, M. (2017). A Conceptual Introduction to Hamiltonian Monte Carlo. URL: https://doi.org/10.48550/arXiv.1701.02434
Hoffman, M. D., & Gelman, A. (2022). The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo. URL: https://arxiv.org/abs/1111.4246
"""
function nuts(
trace::Trace, selection::Selection; eps=0.1, max_treedepth=15,
check=false, observations=EmptyChoiceMap())
Expand All @@ -105,12 +132,15 @@ function nuts(
# values needed for a leapfrog step
(_, values_trie, _) = choice_gradients(trace, selection, retval_grad)
values = to_array(values_trie, Float64)
integrator_state = (values_trie, selection, retval_grad, trace)

momenta = sample_momenta(length(values))
prev_momenta_score = assess_momenta(momenta)

tree = Tree((values, momenta, values, momenta, values, 1, -Inf, false))
weight_init = prev_model_score + prev_momenta_score

integrator_state = (values_trie, selection, retval_grad, trace)

tree = Tree(values, momenta, values, momenta, values, 1, -Inf, false, false)

direction = 0
depth = 0
Expand All @@ -119,14 +149,16 @@ function nuts(
direction = rand([-1, 1])

if direction == 1 # going right
other_tree = build_tree(tree.val_right, tree.momenta_right, depth, eps, direction, integrator_state)
other_tree = build_tree(tree.val_right, tree.momenta_right, depth, eps, direction,
weight_init, integrator_state)
tree = merge_trees(tree, other_tree)
else # going left
other_tree = build_tree(tree.val_left, tree.momenta_left, depth, eps, direction, integrator_state)
other_tree = build_tree(tree.val_left, tree.momenta_left, depth, eps, direction,
weight_init, integrator_state)
tree = merge_trees(other_tree, tree)
end

stop = stop || tree.stop
stop = stop || tree.stop || tree.diverging
if stop
break
end
Expand All @@ -147,12 +179,13 @@ function nuts(
end

# accept or reject
alpha = new_model_score - prev_model_score + new_momenta_score - prev_momenta_score
alpha = new_model_score + new_momenta_score - weight_init
if log(rand()) < alpha
return (new_trace, Stats((depth, tree.n, true)))
return (new_trace, SamplerStats(depth, tree.n, true))
else
return (trace, Stats((depth, tree.n, false)))
return (trace, SamplerStats(depth, tree.n, false))
end
end

export nuts
export nuts

0 comments on commit 2a959a8

Please sign in to comment.