Skip to content

Commit

Permalink
Added reporting infrastructure (with tests).
Browse files Browse the repository at this point in the history
  • Loading branch information
tpapp committed May 3, 2018
1 parent 092c397 commit e8e2497
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 12 deletions.
23 changes: 23 additions & 0 deletions docs/src/lowlevel.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,29 @@ ACCEPTANCE_QUANTILES
explore_local_acceptance_ratios
```

## Reporting information during runs

Samplers take an [`AbstractReport`](@ref) argument, which is then used for reporting. The interface is as follows.

```@docs
DynamicHMC.AbstractReport
DynamicHMC.report!
DynamicHMC.start_progress!
DynamicHMC.end_progress!
```

The default is
```@docs
ReportIO
```

Reporting information can be suppressed with
```@docs
ReportSilent
```

Other interfaces should define similar types.

## Utilities and miscellanea

```@docs
Expand Down
1 change: 1 addition & 0 deletions src/DynamicHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ include("utilities.jl")
include("hamiltonian.jl")
include("stepsize.jl")
include("buildingblocks.jl")
include("reporting.jl")
include("sampler.jl")
include("diagnostics.jl")

Expand Down
97 changes: 97 additions & 0 deletions src/reporting.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
export ReportSilent, ReportIO

"""
Subtypes implement [`report!`](@ref), [`start_progress!`](@ref), and
[`end_progress!`](@ref).
"""
abstract type AbstractReport end

"""
A placeholder type for not reporting any information.
"""
struct ReportSilent <: AbstractReport end

report!(::ReportSilent, objects...) = nothing

start_progress!(::ReportSilent, ::Union{Int, Void}, ::Any) = nothing

end_progress!(::ReportSilent) = nothing

mutable struct ReportIO{TIO <: IO} <: AbstractReport
io::TIO
color::Union{Symbol, Int}
step_count::Int
total::Union{Int, Void}
last_count::Union{Int, Void}
last_time::UInt
end

"""
$SIGNATURES
Report to the given stream `io` (defaults to `STDERR`).
For progress bars, emit new information every after `step_count` steps.
`color` is used with `print_with_color`.
"""
ReportIO(; io = STDERR, color = :blue, step_count = 100) =
ReportIO(io, color, step_count, nothing, nothing, zero(UInt))

"""
$SIGNATURES
Start a progress meter for an iteration. The second argument is either
- `nothing`, if the total number of steps is unknown,
- an integer, for the total number of steps.
After calling this function, [`report!`](@ref) should be used at every step with
an integer.
"""
function start_progress!(report::ReportIO, total, msg)
if total isa Integer
msg *= " ($(total) steps)"
end
print_with_color(report.color, report.io, msg, '\n'; bold = true)
report.total = total
report.last_count = 0
report.last_time = time_ns()
nothing
end

"""
$SIGNATURES
Terminate a progress meter.
"""
function end_progress!(report::ReportIO)
print_with_color(report.color, report.io, " ...done\n"; bold = true)
report.last_count = nothing
end

"""
$SIGNATURES
Display `objects` via the appropriate mechanism.
When a single `Int` is given, it is treated as the index of the current step.
"""
function report!(report::ReportIO, count::Int)
@unpack io, step_count, color, total = report
@argcheck report.last_count isa Int "start_progress! was not called."
if count % step_count == 0
msg = "step $(count)"
if total isa Int
msg *= "/$(total)"
end
t = time_ns()
s_per_iteration = (t - report.last_time) / step_count / 1000
msg *= ", $(signif(s_per_iteration, 2)) s/step"
print_with_color(color, io, msg, '\n')
report.last_time = t
report.last_count = count
end
nothing
end
27 changes: 18 additions & 9 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export
Specification for the No-U-turn algorithm, including the random number
generator, Hamiltonian, the initial position, and various parameters.
"""
struct NUTS{Tv, Tf, TR, TH}
struct NUTS{Tv, Tf, TR, TH, Trep <: AbstractReport}
"Random number generator."
rng::TR
"Hamiltonian"
Expand All @@ -19,6 +19,8 @@ struct NUTS{Tv, Tf, TR, TH}
ϵ::Tf
"maximum depth of the tree"
max_depth::Int
"reporting"
report::Trep
end

function show(io::IO, nuts::NUTS)
Expand All @@ -36,13 +38,16 @@ Run the MCMC `sampler` for `N` iterations, returning the results as a vector,
which has elements that conform to the sampler.
"""
function mcmc(sampler::NUTS{Tv,Tf}, N::Int) where {Tv,Tf}
@unpack rng, H, q, ϵ, max_depth = sampler
@unpack rng, H, q, ϵ, max_depth, report = sampler
sample = Vector{NUTS_Transition{Tv,Tf}}(N)
start_progress!(report, N, "MCMC")
for i in 1:N
trans = NUTS_transition(rng, H, q, ϵ, max_depth)
q = trans.q
sample[i] .= trans
report!(report, i)
end
end_progress!(report)
sample
end

