From 0eefebc39e431ba5c88c5144268ca18376dd11fb Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Mon, 18 Jun 2018 18:20:27 -0700 Subject: [PATCH] Add TimeKind trait --- src/Bifurcations.jl | 1 + src/diffeq.jl | 31 ++++++++++++++++++++++++++---- src/fixedpoint.jl | 46 ++++++++++++++++++++++++++++++--------------- src/traits.jl | 5 +++++ 4 files changed, 64 insertions(+), 19 deletions(-) create mode 100644 src/traits.jl diff --git a/src/Bifurcations.jl b/src/Bifurcations.jl index 08e3724..23c4d00 100644 --- a/src/Bifurcations.jl +++ b/src/Bifurcations.jl @@ -11,6 +11,7 @@ import .Continuations: get_prob_cache, get_u0, residual!, residual_jacobian!, residual, isindomain const _C = AbstractProblemCache +include("traits.jl") include("fixedpoint.jl") include("diffeq.jl") include("examples/examples.jl") diff --git a/src/diffeq.jl b/src/diffeq.jl index 3428be1..80ecc74 100644 --- a/src/diffeq.jl +++ b/src/diffeq.jl @@ -1,8 +1,23 @@ -using DiffEqBase: AbstractODEProblem +using DiffEqBase: AbstractODEProblem, DiscreteProblem using Setfield: Lens, set, get const DEP{iip} = AbstractODEProblem{uType, tType, iip} where {uType, tType} + +abstract type StateKind end +struct MutableState <: StateKind end +struct ImmutableState <: StateKind end + +statekind(::T) where T = StateKind(T) +# TODO: Move StateKind to continuations/base.jl and use it everywhere. +# It's only used in here at the moment. + + +TimeKind(::Type{<: DiscreteProblem}) = Discrete() +TimeKind(::Type{<: AbstractODEProblem}) = Continuous() +StateKind(::Type{<: DEP{true}}) = MutableState() +StateKind(::Type{<: DEP{false}}) = ImmutableState() + struct DiffEqWrapper{P, L} de_prob::P param_axis::L @@ -11,13 +26,20 @@ end function diffeq_homotopy(H, x, p::DiffEqWrapper{<:DEP{true}}, t) q = set(p.param_axis, p.de_prob.p, t) p.de_prob.f(H, x, q, 0) + maybe_subtract!(H, x, statekind(p.de_prob), timekind(p.de_prob)) end function diffeq_homotopy(x, p::DiffEqWrapper{<:DEP{false}}, t) q = set(p.param_axis, p.de_prob.p, t) - return p.de_prob.f(x, q, 0) + H = p.de_prob.f(x, q, 0) + return maybe_subtract!(H, x, statekind(p.de_prob), timekind(p.de_prob)) end +maybe_subtract!(H, ::Any, ::StateKind, ::Continuous) = H +maybe_subtract!(H, x, ::MutableState, ::Discrete) = H .-= x +maybe_subtract!(H, x, ::ImmutableState, ::Discrete) = H .- x + + """ FixedPointBifurcationProblem(ode_or_map::AbstractODEProblem, param_axis::Lens, @@ -36,6 +58,7 @@ function FixedPointBifurcationProblem( u0 = de_prob.u0 t0 = get(param_axis, de_prob.p) p = DiffEqWrapper(deepcopy(de_prob), param_axis) - return FixedPointBifurcationProblem{iip}(diffeq_homotopy, u0, t0, - t_domain, p; kwargs...) + return FixedPointBifurcationProblem{iip, typeof(timekind(de_prob))}( + diffeq_homotopy, u0, t0, + t_domain, p; kwargs...) end diff --git a/src/fixedpoint.jl b/src/fixedpoint.jl index 473e2fb..4ade0b7 100644 --- a/src/fixedpoint.jl +++ b/src/fixedpoint.jl @@ -3,7 +3,9 @@ using StaticArrays: SVector, push using ForwardDiff -struct FixedPointBifurcationProblem{iip, HJ, H, U, T, P, +struct FixedPointBifurcationProblem{iip, + tkind <: TimeKind, + HJ, H, U, T, P, } <: AbstractContinuationProblem{iip} homotopy_jacobian::HJ homotopy::H @@ -14,30 +16,44 @@ struct FixedPointBifurcationProblem{iip, HJ, H, U, T, P, # TODO: Define domain for u. Maybe use Domains.jl? - function FixedPointBifurcationProblem{iip}( + function FixedPointBifurcationProblem{iip, tkind}( homotopy::H, u0::U, t0::Real, t_domain::Tuple, p::P = nothing; homotopy_jacobian::HJ = nothing, - ) where{iip, HJ, H, U, P} + ) where{iip, tkind, HJ, H, U, P} T = promote_type(typeof(t0), map(typeof, t_domain)...) - new{iip, HJ, H, U, T, P}(homotopy_jacobian, homotopy, - u0, t0, t_domain, p) + new{iip, tkind, HJ, H, U, T, P}( + homotopy_jacobian, homotopy, + u0, t0, t_domain, p) end end -const FPBPWithHJac{iip} = FixedPointBifurcationProblem{iip, <: Function} -const FPBPNoHJac{iip} = FixedPointBifurcationProblem{iip, Void} -const FPBPScalar = FixedPointBifurcationProblem{false, HJ, H, - <: Real} where {HJ, H} -function FixedPointBifurcationProblem(homotopy, args...; kwargs...) +const FPBPWithHJac{iip, tkind} = + FixedPointBifurcationProblem{iip, tkind, <: Function} +const FPBPNoHJac{iip, tkind} = + FixedPointBifurcationProblem{iip, tkind, Void} +const FPBPScalar{tkind <: TimeKind} = + FixedPointBifurcationProblem{false, tkind, HJ, H, <: Real} where {HJ, H} + +function FixedPointBifurcationProblem(tkind::TimeKind, + homotopy, args...; kwargs...) iip = numargs(homotopy) == 4 - return FixedPointBifurcationProblem{iip}(homotopy, args...; kwargs...) + return FixedPointBifurcationProblem{iip, tkind}( + homotopy, args...; kwargs...) end -FixedPointBifurcationProblem(homotopy, u0::Tuple, args...; kwargs...) = - FixedPointBifurcationProblem{false}(homotopy, - SVector(u0), - args...; kwargs...) +as_immutable_state(x::Tuple) = SVector(x) +as_immutable_state(x::Number) = x + +function FixedPointBifurcationProblem(tkind::TimeKind, + homotopy, + u0::Union{Tuple, Number}, + args...; kwargs...) + return FixedPointBifurcationProblem{false, tkind}( + homotopy, + as_immutable_state(u0), + args...; kwargs...) +end struct FixedPointBifurcationCache{P, C} <: AbstractProblemCache{P} diff --git a/src/traits.jl b/src/traits.jl new file mode 100644 index 0000000..5bed650 --- /dev/null +++ b/src/traits.jl @@ -0,0 +1,5 @@ +abstract type TimeKind end +struct Discrete <: TimeKind end +struct Continuous <: TimeKind end + +timekind(::T) where T = TimeKind(T)