From 89594a530ab9550806a095fcf8bd5642dcdd9261 Mon Sep 17 00:00:00 2001 From: Tony Chen Date: Fri, 2 Sep 2022 13:17:24 -0400 Subject: [PATCH] wip --- src/inference/hmc_common.jl | 41 +++++++++++++++++++++++++++ src/inference/nuts.jl | 56 ++++++++++++++++++------------------- test/inference/nuts.jl | 20 +++++++++++++ 3 files changed, 89 insertions(+), 28 deletions(-) create mode 100644 test/inference/nuts.jl diff --git a/src/inference/hmc_common.jl b/src/inference/hmc_common.jl index a62e1f1b..1689e784 100644 --- a/src/inference/hmc_common.jl +++ b/src/inference/hmc_common.jl @@ -9,3 +9,44 @@ function assess_momenta(momenta) end logprob end + +function add_choicemaps(a::ChoiceMap, b::ChoiceMap) + out = choicemap() + + for (name, val) in get_values_shallow(a) + out[name] = val + b[name] + end + + for (name, submap) in get_submaps_shallow(a) + out.internal_nodes[name] = add_choicemaps(submap, get_submap(b, name)) + end + + return out +end + +function scale_choicemap(a::ChoiceMap, scale) + out = choicemap() + + for (name, val) in get_values_shallow(a) + out[name] = val * scale + end + + for (name, submap) in get_submaps_shallow(a) + out.internal_nodes[name] = scale_choicemap(submap, scale) + end + + return out +end + +function assess_momenta_trie(momenta_trie) + logprob = 0. + for (_, val) in get_values_shallow(momenta_trie) + logprob += logpdf(normal, val, 0, 1) + end + + for (_, submap) in get_submaps_shallow(momenta_trie) + logprob += assess_momenta_trie(submap) + end + + return logprob +end \ No newline at end of file diff --git a/src/inference/nuts.jl b/src/inference/nuts.jl index d7f4d429..51ce2986 100644 --- a/src/inference/nuts.jl +++ b/src/inference/nuts.jl @@ -15,6 +15,7 @@ end struct SamplerStats depth n + diverging accept end @@ -23,34 +24,30 @@ function u_turn(values_left, values_right, momenta_left, momenta_right) (dot(values_right - values_left, momenta_left) >= 0) end -function leapfrog(values, momenta, eps, integrator_state) - values_trie, selection, retval_grad, trace = integrator_state +function leapfrog(values_trie, momenta_trie, eps, integrator_state) + selection, retval_grad, trace = integrator_state - values_trie = from_array(values_trie, values) (trace, _, _) = update(trace, values_trie) (_, _, gradient_trie) = choice_gradients(trace, selection, retval_grad) - gradient = to_array(gradient_trie, Float64) # half step on momenta - momenta += (eps / 2) * gradient + momenta_trie = add_choicemaps(momenta_trie, scale_choicemap(gradient_trie, eps / 2)) # full step on positions - values += eps * momenta + values_trie = add_choicemaps(values_trie, scale_choicemap(momenta_trie, eps)) # get new gradient - values_trie = from_array(values_trie, values) (trace, _, _) = update(trace, values_trie) (_, _, gradient_trie) = choice_gradients(trace, selection, retval_grad) - gradient = to_array(gradient_trie, Float64) # half step on momenta - momenta += (eps / 2) * gradient - return values, momenta, get_score(trace) + momenta_trie = add_choicemaps(momenta_trie, scale_choicemap(gradient_trie, eps / 2)) + return values_trie, momenta_trie, get_score(trace) end function build_root(val, momenta, eps, direction, weight_init, integrator_state) val, momenta, lp = leapfrog(val, momenta, direction * eps, integrator_state) - weight = lp + assess_momenta(momenta) + weight = lp + assess_momenta_trie(momenta) diverging = weight - weight_init > 1000 @@ -67,9 +64,16 @@ function merge_trees(tree_left, tree_right) weight = logsumexp(tree_left.weight, tree_right.weight) n = tree_left.n + tree_right.n - stop = tree_left.stop || tree_right.stop || u_turn( - tree_left.val_left, tree_right.val_right, tree_left.momenta_left, tree_right.momenta_right - ) + + if u_turn(to_array(tree_left.val_left, Float64), + to_array(tree_right.val_right, Float64), + to_array(tree_left.momenta_left, Float64), + to_array(tree_right.momenta_right, Float64)) + end + stop = tree_left.stop || tree_right.stop || u_turn(to_array(tree_left.val_left, Float64), + to_array(tree_right.val_right, Float64), + to_array(tree_left.momenta_left, Float64), + to_array(tree_right.momenta_right, Float64)) diverging = tree_left.diverging || tree_right.diverging return Tree(tree_left.val_left, tree_left.momenta_left, tree_right.val_right, @@ -90,14 +94,10 @@ function build_tree(val, momenta, depth, eps, direction, weight_init, integrator if direction == 1 other_tree = build_tree(tree.val_right, tree.momenta_right, depth - 1, eps, direction, weight_init, integrator_state) + return merge_trees(tree, other_tree) else other_tree = build_tree(tree.val_left, tree.momenta_left, depth - 1, eps, direction, weight_init, integrator_state) - end - - if direction == 1 - return merge_trees(tree, other_tree) - else return merge_trees(other_tree, tree) end end @@ -131,16 +131,16 @@ function nuts( # values needed for a leapfrog step (_, values_trie, _) = choice_gradients(trace, selection, retval_grad) - values = to_array(values_trie, Float64) - momenta = sample_momenta(length(values)) + momenta = sample_momenta(length(to_array(values_trie, Float64))) + momenta_trie = from_array(values_trie, momenta) prev_momenta_score = assess_momenta(momenta) weight_init = prev_model_score + prev_momenta_score - integrator_state = (values_trie, selection, retval_grad, trace) + integrator_state = (selection, retval_grad, trace) - tree = Tree(values, momenta, values, momenta, values, 1, -Inf, false, false) + tree = Tree(values_trie, momenta_trie, values_trie, momenta_trie, values_trie, 1, -Inf, false, false) direction = 0 depth = 0 @@ -165,7 +165,7 @@ function nuts( depth += 1 end - (new_trace, _, _) = update(trace, from_array(values_trie, tree.val_sample)) + (new_trace, _, _) = update(trace, tree.val_sample) check && check_observations(get_choices(new_trace), observations) # assess new model score (negative potential energy) @@ -173,17 +173,17 @@ function nuts( # assess new momenta score (negative kinetic energy) if direction == 1 - new_momenta_score = assess_momenta(-tree.momenta_right) + new_momenta_score = assess_momenta_trie(tree.momenta_right) else - new_momenta_score = assess_momenta(-tree.momenta_left) + new_momenta_score = assess_momenta_trie(tree.momenta_left) end # accept or reject alpha = new_model_score + new_momenta_score - weight_init if log(rand()) < alpha - return (new_trace, SamplerStats(depth, tree.n, true)) + return (new_trace, SamplerStats(depth, tree.n, tree.diverging, true)) else - return (trace, SamplerStats(depth, tree.n, false)) + return (trace, SamplerStats(depth, tree.n, tree.diverging, false)) end end diff --git a/test/inference/nuts.jl b/test/inference/nuts.jl new file mode 100644 index 00000000..29cd51d4 --- /dev/null +++ b/test/inference/nuts.jl @@ -0,0 +1,20 @@ +@testset "nuts" begin + + # smoke test a function without retval gradient + @gen function foo() + x = @trace(normal(0, 1), :x) + return x + end + + (trace, _) = generate(foo, ()) + (new_trace, accepted) = nuts(trace, select(:x)) + + # smoke test a function with retval gradient + @gen (grad) function foo() + x = @trace(normal(0, 1), :x) + return x + end + + (trace, _) = generate(foo, ()) + (new_trace, accepted) = nuts(trace, select(:x)) +end