-
Notifications
You must be signed in to change notification settings - Fork 159
/
hmc.jl
78 lines (63 loc) · 2.69 KB
/
hmc.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
function sample_momenta(n::Int)
Float64[random(normal, 0, 1) for _=1:n]
end
function assess_momenta(momenta)
logprob = 0.
for val in momenta
logprob += logpdf(normal, val, 0, 1)
end
logprob
end
"""
(new_trace, accepted) = hmc(
trace, selection::Selection; L=10, eps=0.1,
check=false, observations=EmptyChoiceMap())
Apply a Hamiltonian Monte Carlo (HMC) update that proposes new values for the selected addresses, returning the new trace (which is equal to the previous trace if the move was not accepted) and a `Bool` indicating whether the move was accepted or not.
Hamilton's equations are numerically integrated using leapfrog integration with step size `eps` for `L` steps. See equations (5.18)-(5.20) of Neal (2011).
# References
Neal, Radford M. (2011), "MCMC Using Hamiltonian Dynamics", Handbook of Markov Chain Monte Carlo, pp. 113-162. URL: http://www.mcmchandbook.net/HandbookChapter5.pdf
"""
function hmc(
trace::Trace, selection::Selection; L=10, eps=0.1,
check=false, observations=EmptyChoiceMap())
prev_model_score = get_score(trace)
args = get_args(trace)
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing
argdiffs = map((_) -> NoChange(), args)
# run leapfrog dynamics
new_trace = trace
(_, values_trie, gradient_trie) = choice_gradients(new_trace, selection, retval_grad)
values = to_array(values_trie, Float64)
gradient = to_array(gradient_trie, Float64)
momenta = sample_momenta(length(values))
prev_momenta_score = assess_momenta(momenta)
for step=1:L
# half step on momenta
momenta += (eps / 2) * gradient
# full step on positions
values += eps * momenta
# get new gradient
values_trie = from_array(values_trie, values)
(new_trace, _, _) = update(new_trace, args, argdiffs, values_trie)
(_, _, gradient_trie) = choice_gradients(new_trace, selection, retval_grad)
gradient = to_array(gradient_trie, Float64)
# half step on momenta
momenta += (eps / 2) * gradient
end
check && check_observations(get_choices(new_trace), observations)
# assess new model score (negative potential energy)
new_model_score = get_score(new_trace)
# assess new momenta score (negative kinetic energy)
new_momenta_score = assess_momenta(-momenta)
# accept or reject
alpha = new_model_score - prev_model_score + new_momenta_score - prev_momenta_score
if log(rand()) < alpha
(new_trace, true)
else
(trace, false)
end
end
check_is_kernel(::typeof(hmc)) = true
is_custom_primitive_kernel(::typeof(hmc)) = false
reversal(::typeof(hmc)) = hmc
export hmc