Skip to content

ReverseDiff: Do not always compile tape#85

Merged
tpapp merged 2 commits into
tpapp:masterfrom
devmotion:dw/reversediff
Jul 21, 2022
Merged

ReverseDiff: Do not always compile tape#85
tpapp merged 2 commits into
tpapp:masterfrom
devmotion:dw/reversediff

Conversation

@devmotion
Copy link
Copy Markdown
Collaborator

It's nice - and very efficient - to pre-compute and compile tapes with ReverseDiff whenever possible. Unfortunately, if there is some branching it leads to incorrect derivatives. This is even the case for the example in the tests. A simpler example that demonstrates the issue:

julia> using ReverseDiff                                   
                                                           
julia> f(x) = x < zero(x) ? zero(x) : x   
f (generic function with 1 method)   

julia> ReverseDiff.gradient(f  only, [-1.23])
1-element Vector{Float64}:
 0.0

julia> ReverseDiff.gradient(f  only, [1.41])
1-element Vector{Float64}:
 1.0

julia> tape = ReverseDiff.GradientTape(f, [-1.3])
typename(ReverseDiff.GradientTape)(f)

julia> ReverseDiff.gradient!(tape, [1.34])
1-element Vector{Float64}:
 0.0

julia> ReverseDiff.gradient!(tape, [-0.2])
1-element Vector{Float64}:
 0.0

julia> compiledtape = ReverseDiff.compile(tape)
typename(ReverseDiff.CompiledTape)(f)

julia> ReverseDiff.gradient!(tape, [1.34])
1-element Vector{Float64}:
 0.0

julia> ReverseDiff.gradient!(tape, [-0.2])
1-element Vector{Float64}:
 0.0

Thus I suggest to not use compiled tapes by default (alternatively, one could use something like https://github.com/SciML/SciMLSensitivity.jl/blob/4a608f15413e5fdedd9939093d97d5cd982205f6/src/hasbranching.jl to determine if the function has branches but that seems a bit too dependency-heavy for LogDensityProblems, I assume).

The PR adds a (default) option for not compiling the tape and support for ReverseDiff without tapes.

Copy link
Copy Markdown
Owner

@tpapp tpapp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR and the very thorough tests. Just requesting a docstring change, then I will merge.

Comment thread src/AD_ReverseDiff.jl
tape = ReverseDiff.GradientTape(f, x)
compiledtape = ReverseDiff.compile(tape)
ReverseDiffLogDensity(ℓ, compiledtape)
function ADgradient(::Val{:ReverseDiff}, ℓ;
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please document the compile keyword in the docstring above?

@tpapp
Copy link
Copy Markdown
Owner

tpapp commented Jul 21, 2022

Thanks for the update, I will tag in a minute.

@tpapp tpapp merged commit 46c939d into tpapp:master Jul 21, 2022
@devmotion devmotion deleted the dw/reversediff branch July 21, 2022 09:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants