/
reporting.jl
189 lines (147 loc) · 5.34 KB
/
reporting.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import ProgressMeter
#####
##### Reporting progress.
#####
"""
$(TYPEDEF)
A placeholder type for not reporting any information.
"""
struct NoProgressReport end
"""
$(SIGNATURES)
Report to the given `reporter`.
The second argument can be
1. a string, which is displayed as is (this is supported by all reporters).
2. or a step in an MCMC chain with a known number of steps for progress reporters (see
[`make_mcmc_reporter`](@ref)).
`meta` arguments are key-value pairs.
In this context, a *step* is a NUTS transition, not a leapfrog step.
"""
report(reporter::NoProgressReport, step::Union{AbstractString,Integer}; meta...) = nothing
"""
$(SIGNATURES)
Return a reporter which can be used for progress reports with a known number of
`total_steps`. May return the same reporter, or a related object. Will display `meta` as
key-value pairs.
## Arguments:
- `reporter::NoProgressReport`: the original reporter
- `total_steps`: total number of steps
## Keyword arguments:
- `currently_warmup::Bool`: `true` if we are currently doing warmup; `false` if we are currently doing MCMC
- `meta`: key-value pairs that will be displayed by the reporter
"""
make_mcmc_reporter(reporter::NoProgressReport, total_steps; currently_warmup::Bool = false, meta...) = reporter
"""
$(TYPEDEF)
Report progress into the `Logging` framework, using `@info`.
For the information reported, a *step* is a NUTS transition, not a leapfrog step.
# Fields
$(FIELDS)
"""
@with_kw struct LogProgressReport{T}
"ID of chain. Can be an arbitrary object, eg `nothing`."
chain_id::T = nothing
"Always report progress past `step_interval` of the last report."
step_interval::Int = 100
"Always report progress past this much time (in seconds) after the last report."
time_interval_s::Float64 = 1000.0
end
"""
$(SIGNATURES)
Assemble log message metadata.
Currently, it adds `chain_id` *iff* it is not `nothing`.
"""
_log_meta(chain_id::Nothing, meta) = meta
_log_meta(chain_id, meta) = (chain_id = chain_id, meta...)
function report(reporter::LogProgressReport, message::AbstractString; meta...)
@info message _log_meta(reporter.chain_id, meta)...
nothing
end
"""
$(TYPEDEF)
A composite type for tracking the state for which the last log message was emitted, for MCMC
reporting with a given total number of steps (see [`make_mcmc_reporter`](@ref).
# Fields
$(FIELDS)
"""
mutable struct LogMCMCReport{T}
"The progress report sink."
log_progress_report::T
"Total steps for this stage."
total_steps::Int
"Index of the last reported step."
last_reported_step::Int
"The last time a report was logged (determined using `time_ns`)."
last_reported_time_ns::UInt64
end
function report(reporter::LogMCMCReport, message::AbstractString; meta...)
@info message _log_meta(reporter.log_progress_report.chain_id, meta)...
nothing
end
function make_mcmc_reporter(reporter::LogProgressReport, total_steps::Integer; currently_warmup::Bool = false, meta...)
@info "Starting MCMC" total_steps = total_steps meta...
LogMCMCReport(reporter, total_steps, -1, time_ns())
end
function report(reporter::LogMCMCReport, step::Integer; meta...)
@unpack (log_progress_report, total_steps, last_reported_step,
last_reported_time_ns) = reporter
@unpack chain_id, step_interval, time_interval_s = log_progress_report
@argcheck 1 ≤ step ≤ total_steps
Δ_steps = step - last_reported_step
t_ns = time_ns()
Δ_time_s = (t_ns - last_reported_time_ns) / 1_000_000_000
if last_reported_step < 0 || Δ_steps ≥ step_interval || Δ_time_s ≥ time_interval_s
seconds_per_step = Δ_time_s / Δ_steps
meta_progress = (step = step,
seconds_per_step = round(seconds_per_step; sigdigits = 2),
estimated_seconds_left = round((total_steps - step) *
seconds_per_step; sigdigits = 2))
@info "MCMC progress" merge(_log_meta(chain_id, meta_progress), meta)...
reporter.last_reported_step = step
reporter.last_reported_time_ns = t_ns
end
nothing
end
"""
$(TYPEDEF)
Report progress via a progress bar, using `ProgressMeter.jl`.
Example usage:
```julia
julia> ProgressMeterReport()
```
"""
struct ProgressMeterReport
end
struct ProgressMeterReportMCMC{T}
currently_warmup::Bool
progress_meter::T
end
function make_mcmc_reporter(reporter::ProgressMeterReport, total_steps::Integer; currently_warmup::Bool=false, meta...)
description = currently_warmup ? "Warmup: " : "MCMC: "
return ProgressMeterReportMCMC(currently_warmup, ProgressMeter.Progress(total_steps, 1, description))
end
function report(reporter::ProgressMeterReport, message::AbstractString; meta...)
return nothing
end
function report(reporter::ProgressMeterReportMCMC, message::AbstractString; meta...)
return nothing
end
function report(reporter::ProgressMeterReport, step::Integer; meta...)
return nothing
end
function report(reporter::ProgressMeterReportMCMC, step::Integer; meta...)
ProgressMeter.next!(reporter.progress_meter)
return nothing
end
"""
$(SIGNATURES)
Return a default reporter, taking the environment into account. Keyword arguments are passed
to constructors when applicable.
"""
function default_reporter(; kwargs...)
if isinteractive()
LogProgressReport(; kwargs...)
else
NoProgressReport()
end
end