Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
chentoast committed Sep 2, 2022
1 parent 2a959a8 commit 89594a5
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 28 deletions.
41 changes: 41 additions & 0 deletions src/inference/hmc_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,44 @@ function assess_momenta(momenta)
end
logprob
end

function add_choicemaps(a::ChoiceMap, b::ChoiceMap)
out = choicemap()

for (name, val) in get_values_shallow(a)
out[name] = val + b[name]
end

for (name, submap) in get_submaps_shallow(a)
out.internal_nodes[name] = add_choicemaps(submap, get_submap(b, name))
end

return out
end

function scale_choicemap(a::ChoiceMap, scale)
out = choicemap()

for (name, val) in get_values_shallow(a)
out[name] = val * scale
end

for (name, submap) in get_submaps_shallow(a)
out.internal_nodes[name] = scale_choicemap(submap, scale)
end

return out
end

function assess_momenta_trie(momenta_trie)
logprob = 0.
for (_, val) in get_values_shallow(momenta_trie)
logprob += logpdf(normal, val, 0, 1)
end

for (_, submap) in get_submaps_shallow(momenta_trie)
logprob += assess_momenta_trie(submap)
end

return logprob
end
56 changes: 28 additions & 28 deletions src/inference/nuts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ end
struct SamplerStats
depth
n
diverging
accept
end

Expand All @@ -23,34 +24,30 @@ function u_turn(values_left, values_right, momenta_left, momenta_right)
(dot(values_right - values_left, momenta_left) >= 0)
end

function leapfrog(values, momenta, eps, integrator_state)
values_trie, selection, retval_grad, trace = integrator_state
function leapfrog(values_trie, momenta_trie, eps, integrator_state)
selection, retval_grad, trace = integrator_state

values_trie = from_array(values_trie, values)
(trace, _, _) = update(trace, values_trie)
(_, _, gradient_trie) = choice_gradients(trace, selection, retval_grad)
gradient = to_array(gradient_trie, Float64)

# half step on momenta
momenta += (eps / 2) * gradient
momenta_trie = add_choicemaps(momenta_trie, scale_choicemap(gradient_trie, eps / 2))

# full step on positions
values += eps * momenta
values_trie = add_choicemaps(values_trie, scale_choicemap(momenta_trie, eps))

# get new gradient
values_trie = from_array(values_trie, values)
(trace, _, _) = update(trace, values_trie)
(_, _, gradient_trie) = choice_gradients(trace, selection, retval_grad)
gradient = to_array(gradient_trie, Float64)

# half step on momenta
momenta += (eps / 2) * gradient
return values, momenta, get_score(trace)
momenta_trie = add_choicemaps(momenta_trie, scale_choicemap(gradient_trie, eps / 2))
return values_trie, momenta_trie, get_score(trace)
end

function build_root(val, momenta, eps, direction, weight_init, integrator_state)
val, momenta, lp = leapfrog(val, momenta, direction * eps, integrator_state)
weight = lp + assess_momenta(momenta)
weight = lp + assess_momenta_trie(momenta)

diverging = weight - weight_init > 1000

Expand All @@ -67,9 +64,16 @@ function merge_trees(tree_left, tree_right)

weight = logsumexp(tree_left.weight, tree_right.weight)
n = tree_left.n + tree_right.n
stop = tree_left.stop || tree_right.stop || u_turn(
tree_left.val_left, tree_right.val_right, tree_left.momenta_left, tree_right.momenta_right
)

if u_turn(to_array(tree_left.val_left, Float64),
to_array(tree_right.val_right, Float64),
to_array(tree_left.momenta_left, Float64),
to_array(tree_right.momenta_right, Float64))
end
stop = tree_left.stop || tree_right.stop || u_turn(to_array(tree_left.val_left, Float64),
to_array(tree_right.val_right, Float64),
to_array(tree_left.momenta_left, Float64),
to_array(tree_right.momenta_right, Float64))
diverging = tree_left.diverging || tree_right.diverging

return Tree(tree_left.val_left, tree_left.momenta_left, tree_right.val_right,
Expand All @@ -90,14 +94,10 @@ function build_tree(val, momenta, depth, eps, direction, weight_init, integrator
if direction == 1
other_tree = build_tree(tree.val_right, tree.momenta_right, depth - 1, eps, direction,
weight_init, integrator_state)
return merge_trees(tree, other_tree)
else
other_tree = build_tree(tree.val_left, tree.momenta_left, depth - 1, eps, direction,
weight_init, integrator_state)
end

if direction == 1
return merge_trees(tree, other_tree)
else
return merge_trees(other_tree, tree)
end
end
Expand Down Expand Up @@ -131,16 +131,16 @@ function nuts(

# values needed for a leapfrog step
(_, values_trie, _) = choice_gradients(trace, selection, retval_grad)
values = to_array(values_trie, Float64)

momenta = sample_momenta(length(values))
momenta = sample_momenta(length(to_array(values_trie, Float64)))
momenta_trie = from_array(values_trie, momenta)
prev_momenta_score = assess_momenta(momenta)

weight_init = prev_model_score + prev_momenta_score

integrator_state = (values_trie, selection, retval_grad, trace)
integrator_state = (selection, retval_grad, trace)

tree = Tree(values, momenta, values, momenta, values, 1, -Inf, false, false)
tree = Tree(values_trie, momenta_trie, values_trie, momenta_trie, values_trie, 1, -Inf, false, false)

direction = 0
depth = 0
Expand All @@ -165,25 +165,25 @@ function nuts(
depth += 1
end

(new_trace, _, _) = update(trace, from_array(values_trie, tree.val_sample))
(new_trace, _, _) = update(trace, tree.val_sample)
check && check_observations(get_choices(new_trace), observations)

# assess new model score (negative potential energy)
new_model_score = get_score(new_trace)

# assess new momenta score (negative kinetic energy)
if direction == 1
new_momenta_score = assess_momenta(-tree.momenta_right)
new_momenta_score = assess_momenta_trie(tree.momenta_right)
else
new_momenta_score = assess_momenta(-tree.momenta_left)
new_momenta_score = assess_momenta_trie(tree.momenta_left)
end

# accept or reject
alpha = new_model_score + new_momenta_score - weight_init
if log(rand()) < alpha
return (new_trace, SamplerStats(depth, tree.n, true))
return (new_trace, SamplerStats(depth, tree.n, tree.diverging, true))
else
return (trace, SamplerStats(depth, tree.n, false))
return (trace, SamplerStats(depth, tree.n, tree.diverging, false))
end
end

Expand Down
20 changes: 20 additions & 0 deletions test/inference/nuts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
@testset "nuts" begin

# smoke test a function without retval gradient
@gen function foo()
x = @trace(normal(0, 1), :x)
return x
end

(trace, _) = generate(foo, ())
(new_trace, accepted) = nuts(trace, select(:x))

# smoke test a function with retval gradient
@gen (grad) function foo()
x = @trace(normal(0, 1), :x)
return x
end

(trace, _) = generate(foo, ())
(new_trace, accepted) = nuts(trace, select(:x))
end

0 comments on commit 89594a5

Please sign in to comment.