Skip to content

Commit

Permalink
Merge branch 'master' into fix-numerical-bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Hong Ge committed Apr 26, 2017
2 parents 99e41b8 + b69b799 commit 5842bbd
Show file tree
Hide file tree
Showing 24 changed files with 625 additions and 119 deletions.
4 changes: 4 additions & 0 deletions benchmarks/gauss.run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ print_log(logd)
bench_res = tbenchmark("Gibbs(200, HMC(10, 0.25, 5, :mu), PG(20, 10, :lam))", "gaussmodel", "gaussdata")
logd = build_logd("Gaussian Model", bench_res...)
print_log(logd)

bench_res = tbenchmark("Gibbs(200, eNUTS(5, 0.5, :mu), PG(50, 10, :lam))", "gaussmodel", "gaussdata")
logd = build_logd("Gaussian Model", bench_res...)
print_log(logd)
2 changes: 1 addition & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import Base: ~, convert, promote_rule
#################

# Turing essentials - modelling macros and inference algorithms
export @model, @~, InferenceAlgorithm, HMC, HMCDA, IS, SMC, PG, Gibbs, sample, Chain, Sample, Sampler, setchunksize
export @model, @~, InferenceAlgorithm, HMC, HMCDA, eNUTS, NUTS, IS, SMC, PG, Gibbs, sample, Chain, Sample, Sampler, setchunksize
export VarName, VarInfo, randr, randoc, retain, groupvals
export Dual

Expand Down
7 changes: 5 additions & 2 deletions src/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ function gradient(_vi::VarInfo, model::Function, spl=nothing)
vi = deepcopy(_vi)
# Initialisation
val∇E = Dict{Tuple, Vector{Float64}}()

# Split keys(values) into CHUNKSIZE, CHUNKSIZE, CHUNKSIZE, m-size chunks,
dprintln(4, "making chunks...")
prior_key_chunks = []
Expand Down Expand Up @@ -49,6 +50,7 @@ function gradient(_vi::VarInfo, model::Function, spl=nothing)
if length(key_chunk) != 0
push!(prior_key_chunks, (key_chunk, prior_dim)) # push the last chunk
end

# chunk-wise forward AD
for (key_chunk, prior_dim) in prior_key_chunks
# Set dual part correspondingly
Expand All @@ -70,12 +72,13 @@ function gradient(_vi::VarInfo, model::Function, spl=nothing)
dprintln(5, "make dual done")
else # other varilables (not for gradient info)
for i = 1:l # NOTE: we cannot use direct assignment here as we dont' want the reference of val_vect is changed (Mv and Mat support)
val_vect[i] = reals[i]
val_vect[i] = Dual{prior_dim, Float64}(reals[i])
end
end
end
# Run the model
dprintln(4, "run model...")
vi.logjoint = Dual{prior_dim, Float64}(0)
vi = runmodel(model, vi, spl)
# Collect gradient
dprintln(4, "collect dual...")
Expand All @@ -94,7 +97,7 @@ function gradient(_vi::VarInfo, model::Function, spl=nothing)
val∇E[k] = g
end
# Reset logjoint
vi.logjoint = Dual(0)
vi.logjoint = Dual{prior_dim, Float64}(0)
end
# Return
return val∇E
Expand Down
88 changes: 88 additions & 0 deletions src/samplers/enuts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
immutable eNUTS <: InferenceAlgorithm
n_samples :: Int # number of samples
step_size :: Float64 # leapfrog step size
space :: Set # sampling space, emtpy means all
group_id :: Int

eNUTS(step_size::Float64) = new(1, step_size, Set(), 0)
eNUTS(n_samples::Int, step_size::Float64) = new(n_samples, step_size, Set(), 0)
eNUTS(n_samples::Int, step_size::Float64, space...) = new(n_samples, step_size, isa(space, Symbol) ? Set([space]) : Set(space), 0)
eNUTS(alg::eNUTS, new_group_id::Int) = new(alg.n_samples, alg.step_size, alg.space, new_group_id)
end

function step(model, spl::Sampler{eNUTS}, vi::VarInfo, is_first::Bool)
if is_first
true, vi
else
ϵ = spl.alg.step_size

dprintln(2, "sampling momentum...")
p = sample_momentum(vi, spl)

dprintln(3, "X -> R...")
vi = link(vi, spl)

dprintln(3, "sample slice variable u")
u = rand() * exp(-find_H(p, model, vi, spl))

θm, θp, rm, rp, j, vi_new, n, s = deepcopy(vi), deepcopy(vi), deepcopy(p), deepcopy(p), 0, deepcopy(vi), 1, 1
while s == 1
v_j = rand([-1, 1]) # Note: this variable actually does not depend on j;
# it is set as `v_j` just to be consistent to the paper
if v_j == -1
θm, rm, _, _, θ′, n′, s′ = build_tree(θm, rm, u, v_j, j, ϵ, model, spl)
else
_, _, θp, rp, θ′, n′, s′ = build_tree(θp, rp, u, v_j, j, ϵ, model, spl)
end
if s′ == 1
if rand() < min(1, n′ / n)
vi_new = deepcopy(θ′)
end
end
n = n + n′
s = s′ * (direction(θm, θp, rm, model, spl) >= 0 ? 1 : 0) * (direction(θm, θp, rp, model, spl) >= 0 ? 1 : 0)
j = j + 1
end

