Skip to content

Commit

Permalink
Handle other non-differentiables
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Sep 22, 2019
1 parent cc6c997 commit ef60836
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 6 deletions.
40 changes: 34 additions & 6 deletions src/ChainCutters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,45 @@ for op in (*, +, -)
end
end

# Based on `Zygote.broadcast_forward`:
const NonDifferentiableType = Union{
Const,
# From `Broadcast.broadcastable(x) = Ref(x)`:
Symbol,
AbstractString,
# Function, # closures may contain `Real`s
UndefInitializer,
Nothing,
RoundingMode,
Missing,
Val,
Ptr,
Regex,
# From `Type` is also treated similarly in `Broadcast.broadcastable`:
Type,
}

nondifferentiable(::T) where T = nondifferentiable(T)
nondifferentiable(::Type) = false
nondifferentiable(::Type{<:NonDifferentiableType}) = true
nondifferentiable(::Type{<:AbstractArray{<:NonDifferentiableType}}) = true

differentiable(::T) where T = differentiable(T)
differentiable(::Type) = false
differentiable(::Type{<:Real}) = true
differentiable(::Type{<:AbstractArray{<:Real}}) = true
# How about Union{Missing,Real}?

supported(x) = nondifferentiable(x) || differentiable(x)

dual(x::Real, p) = Dual(x, p)
# Based on `Zygote.broadcast_forward`:

function dual_function(f::F, args0::NTuple{N, Any}) where {F, N}
nvariables = _count(x -> !(x isa Const), args0)
partials, = foldlargs(((), 0), args0...) do (partials, n), x
if x isa Const
if nondifferentiable(x)
((partials..., nothing), n)
else
@assert differentiable(x)
i = n + 1
((partials..., ntuple(j -> i == j, nvariables)), i)
end
Expand All @@ -157,7 +186,7 @@ function dual_function(f::F, args0::NTuple{N, Any}) where {F, N}
if partials[i] === nothing
args[i]
else
dual(args[i], partials[i])
Dual(args[i], partials[i])
end
end
return f(ds...)
Expand All @@ -169,8 +198,7 @@ broadcast_adjoint(f, args::Vararg{Const}) =

function broadcast_adjoint(f, args0...)
map(args0) do x
x isa Const && return
eltype(x) <: Real && return
supported(x) && return
throw(ArgumentError(string(
"Differentiation w.r.t ", x, " is not supported.\n",
"Use `cut` to mark it as a constant.",
Expand Down
10 changes: 10 additions & 0 deletions test/test_broadcastablecallable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,19 @@ end
y_actual, back_actual = Zygote.forward(v -> sum(f.(cut(u), v)), v)
y_desired, back_desired = Zygote.forward(v -> sum(f.(u, v)), v)
@test y_actual == y_desired
@test back_actual(1) isa Tuple{Vector{Float64}}
@test back_actual(1) == back_desired(1)

y_actual, back_actual = Zygote.forward(u -> sum(f.(u, cut(v))), u)
y_desired, back_desired = Zygote.forward(u -> sum(f.(u, v)), u)
@test y_actual == y_desired
@test back_actual(1) isa Tuple{Vector{Float64}}
@test back_actual(1) == back_desired(1)

y_actual, back_actual = Zygote.forward(f -> sum(f.(cut(u), cut(v))), f)
y_desired, back_desired = Zygote.forward(f -> sum(f.(u, v)), f)
@test y_actual == y_desired
@test back_actual(1) isa Tuple{NamedTuple{(:a, :b)}}
@test back_actual(1) == back_desired(1)

y_partialcut, back_partialcut = Zygote.forward(f) do f
Expand Down Expand Up @@ -121,4 +124,11 @@ end
@test back(1) === (nothing,)
end

@testset "NonDifferentiableType" begin
f = AddCall(Poly3(rand(4)...), Poly3(rand(4)...))
y, back = Zygote.forward(x -> cut(f).(x), missing)
@test y === missing
@test back(1) === (nothing,)
end

end # module

0 comments on commit ef60836

Please sign in to comment.