Expand All @@ -56,14 +61,17 @@ second value.
When the last two parameters are not specified, initialize using `adapting_ϵ`.
"""
function mcmc_adapting_ϵ(sampler::NUTS{Tv,Tf}, N::Int, A_params, A) where {Tv,Tf}
@unpack rng, H, q, max_depth = sampler
@unpack rng, H, q, max_depth, report = sampler
sample = Vector{NUTS_Transition{Tv,Tf}}(N)
start_progress!(report, N, "MCMC, adapting ϵ")
for i in 1:N
trans = NUTS_transition(rng, H, q, get_ϵ(A), max_depth)
A = adapt_stepsize(A_params, A, trans.a)
q = trans.q
sample[i] .= trans
report!(report, i)
end
end_progress!(report)
sample, A
end

Expand Down Expand Up @@ -115,13 +123,14 @@ function NUTS_init(rng, ℓ, q;
κ = GaussianKE(length(q)),
p = rand(rng, κ),
max_depth = 5,
ϵ = InitialStepsizeSearch())
ϵ = InitialStepsizeSearch(),
report = ReportIO())
H = Hamiltonian(ℓ, κ)
z = phasepoint_in(H, q, p)
if !isa Float64)
ϵ = find_initial_stepsize(ϵ, H, z)
end
NUTS(rng, H, q, ϵ, max_depth)
NUTS(rng, H, q, ϵ, max_depth, report)
end

"""
Expand Down Expand Up @@ -165,9 +174,9 @@ show(io::IO, tuner::StepsizeTuner) =
print(io, "Stepsize tuner, $(tuner.N) samples")

function tune(sampler::NUTS, tuner::StepsizeTuner)
@unpack rng, H, max_depth = sampler
@unpack rng, H, max_depth, report = sampler
sample, A = mcmc_adapting_ϵ(sampler, tuner.N)
NUTS(rng, H, sample[end].q, get_ϵ(A, false), max_depth)
NUTS(rng, H, sample[end].q, get_ϵ(A, false), max_depth, report)
end

"""
Expand All @@ -192,12 +201,12 @@ end

function tune(sampler::NUTS, tuner::StepsizeCovTuner)
@unpack regularize, N = tuner
@unpack rng, H, max_depth = sampler
@unpack rng, H, max_depth, report = sampler
sample, A = mcmc_adapting_ϵ(sampler, N)
Σ = sample_cov(sample)
Σ .+= (UniformScaling(median(diag(Σ)))-Σ) * regularize/N
κ = GaussianKE(Σ)
NUTS(rng, Hamiltonian(H.ℓ, κ), sample[end].q, get_ϵ(A), max_depth)
NUTS(rng, Hamiltonian(H.ℓ, κ), sample[end].q, get_ϵ(A), max_depth, report)
end

"Sequence of tuners, applied in the given order."
Expand Down
2 changes: 2 additions & 0 deletions test/REQUIRE
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
Compat
Distributions
Documenter
ForwardDiff
MCMCDiagnostics
Suppressor
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ include("test-tuners.jl")
include("test-sample-normal.jl")
include("test-normal-mcmc.jl")
include("test-statistics.jl")
include("test-reporting.jl")

include("../docs/make.jl")
3 changes: 2 additions & 1 deletion test/setup-and-utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ import DynamicHMC:

using Base.Test

using ArgCheck
using ArgCheck: @argcheck
using DataStructures
using DiffResults
using Distributions
import ForwardDiff: gradient
using MCMCDiagnostics
using Parameters
using StatsBase
using Suppressor

"RNG for consistent test environment"
const RNG = srand(UInt32[0x23ef614d, 0x8332e05c, 0x3c574111, 0x121aa2f4])
Expand Down
6 changes: 4 additions & 2 deletions test/test-normal-mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ end
5.57947 -0.0540131 1.78163 1.73862 -2.99741 3.6118 10.215 9.60671;
7.28634 1.79718 -0.0821483 2.55874 -1.95031 5.22626 9.60671 11.5554])
forin [ℓ0, ℓ1, ℓ2, ℓ3]
sample, nuts = NUTS_init_tune_mcmc(RNG, ℓ, length(ℓ), 1000)
sample, nuts = NUTS_init_tune_mcmc(RNG, ℓ, length(ℓ), 1000;
report = ReportSilent())
@test EBFMI(sample) 0.3
@test maximum((nuts, 1000, 3)) 1.05
zs = zvalue.([sample], mean_cov_ztests(ℓ))
Expand All @@ -117,7 +118,8 @@ end
for _ in 1:100
K = rand(2:10)
= MvNormal(randn(K), full(rand_Σ(K)))
sample, nuts = NUTS_init_tune_mcmc(RNG, ℓ, K, 1000)
sample, nuts = NUTS_init_tune_mcmc(RNG, ℓ, K, 1000;
report = ReportSilent())
@test EBFMI(sample) 0.3
@test maximum((nuts, 1000, 3)) 1.05
zs = zvalue.([sample], mean_cov_ztests(ℓ))
Expand Down
19 changes: 19 additions & 0 deletions test/test-reporting.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
@testset "reporting" begin
= MvNormal(zeros(3), ones(3))
@color_output false begin
output = @capture_err begin
sample, nuts = NUTS_init_tune_mcmc(RNG, ℓ, length(ℓ), 1000;
report = ReportIO())
end
end
function expectedA(msg, n)
r = "$msg \\($(n) steps\\)\\n"
for i in 100:100:n
r *= "step $(i)/$(n), \\d+\\.\\d+ s/step\\n"
end
r *= " \\.\\.\\.done\\n"
end
raw_regex = join(expectedA.(vcat(fill("MCMC, adapting ϵ", 7), ["MCMC"]),
[75, 25, 50, 100, 200, 400, 50, 1000]), "")
@test ismatch(Regex(raw_regex), output)
end

0 comments on commit e8e2497

Please sign in to comment.