dprintln(3, "R -> X...")
vi_new = invlink(vi_new, spl)

cleandual!(vi_new)

true, vi_new
end
end

function build_tree(θ, r, u, v, j, ϵ, model, spl)
doc"""
- θ : model parameter
- r : momentum variable
- u : slice variable
- v : direction ∈ {-1, 1}
- j : depth
- ϵ : leapfrog step size
"""
if j == 0
# Base case - take one leapfrog step in the direction v.
θ′, r′ = leapfrog(θ, r, 1, v * ϵ, model, spl)
n′ = u <= exp(-find_H(r′, model, θ′, spl)) ? 1 : 0
s′ = u < exp(Δ_max - find_H(r′, model, θ′, spl)) ? 1 : 0
return deepcopy(θ′), deepcopy(r′), deepcopy(θ′), deepcopy(r′), deepcopy(θ′), n′, s′
else
# Recursion - build the left and right subtrees.
θm, rm, θp, rp, θ′, n′, s′ = build_tree(θ, r, u, v, j - 1, ϵ, model, spl)
if s′ == 1
if v == -1
θm, rm, _, _, θ′′, n′′, s′′ = build_tree(θm, rm, u, v, j - 1, ϵ, model, spl)
else
_, _, θp, rp, θ′′, n′′, s′′ = build_tree(θp, rp, u, v, j - 1, ϵ, model, spl)
end
if rand() < n′′ / (n′ + n′′)
θ′ = deepcopy(θ′′)
end
s′ = s′′ * (direction(θm, θp, rm, model, spl) >= 0 ? 1 : 0) * (direction(θm, θp, rp, model, spl) >= 0 ? 1 : 0)
n′ = n′ + n′′
end
return θm, rm, θp, rp, θ′, n′, s′
end
end
10 changes: 6 additions & 4 deletions src/samplers/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function Sampler(alg::Gibbs)

for i in 1:n_samplers
sub_alg = alg.algs[i]
if isa(sub_alg, HMC) || isa(sub_alg, HMCDA)
if isa(sub_alg, Hamiltonian)
samplers[i] = Sampler(typeof(sub_alg)(sub_alg, i))
elseif isa(sub_alg, PG)
samplers[i] = Sampler(PG(sub_alg, i))
Expand Down Expand Up @@ -46,7 +46,7 @@ function sample(model::Function, alg::Gibbs)
sub_sample_n = [] # record #samples for each sampler
for i in 1:length(alg.algs)
sub_alg = alg.algs[i]
if isa(sub_alg, HMC) || isa(sub_alg, HMCDA)
if isa(sub_alg, Hamiltonian)
push!(sub_sample_n, sub_alg.n_samples)
elseif isa(sub_alg, PG)
push!(sub_sample_n, sub_alg.n_iterations)
Expand Down Expand Up @@ -79,7 +79,7 @@ function sample(model::Function, alg::Gibbs)
# dprintln(2, "Sampler stepping...")
dprintln(2, "$(typeof(local_spl)) stepping...")
# println(varInfo)
if isa(local_spl, Sampler{HMC}) || isa(local_spl, Sampler{HMCDA})
if isa(local_spl.alg, Hamiltonian)

for _ = 1:local_spl.alg.n_samples
dprintln(2, "recording old θ...")
Expand All @@ -95,7 +95,7 @@ function sample(model::Function, alg::Gibbs)
i_thin += 1
end
end
elseif isa(local_spl, Sampler{PG})
elseif isa(local_spl.alg, PG)
# Update new VarInfo to the reference particle
varInfo.index = 0
varInfo.num_produce = 0
Expand All @@ -113,6 +113,8 @@ function sample(model::Function, alg::Gibbs)
end
end
varInfo = ref_particle.vi
else
error("[GibbsSampler] unsupport base sampler $local_spl")
end

end
Expand Down
21 changes: 12 additions & 9 deletions src/samplers/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ immutable HMC <: InferenceAlgorithm
HMC(alg::HMC, new_group_id::Int) = new(alg.n_samples, alg.lf_size, alg.lf_num, alg.space, new_group_id)
end

Sampler(alg::Union{HMC,HMCDA}) = begin
typealias Hamiltonian Union{HMC,HMCDA,eNUTS,NUTS}

Sampler(alg::Hamiltonian) = begin
info = Dict{Symbol, Any}()
Sampler(alg, info)
end
Expand Down Expand Up @@ -83,16 +85,17 @@ function step(model, spl::Sampler{HMC}, vi::VarInfo, is_first::Bool)
end
end

sample(model::Function, alg::Union{HMC, HMCDA}) = sample(model, alg, CHUNKSIZE)
sample(model::Function, alg::Hamiltonian) = sample(model, alg, CHUNKSIZE)

