diff --git a/Project.toml b/Project.toml index c0f88ec..1d304a3 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.0.1" [deps] DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/Bifurcations.jl b/src/Bifurcations.jl index 91e83c1..fbe35a4 100644 --- a/src/Bifurcations.jl +++ b/src/Bifurcations.jl @@ -5,6 +5,8 @@ export init, solve, solve!, step! using DiffEqBase: init, solve, solve!, step! # see: continuations/solver.jl +using FillArrays: Eye + include("utils/utils.jl") include("continuations/continuations.jl") diff --git a/src/diffeq.jl b/src/diffeq.jl index 0c22b33..e2d8e93 100644 --- a/src/diffeq.jl +++ b/src/diffeq.jl @@ -1,5 +1,5 @@ using DiffEqBase: AbstractODEProblem, DiscreteProblem -using Setfield: Lens, set, get +using Setfield: Lens, set, get, @set const DEP{iip} = AbstractODEProblem{uType, tType, iip} where {uType, tType} @@ -9,11 +9,14 @@ TimeKind(::Type{<: AbstractODEProblem}) = Continuous() StateKind(::Type{<: DEP{true}}) = MutableState() StateKind(::Type{<: DEP{false}}) = ImmutableState() -struct DiffEqWrapper{P, L} +struct DiffEqWrapper{P, L, D} de_prob::P param_axis::L + param_diff::D end +DiffEqWrapper(de_prob, param_axis) = DiffEqWrapper(de_prob, param_axis, nothing) + function diffeq_homotopy(H, x, p::DiffEqWrapper{<:DEP{true}}, t) q = set(p.de_prob.p, p.param_axis, t) p.de_prob.f(H, x, q, 0) @@ -26,10 +29,40 @@ function diffeq_homotopy(x, p::DiffEqWrapper{<:DEP{false}}, t) return maybe_subtract!(H, x, statekind(p.de_prob), timekind(p.de_prob)) end +maybe_subtract!(H, x, p) = maybe_subtract!(H, x, statekind(p), timekind(p)) maybe_subtract!(H, ::Any, ::StateKind, ::Continuous) = H maybe_subtract!(H, x, ::MutableState, ::Discrete) = H .-= x maybe_subtract!(H, x, ::ImmutableState, ::Discrete) = H .- x +struct MutableParamDiff{F, X, C} + param_to_state::F + x::X + cfg::C +end + +function MutableParamDiff(p::DiffEqWrapper) + x = deepcopy(p.de_prob.u0) + param_to_state = let x = x + (H, t) -> diffeq_homotopy(H, x, p, t) + end + cfg = ForwardDiff.DerivativeConfig(param_to_state, x, 1.0) + return MutableParamDiff(param_to_state, x, cfg) +end + +function diffeq_homotopy_jacobian(H, J, x, p::DiffEqWrapper{<:DEP{true}}, t) + p.param_diff :: MutableParamDiff + p.param_diff.x .= x + ForwardDiff.derivative!( + (@view J[:, end]), + p.param_diff.param_to_state, + H, + t, + p.param_diff.cfg, + ) + q = set(p.de_prob.p, p.param_axis, t) + p.de_prob.f.jac((@view J[:, 1:end-1]), x, q, t) + return (H, maybe_subtract!(J, Eye(size(J)...), p.de_prob)) +end """ BifurcationProblem(ode_or_map::AbstractODEProblem, @@ -45,10 +78,22 @@ maybe_subtract!(H, x, ::ImmutableState, ::Discrete) = H .- x """ function BifurcationProblem( de_prob::DEP, param_axis::Lens, t_domain::Tuple; - kwargs...) + kwargs0...) + de_prob = deepcopy(de_prob) u0 = de_prob.u0 t0 = get(de_prob.p, param_axis) - p = DiffEqWrapper(deepcopy(de_prob), param_axis) + p0 = DiffEqWrapper(de_prob, param_axis) + if de_prob.f.jac === nothing + p = p0 + kwargs = kwargs0 + elseif de_prob isa DEP{true} + p = @set p0.param_diff = MutableParamDiff(p0) + kwargs = (; homotopy_jacobian=diffeq_homotopy_jacobian, kwargs0...) + else + # TODO: implement + p = p0 + kwargs = kwargs0 + end return FixedPointBifurcationProblem( statekind(de_prob), timekind(de_prob), diff --git a/src/examples/calcium.jl b/src/examples/calcium.jl index 9fe946a..ce76959 100644 --- a/src/examples/calcium.jl +++ b/src/examples/calcium.jl @@ -1,6 +1,7 @@ module Calcium -using DiffEqBase: ODEProblem +using DiffEqBase: ODEProblem, ODEFunction +using ForwardDiff using Parameters: @with_kw, @unpack using StaticArrays: SVector using Setfield: @lens @@ -38,7 +39,13 @@ function f(du, u, p, t) nothing end +function jac(J, u, p, t) + J .= ForwardDiff.jacobian(x -> f(x, p, t), SVector{2}(u)) + nothing +end + const _f = f +const _jac = jac make_prob( @@ -76,4 +83,13 @@ function make_codim2_prob( kwargs...) end +make_ode_jac( + p = CalciumParam(); + u0 = [-170.0, -170.0], + tspan = (0.0, 30.0), + f = _f, + jac = _jac, +) = + ODEProblem(ODEFunction{true}(f; jac = jac), u0, tspan, p) + end # module diff --git a/src/fixedpoint.jl b/src/fixedpoint.jl index 876966b..e8f6af4 100644 --- a/src/fixedpoint.jl +++ b/src/fixedpoint.jl @@ -115,7 +115,7 @@ FixedPointBifurcationCache(prob::FixedPointBifurcationProblem) = _FixedPointBifurcationCache(statekind(prob), hasjac(prob), prob) _FixedPointBifurcationCache(::Any, ::HasJac, prob) = - FixedPointBifurcationCache(prob, nothing) + FixedPointBifurcationCache(prob, nothing, nothing) function _FixedPointBifurcationCache(::MutableState, ::NoJac, prob) x = get_u0(prob) diff --git a/test/runtests.jl b/test/runtests.jl index b89d26b..d5e3b8b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,6 +27,7 @@ TEST_GROUPS = Dict{String, Vector{String}}( "2" => [ "test_fold_lc.jl", "test_morris_lecar.jl", + "test_diffeq_jacobian.jl", ], ) TEST_GROUPS["all"] = vcat(last.(sort(collect(TEST_GROUPS), by=first))...) diff --git a/test/test_diffeq_jacobian.jl b/test/test_diffeq_jacobian.jl new file mode 100644 index 0000000..c8f6a0a --- /dev/null +++ b/test/test_diffeq_jacobian.jl @@ -0,0 +1,29 @@ +module TestDiffeqJacobian +include("preamble.jl") + +using Bifurcations.Continuations: get_prob_cache, get_u0, residual_jacobian +using Bifurcations.Examples.Calcium + +@testset "ODEFunction (inplace)" begin + prob0 = Calcium.make_prob() + prob1 = Calcium.make_prob(ode=Calcium.make_ode_jac()) + cache0 = get_prob_cache(prob0) + cache1 = get_prob_cache(prob1) + + @test prob0.p.de_prob.f.jac === nothing + @test prob1.p.de_prob.f.jac === Calcium.jac + + u0 = get_u0(prob0) + rng = MersenneTwister(0) + for _ in 1:5 + u = typeof(u0)(randn(rng, eltype(u0), length(u0))) + + H0, J0 = residual_jacobian(u, cache0) + H1, J1 = residual_jacobian(Array(u), cache1) + + @test H1 ≈ H0 + @test J1 ≈ J0 + end +end + +end # module