Skip to content

Commit

Permalink
Use Jacobian from DiffEqFunction if provided
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Aug 18, 2019
1 parent fd350ad commit 9e61845
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 6 deletions.
1 change: 1 addition & 0 deletions Project.toml
Expand Up @@ -5,6 +5,7 @@ version = "0.0.1"

[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
2 changes: 2 additions & 0 deletions src/Bifurcations.jl
Expand Up @@ -5,6 +5,8 @@ export init, solve, solve!, step!
using DiffEqBase: init, solve, solve!, step!
# see: continuations/solver.jl

using FillArrays: Eye

include("utils/utils.jl")

include("continuations/continuations.jl")
Expand Down
53 changes: 49 additions & 4 deletions src/diffeq.jl
@@ -1,5 +1,5 @@
using DiffEqBase: AbstractODEProblem, DiscreteProblem
using Setfield: Lens, set, get
using Setfield: Lens, set, get, @set

const DEP{iip} = AbstractODEProblem{uType, tType, iip} where {uType, tType}

Expand All @@ -9,11 +9,14 @@ TimeKind(::Type{<: AbstractODEProblem}) = Continuous()
StateKind(::Type{<: DEP{true}}) = MutableState()
StateKind(::Type{<: DEP{false}}) = ImmutableState()

struct DiffEqWrapper{P, L}
struct DiffEqWrapper{P, L, D}
de_prob::P
param_axis::L
param_diff::D
end

DiffEqWrapper(de_prob, param_axis) = DiffEqWrapper(de_prob, param_axis, nothing)

function diffeq_homotopy(H, x, p::DiffEqWrapper{<:DEP{true}}, t)
q = set(p.de_prob.p, p.param_axis, t)
p.de_prob.f(H, x, q, 0)
Expand All @@ -26,10 +29,40 @@ function diffeq_homotopy(x, p::DiffEqWrapper{<:DEP{false}}, t)
return maybe_subtract!(H, x, statekind(p.de_prob), timekind(p.de_prob))
end

maybe_subtract!(H, x, p) = maybe_subtract!(H, x, statekind(p), timekind(p))
maybe_subtract!(H, ::Any, ::StateKind, ::Continuous) = H
maybe_subtract!(H, x, ::MutableState, ::Discrete) = H .-= x
maybe_subtract!(H, x, ::ImmutableState, ::Discrete) = H .- x

struct MutableParamDiff{F, X, C}
param_to_state::F
x::X
cfg::C
end

function MutableParamDiff(p::DiffEqWrapper)
x = deepcopy(p.de_prob.u0)
param_to_state = let x = x
(H, t) -> diffeq_homotopy(H, x, p, t)
end
cfg = ForwardDiff.DerivativeConfig(param_to_state, x, 1.0)
return MutableParamDiff(param_to_state, x, cfg)
end

function diffeq_homotopy_jacobian(H, J, x, p::DiffEqWrapper{<:DEP{true}}, t)
p.param_diff :: MutableParamDiff
p.param_diff.x .= x
ForwardDiff.derivative!(
(@view J[:, end]),
p.param_diff.param_to_state,
H,
t,
p.param_diff.cfg,
)
q = set(p.de_prob.p, p.param_axis, t)
p.de_prob.f.jac((@view J[:, 1:end-1]), x, q, t)
return (H, maybe_subtract!(J, Eye(size(J)...), p.de_prob))
end

"""
BifurcationProblem(ode_or_map::AbstractODEProblem,
Expand All @@ -45,10 +78,22 @@ maybe_subtract!(H, x, ::ImmutableState, ::Discrete) = H .- x
"""
function BifurcationProblem(
de_prob::DEP, param_axis::Lens, t_domain::Tuple;
kwargs...)
kwargs0...)
de_prob = deepcopy(de_prob)
u0 = de_prob.u0
t0 = get(de_prob.p, param_axis)
p = DiffEqWrapper(deepcopy(de_prob), param_axis)
p0 = DiffEqWrapper(de_prob, param_axis)
if de_prob.f.jac === nothing
p = p0
kwargs = kwargs0
elseif de_prob isa DEP{true}
p = @set p0.param_diff = MutableParamDiff(p0)
kwargs = (; homotopy_jacobian=diffeq_homotopy_jacobian, kwargs0...)
else
# TODO: implement
p = p0
kwargs = kwargs0
end
return FixedPointBifurcationProblem(
statekind(de_prob),
timekind(de_prob),
Expand Down
18 changes: 17 additions & 1 deletion src/examples/calcium.jl
@@ -1,6 +1,7 @@
module Calcium

using DiffEqBase: ODEProblem
using DiffEqBase: ODEProblem, ODEFunction
using ForwardDiff
using Parameters: @with_kw, @unpack
using StaticArrays: SVector
using Setfield: @lens
Expand Down Expand Up @@ -38,7 +39,13 @@ function f(du, u, p, t)
nothing
end

function jac(J, u, p, t)
J .= ForwardDiff.jacobian(x -> f(x, p, t), SVector{2}(u))
nothing
end

const _f = f
const _jac = jac


make_prob(
Expand Down Expand Up @@ -76,4 +83,13 @@ function make_codim2_prob(
kwargs...)
end

make_ode_jac(
p = CalciumParam();
u0 = [-170.0, -170.0],
tspan = (0.0, 30.0),
f = _f,
jac = _jac,
) =
ODEProblem(ODEFunction{true}(f; jac = jac), u0, tspan, p)

end # module
2 changes: 1 addition & 1 deletion src/fixedpoint.jl
Expand Up @@ -115,7 +115,7 @@ FixedPointBifurcationCache(prob::FixedPointBifurcationProblem) =
_FixedPointBifurcationCache(statekind(prob), hasjac(prob), prob)

_FixedPointBifurcationCache(::Any, ::HasJac, prob) =
FixedPointBifurcationCache(prob, nothing)
FixedPointBifurcationCache(prob, nothing, nothing)

function _FixedPointBifurcationCache(::MutableState, ::NoJac, prob)
x = get_u0(prob)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Expand Up @@ -27,6 +27,7 @@ TEST_GROUPS = Dict{String, Vector{String}}(
"2" => [
"test_fold_lc.jl",
"test_morris_lecar.jl",
"test_diffeq_jacobian.jl",
],
)
TEST_GROUPS["all"] = vcat(last.(sort(collect(TEST_GROUPS), by=first))...)
Expand Down
29 changes: 29 additions & 0 deletions test/test_diffeq_jacobian.jl
@@ -0,0 +1,29 @@
module TestDiffeqJacobian
include("preamble.jl")

using Bifurcations.Continuations: get_prob_cache, get_u0, residual_jacobian
using Bifurcations.Examples.Calcium

@testset "ODEFunction (inplace)" begin
prob0 = Calcium.make_prob()
prob1 = Calcium.make_prob(ode=Calcium.make_ode_jac())
cache0 = get_prob_cache(prob0)
cache1 = get_prob_cache(prob1)

@test prob0.p.de_prob.f.jac === nothing
@test prob1.p.de_prob.f.jac === Calcium.jac

u0 = get_u0(prob0)
rng = MersenneTwister(0)
for _ in 1:5
u = typeof(u0)(randn(rng, eltype(u0), length(u0)))

H0, J0 = residual_jacobian(u, cache0)
H1, J1 = residual_jacobian(Array(u), cache1)

@test H1 H0
@test J1 J0
end
end

end # module

0 comments on commit 9e61845

Please sign in to comment.