# NOTE: in the previous code, `sample` would call `run`; this is
# now simplified: `sample` and `run` are merged into one function.
function sample{T<:Union{HMC, HMCDA}}(model::Function, alg::T, chunk_size::Int)
function sample{T<:Hamiltonian}(model::Function, alg::T, chunk_size::Int)
global CHUNKSIZE = chunk_size;
spl = Sampler(alg);
alg_str = isa(alg, HMC) ? "HMC" : "HMCDA"


alg_str = isa(alg, HMC) ? "HMC" :
isa(alg, HMCDA) ? "HMCDA" :
isa(alg, eNUTS) ? "eNUTS" :
isa(alg, NUTS) ? "NUTS" : "Hamiltonian"

# initialization
n = spl.alg.n_samples
Expand Down Expand Up @@ -125,7 +128,7 @@ function sample{T<:Union{HMC, HMCDA}}(model::Function, alg::T, chunk_size::Int)
Chain(0, samples) # wrap the result by Chain
end

function assume{T<:Union{HMC,HMCDA}}(spl::Sampler{T}, dist::Distribution, vn::VarName, vi::VarInfo)
function assume{T<:Hamiltonian}(spl::Sampler{T}, dist::Distribution, vn::VarName, vi::VarInfo)
# Step 1 - Generate or replay variable
dprintln(2, "assuming...")
r = rand(vi, vn, dist, spl)
Expand All @@ -135,7 +138,7 @@ function assume{T<:Union{HMC,HMCDA}}(spl::Sampler{T}, dist::Distribution, vn::Va
end

# NOTE: TRY TO REMOVE Void through defining a special type for gradient based algs.
function observe{T<:Union{HMC,HMCDA}}(spl::Sampler{T}, d::Distribution, value, vi::VarInfo)
function observe{T<:Hamiltonian}(spl::Sampler{T}, d::Distribution, value, vi::VarInfo)
dprintln(2, "observing...")
if length(value) == 1
vi.logjoint += logpdf(d, Dual(value))
Expand All @@ -145,7 +148,7 @@ function observe{T<:Union{HMC,HMCDA}}(spl::Sampler{T}, d::Distribution, value, v
dprintln(2, "observe done")
end

rand{T<:Union{HMC,HMCDA}}(vi::VarInfo, vn::VarName, dist::Distribution, spl::Sampler{T}) = begin
rand{T<:Hamiltonian}(vi::VarInfo, vn::VarName, dist::Distribution, spl::Sampler{T}) = begin
isempty(spl.alg.space) || vn.sym in spl.alg.space ?
randr(vi, vn, dist, spl, false) :
randr(vi, vn, dist)
Expand Down
43 changes: 0 additions & 43 deletions src/samplers/hmcda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,51 +20,8 @@ immutable HMCDA <: InferenceAlgorithm

end

function find_good_eps(model::Function, spl::Sampler{HMCDA}, vi::VarInfo)
ϵ, p = 1.0, sample_momentum(vi, spl) # set initial epsilon and momentums
jointd = exp(-find_H(p, model, vi, spl)) # calculate p(Θ, p) = exp(-H(Θ, p))

# println("[HMCDA] grad: ", grad)
# println("[HMCDA] p: ", p)
# println("[HMCDA] vi: ", vi)
vi_prime, p_prime = leapfrog(vi, p, 1, ϵ, model, spl) # make a leapfrog dictionary

jointd_prime = exp(-find_H(p_prime, model, vi_prime, spl)) # calculate new p(Θ, p)

# println("[HMCDA] jointd: ", jointd)
# println("[HMCDA] jointd_prime: ", jointd_prime)

# This trick prevents the log-joint or its graident from being infinte
# Ref: https://github.com/mfouesneau/NUTS/blob/master/nuts.py#L111
# QUES: will this lead to some bias of the sampler?
while isnan(jointd_prime)
ϵ *= 0.5
# println("[HMCDA] current ϵ: ", ϵ)
# println("[HMCDA] jointd_prime: ", jointd_prime)
# println("[HMCDA] vi_prime: ", vi_prime)
vi_prime, p_prime = leapfrog(vi, p, 1, ϵ, model, spl)
jointd_prime = exp(-find_H(p_prime, model, vi_prime, spl))
end
ϵ_bar = ϵ

# Heuristically find optimal ϵ
a = 2.0 * (jointd_prime / jointd > 0.5 ? 1 : 0) - 1
while (jointd_prime / jointd)^a > 2.0^(-a)
# println("[HMCDA] current ϵ: ", ϵ)
# println("[HMCDA] jointd_prime: ", jointd_prime)
# println("[HMCDA] vi_prime: ", vi_prime)
ϵ = 2.0^a * ϵ
vi_prime, p_prime = leapfrog(vi, p, 1, ϵ, model, spl)
jointd_prime = exp(-find_H(p_prime, model, vi_prime, spl))
end

println("[HMCDA] found initial ϵ: ", ϵ)
ϵ_bar, ϵ
end

function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool)
if is_first

vi_0 = deepcopy(vi)

vi = link(vi, spl)
Expand Down

0 comments on commit 5842bbd

Please sign in to comment.