-
Notifications
You must be signed in to change notification settings - Fork 39
Open
Description
julia> function f(x,p)
grad = FiniteDiff.finite_difference_gradient(y -> sum(y.^3), x)
return grad .* p
end
f (generic function with 1 method)
julia> x,p = rand(3),rand(3);
julia> Zygote.gradient(p->sum(f(x,p)), p)[1]
ERROR: Mutating arrays is not supported -- called setindex!(Vector{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)
Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/dev/limitations.html#Array-mutation-1
Stacktrace:
[1] error(s::String)
@ Base .\error.jl:33
[2] _throw_mutation_error(f::Function, args::Vector{Float64})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\array.jl:70
[3] (::Zygote.var"#444#445"{Vector{Float64}})(#unused#::Nothing)
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\array.jl:82
[4] (::Zygote.var"#2496#back#446"{Zygote.var"#444#445"{Vector{Float64}}})(Δ::Nothing)
@ Zygote C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
[5] Pullback
@ C:\Users\Luffy\.julia\packages\FiniteDiff\KkXlb\src\gradients.jl:277 [inlined]
[6] (::typeof(∂(#finite_difference_gradient!#16)))(Δ::Vector{Float64})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[7] Pullback
@ C:\Users\Luffy\.julia\packages\FiniteDiff\KkXlb\src\gradients.jl:224 [inlined]
[8] (::typeof(∂(finite_difference_gradient!##kw)))(Δ::Vector{Float64})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[9] Pullback
@ C:\Users\Luffy\.julia\packages\FiniteDiff\KkXlb\src\gradients.jl:88 [inlined]
[10] (::typeof(∂(#finite_difference_gradient#12)))(Δ::Vector{Float64})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[11] Pullback
@ C:\Users\Luffy\.julia\packages\FiniteDiff\KkXlb\src\gradients.jl:70 [inlined]
[12] (::typeof(∂(finite_difference_gradient)))(Δ::Vector{Float64})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[13] Pullback
@ C:\Users\Luffy\.julia\packages\FiniteDiff\KkXlb\src\gradients.jl:70 [inlined]
[14] (::typeof(∂(finite_difference_gradient)))(Δ::Vector{Float64})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[15] Pullback
@ .\REPL[94]:2 [inlined]
[16] (::typeof(∂(f)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[17] Pullback
@ .\REPL[95]:1 [inlined]
[18] (::typeof(∂(#30)))(Δ::Float64)
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[19] (::Zygote.var"#60#61"{typeof(∂(#30))})(Δ::Float64)
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41
[20] gradient(f::Function, args::Vector{Float64})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:76
[21] top-level scope
@ REPL[95]:1
[22] top-level scope
@ C:\Users\Luffy\.julia\packages\CUDA\DfvRa\src\initialization.jl:52
Metadata
Metadata
Assignees
Labels
No labels
Type
Projects
Milestone
Relationships
Development
Select code repository
Activity
ArnoStrouwen commentedon Nov 11, 2022
It is tricky to make this completely non-mutating for
AbstractArray
.If you look how this case is handled in
ForwardDiff
, it is also has mutations in there, even for out of placef
:https://github.com/JuliaDiff/ForwardDiff.jl/blob/master/src/gradient.jl#L106
https://github.com/JuliaDiff/ForwardDiff.jl/blob/master/src/apiutils.jl#L36
https://github.com/JuliaDiff/ForwardDiff.jl/blob/master/src/apiutils.jl#L58
StaticArrays
is then handled by a specialized dispatch.Zygote
seems to be handled by custom rules:https://github.com/FluxML/Zygote.jl/blob/master/src/lib/forward.jl#L140
Notice how this is not really reverse over forward.
The issue is that you have a bunch of partial derivatives:
https://github.com/JuliaDiff/FiniteDiff.jl/blob/master/src/gradients.jl#L287
But how do you gather them into an array of the correct type without mutating?
The correct type being the array type of the input of
f
with the element type changed to the type of the output off
?ChrisRackauckas commentedon Nov 11, 2022
If you use out of place
setindex
it should be fine?ArnoStrouwen commentedon Nov 11, 2022
I have not used
setindex
much, but I thought that was only defined forTuple
andStaticArray
, not forArray
?Are there things in
ArrayInterface
that extends this?ChrisRackauckas commentedon Nov 11, 2022
FiniteDiff.jl has an extended version (shadowed)