Skip to content

Commit

Permalink
Use Jacobian from DiffEqFunction if provided
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Aug 14, 2019
1 parent fd350ad commit 5b3ae50
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 5 deletions.
52 changes: 48 additions & 4 deletions src/diffeq.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using DiffEqBase: AbstractODEProblem, DiscreteProblem
using Setfield: Lens, set, get
using Setfield: Lens, set, get, @set

const DEP{iip} = AbstractODEProblem{uType, tType, iip} where {uType, tType}

Expand All @@ -9,11 +9,14 @@ TimeKind(::Type{<: AbstractODEProblem}) = Continuous()
StateKind(::Type{<: DEP{true}}) = MutableState()
StateKind(::Type{<: DEP{false}}) = ImmutableState()

struct DiffEqWrapper{P, L}
struct DiffEqWrapper{P, L, D}
de_prob::P
param_axis::L
param_diff::D
end

DiffEqWrapper(de_prob, param_axis) = DiffEqWrapper(de_prob, param_axis, nothing)

function diffeq_homotopy(H, x, p::DiffEqWrapper{<:DEP{true}}, t)
q = set(p.de_prob.p, p.param_axis, t)
p.de_prob.f(H, x, q, 0)
Expand All @@ -30,6 +33,35 @@ maybe_subtract!(H, ::Any, ::StateKind, ::Continuous) = H
maybe_subtract!(H, x, ::MutableState, ::Discrete) = H .-= x
maybe_subtract!(H, x, ::ImmutableState, ::Discrete) = H .- x

struct MutableParamDiff{F, X, C}
param_to_state::F
x::X
cfg::C
end

function MutableParamDiff(p::DiffEqWrapper)
x = deepcopy(p.de_prob.u0)
param_to_state = let x = x
(H, t) -> diffeq_homotopy(H, x, p, t)
end
cfg = ForwardDiff.DerivativeConfig(param_to_state, x, 1.0)
return MutableParamDiff(param_to_state, x, cfg)
end

function diffeq_homotopy_jacobian(H, J, x, p::DiffEqWrapper{<:DEP{true}}, t)
p.param_diff :: MutableParamDiff
p.param_diff.x .= x
ForwardDiff.derivative!(
(@view J[:, end]),
p.param_diff.param_to_state,
H,
t,
p.param_diff.cfg,
)
q = set(p.de_prob.p, p.param_axis, t)
p.de_prob.f.jac((@view J[:, 1:end-1]), x, q, t)
return (H, J)
end

"""
BifurcationProblem(ode_or_map::AbstractODEProblem,
Expand All @@ -45,10 +77,22 @@ maybe_subtract!(H, x, ::ImmutableState, ::Discrete) = H .- x
"""
function BifurcationProblem(
de_prob::DEP, param_axis::Lens, t_domain::Tuple;
kwargs...)
kwargs0...)
de_prob = deepcopy(de_prob)
u0 = de_prob.u0
t0 = get(de_prob.p, param_axis)
p = DiffEqWrapper(deepcopy(de_prob), param_axis)
p0 = DiffEqWrapper(de_prob, param_axis)
if de_prob.f.jac === nothing
p = p0
kwargs = kwargs0
elseif de_prob isa DEP{true}
p = @set p0.param_diff = MutableParamDiff(p0)
kwargs = (; homotopy_jacobian=diffeq_homotopy_jacobian, kwargs0...)
else
# TODO: implement
p = p0
kwargs = kwargs0
end
return FixedPointBifurcationProblem(
statekind(de_prob),
timekind(de_prob),
Expand Down
2 changes: 1 addition & 1 deletion src/fixedpoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ FixedPointBifurcationCache(prob::FixedPointBifurcationProblem) =
_FixedPointBifurcationCache(statekind(prob), hasjac(prob), prob)

_FixedPointBifurcationCache(::Any, ::HasJac, prob) =
FixedPointBifurcationCache(prob, nothing)
FixedPointBifurcationCache(prob, nothing, nothing)

function _FixedPointBifurcationCache(::MutableState, ::NoJac, prob)
x = get_u0(prob)
Expand Down

0 comments on commit 5b3ae50

Please sign in to comment.