Skip to content

Commit

Permalink
Add progress bar to HMC, PG and Gibbs
Browse files Browse the repository at this point in the history
  • Loading branch information
xukai92 committed Apr 14, 2017
1 parent 02ae486 commit 3080d25
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 27 deletions.
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ PDMats 0.5.1
ForwardDiff
Mamba
Stats
ProgressMeter
1 change: 1 addition & 0 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Turing

using Distributions
using ForwardDiff: Dual, npartials # for automatic differentiation
using ProgressMeter

abstract InferenceAlgorithm{P}
abstract Sampler{T<:InferenceAlgorithm}
Expand Down
17 changes: 2 additions & 15 deletions src/samplers/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,11 @@ function Base.run(model, data, spl::Sampler{Gibbs})
# initialization
task = current_task()
n = spl.gibbs.n_iters
t_start = time() # record the start time of HMC
varInfo = VarInfo()
ref_particle = nothing

# HMC steps
for i = 1:n
# Gibbs steps
@showprogress 1 "[Gibbs] Sampling..." for i = 1:n
dprintln(2, "Gibbs stepping...")

for local_spl in spl.samplers
Expand Down Expand Up @@ -92,18 +91,6 @@ function Base.run(model, data, spl::Sampler{Gibbs})
end
end
spl.samples[i].value = Sample(varInfo).value

if VERBOSITY > 0
if i == n
println("100% Done")
elseif i % floor(n / 100) == 0
print("$(i / floor(n / 100))% ")
end
end
end

if VERBOSITY > 0
println("[Gibbs]: Finshed within $(time() - t_start) seconds")
end

Chain(0, spl.samples) # wrap the result by Chain
Expand Down
7 changes: 1 addition & 6 deletions src/samplers/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,11 @@ function sample(model::Function, alg::HMC, chunk_size::Int)
# initialization
n = spl.alg.n_samples
task = current_task()
t_start = time() # record the start time of HMC
accept_num = 0 # record the accept number
varInfo = VarInfo()

# HMC steps
for i = 1:n
@showprogress 1 "[HMC] Sampling..." for i = 1:n
dprintln(2, "recording old θ...")
old_vals = deepcopy(varInfo.vals)
dprintln(2, "HMC stepping...")
Expand All @@ -136,10 +135,6 @@ function sample(model::Function, alg::HMC, chunk_size::Int)

accept_rate = accept_num / n # calculate the accept rate

if VERBOSITY > 0
println("[HMC]: Finshed with accept rate = $(accept_rate) within $(time() - t_start) seconds")
end

Chain(0, spl.samples) # wrap the result by Chain
end

Expand Down
8 changes: 2 additions & 6 deletions src/samplers/pgibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,18 @@ function sample(model, alg::PG)
global sampler = ParticleSampler{PG}(alg);
spl = sampler
n = spl.alg.n_iterations
t_start = time() # record the start time of PG
samples = Vector{Sample}()
logevidence = Vector{Float64}(n)

## custom resampling function for pgibbs
## re-inserts reteined particle after each resampling step
ref_particle = nothing
for i = 1:n
@showprogress 1 "[PG] Sampling..." for i = 1:n
ref_particle, s = step(model, spl, VarInfo(), ref_particle)
logevidence[i] = spl.particles.logE
push!(samples, Sample(1/n, s.value))
end

if VERBOSITY > 0
println("[PG]: Finshed within $(time() - t_start) seconds")
end

chain = Chain(exp(mean(logevidence)), samples)
end

Expand Down

0 comments on commit 3080d25

Please sign in to comment.