-
Notifications
You must be signed in to change notification settings - Fork 6
/
LogDensityProblemsADReverseDiffExt.jl
77 lines (63 loc) · 2.72 KB
/
LogDensityProblemsADReverseDiffExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Gradient AD implementation using ReverseDiff.
"""
module LogDensityProblemsADReverseDiffExt
if isdefined(Base, :get_extension)
using LogDensityProblemsAD: ADGradientWrapper, SIGNATURES, dimension, logdensity
import LogDensityProblemsAD: ADgradient, logdensity_and_gradient
import ReverseDiff
import ReverseDiff: DiffResults
else
using ..LogDensityProblemsAD: ADGradientWrapper, SIGNATURES, dimension, logdensity
import ..LogDensityProblemsAD: ADgradient, logdensity_and_gradient
import ..ReverseDiff
import ..ReverseDiff: DiffResults
end
# Load DiffResults helpers
include("DiffResults_helpers.jl")
struct ReverseDiffLogDensity{L,C} <: ADGradientWrapper
ℓ::L
compiledtape::C
end
"""
ADgradient(:ReverseDiff, ℓ; compile=Val(false), x=nothing)
ADgradient(Val(:ReverseDiff), ℓ; compile=Val(false), x=nothing)
Gradient using algorithmic/automatic differentiation via ReverseDiff.
If `compile isa Val{true}`, a tape of the log density computation is created upon construction of the gradient function and used in every evaluation of the gradient.
One may provide an example input `x::AbstractVector` of the log density function.
If `x` is `nothing` (the default), the tape is created with input `zeros(dimension(ℓ))`.
By default, no tape is created.
!!! note
Using a compiled tape can lead to significant performance improvements when the gradient of the log density
is evaluated multiple times (possibly for different inputs).
However, if the log density contains branches, use of a compiled tape can lead to silently incorrect results.
"""
function ADgradient(::Val{:ReverseDiff}, ℓ;
compile::Union{Val{true},Val{false}}=Val(false),
x::Union{Nothing,AbstractVector}=nothing)
ReverseDiffLogDensity(ℓ, _compiledtape(ℓ, compile, x))
end
_compiledtape(ℓ, compile, x) = nothing
_compiledtape(ℓ, ::Val{true}, ::Nothing) = _compiledtape(ℓ, Val(true), zeros(dimension(ℓ)))
function _compiledtape(ℓ, ::Val{true}, x)
tape = ReverseDiff.GradientTape(Base.Fix1(logdensity, ℓ), x)
return ReverseDiff.compile(tape)
end
function Base.show(io::IO, ∇ℓ::ReverseDiffLogDensity)
print(io, "ReverseDiff AD wrapper for ", ∇ℓ.ℓ, " (")
if ∇ℓ.compiledtape === nothing
print(io, "no ")
end
print(io, "compiled tape)")
end
function logdensity_and_gradient(∇ℓ::ReverseDiffLogDensity, x::AbstractVector)
(; ℓ, compiledtape) = ∇ℓ
buffer = _diffresults_buffer(x)
if compiledtape === nothing
result = ReverseDiff.gradient!(buffer, Base.Fix1(logdensity, ℓ), x)
else
result = ReverseDiff.gradient!(buffer, compiledtape, x)
end
_diffresults_extract(result)
end
end # module