Skip to content

Commit

Permalink
Add TimeKind trait
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Jun 19, 2018
1 parent abd7bc4 commit 0eefebc
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 19 deletions.
1 change: 1 addition & 0 deletions src/Bifurcations.jl
Expand Up @@ -11,6 +11,7 @@ import .Continuations: get_prob_cache, get_u0, residual!, residual_jacobian!,
residual, isindomain
const _C = AbstractProblemCache

include("traits.jl")
include("fixedpoint.jl")
include("diffeq.jl")
include("examples/examples.jl")
Expand Down
31 changes: 27 additions & 4 deletions src/diffeq.jl
@@ -1,8 +1,23 @@
using DiffEqBase: AbstractODEProblem
using DiffEqBase: AbstractODEProblem, DiscreteProblem
using Setfield: Lens, set, get

const DEP{iip} = AbstractODEProblem{uType, tType, iip} where {uType, tType}


abstract type StateKind end
struct MutableState <: StateKind end
struct ImmutableState <: StateKind end

statekind(::T) where T = StateKind(T)
# TODO: Move StateKind to continuations/base.jl and use it everywhere.
# It's only used in here at the moment.


TimeKind(::Type{<: DiscreteProblem}) = Discrete()
TimeKind(::Type{<: AbstractODEProblem}) = Continuous()
StateKind(::Type{<: DEP{true}}) = MutableState()
StateKind(::Type{<: DEP{false}}) = ImmutableState()

struct DiffEqWrapper{P, L}
de_prob::P
param_axis::L
Expand All @@ -11,13 +26,20 @@ end
function diffeq_homotopy(H, x, p::DiffEqWrapper{<:DEP{true}}, t)
q = set(p.param_axis, p.de_prob.p, t)
p.de_prob.f(H, x, q, 0)
maybe_subtract!(H, x, statekind(p.de_prob), timekind(p.de_prob))
end

function diffeq_homotopy(x, p::DiffEqWrapper{<:DEP{false}}, t)
q = set(p.param_axis, p.de_prob.p, t)
return p.de_prob.f(x, q, 0)
H = p.de_prob.f(x, q, 0)
return maybe_subtract!(H, x, statekind(p.de_prob), timekind(p.de_prob))
end

maybe_subtract!(H, ::Any, ::StateKind, ::Continuous) = H
maybe_subtract!(H, x, ::MutableState, ::Discrete) = H .-= x
maybe_subtract!(H, x, ::ImmutableState, ::Discrete) = H .- x


"""
FixedPointBifurcationProblem(ode_or_map::AbstractODEProblem,
param_axis::Lens,
Expand All @@ -36,6 +58,7 @@ function FixedPointBifurcationProblem(
u0 = de_prob.u0
t0 = get(param_axis, de_prob.p)
p = DiffEqWrapper(deepcopy(de_prob), param_axis)
return FixedPointBifurcationProblem{iip}(diffeq_homotopy, u0, t0,
t_domain, p; kwargs...)
return FixedPointBifurcationProblem{iip, typeof(timekind(de_prob))}(
diffeq_homotopy, u0, t0,
t_domain, p; kwargs...)
end
46 changes: 31 additions & 15 deletions src/fixedpoint.jl
Expand Up @@ -3,7 +3,9 @@ using StaticArrays: SVector, push
using ForwardDiff


struct FixedPointBifurcationProblem{iip, HJ, H, U, T, P,
struct FixedPointBifurcationProblem{iip,
tkind <: TimeKind,
HJ, H, U, T, P,
} <: AbstractContinuationProblem{iip}
homotopy_jacobian::HJ
homotopy::H
Expand All @@ -14,30 +16,44 @@ struct FixedPointBifurcationProblem{iip, HJ, H, U, T, P,

# TODO: Define domain for u. Maybe use Domains.jl?

function FixedPointBifurcationProblem{iip}(
function FixedPointBifurcationProblem{iip, tkind}(
homotopy::H, u0::U, t0::Real, t_domain::Tuple,
p::P = nothing;
homotopy_jacobian::HJ = nothing,
) where{iip, HJ, H, U, P}
) where{iip, tkind, HJ, H, U, P}
T = promote_type(typeof(t0), map(typeof, t_domain)...)
new{iip, HJ, H, U, T, P}(homotopy_jacobian, homotopy,
u0, t0, t_domain, p)
new{iip, tkind, HJ, H, U, T, P}(
homotopy_jacobian, homotopy,
u0, t0, t_domain, p)
end
end
const FPBPWithHJac{iip} = FixedPointBifurcationProblem{iip, <: Function}
const FPBPNoHJac{iip} = FixedPointBifurcationProblem{iip, Void}
const FPBPScalar = FixedPointBifurcationProblem{false, HJ, H,
<: Real} where {HJ, H}

function FixedPointBifurcationProblem(homotopy, args...; kwargs...)
const FPBPWithHJac{iip, tkind} =
FixedPointBifurcationProblem{iip, tkind, <: Function}
const FPBPNoHJac{iip, tkind} =
FixedPointBifurcationProblem{iip, tkind, Void}
const FPBPScalar{tkind <: TimeKind} =
FixedPointBifurcationProblem{false, tkind, HJ, H, <: Real} where {HJ, H}

function FixedPointBifurcationProblem(tkind::TimeKind,
homotopy, args...; kwargs...)
iip = numargs(homotopy) == 4
return FixedPointBifurcationProblem{iip}(homotopy, args...; kwargs...)
return FixedPointBifurcationProblem{iip, tkind}(
homotopy, args...; kwargs...)
end

FixedPointBifurcationProblem(homotopy, u0::Tuple, args...; kwargs...) =
FixedPointBifurcationProblem{false}(homotopy,
SVector(u0),
args...; kwargs...)
as_immutable_state(x::Tuple) = SVector(x)
as_immutable_state(x::Number) = x

function FixedPointBifurcationProblem(tkind::TimeKind,
homotopy,
u0::Union{Tuple, Number},
args...; kwargs...)
return FixedPointBifurcationProblem{false, tkind}(
homotopy,
as_immutable_state(u0),
args...; kwargs...)
end


struct FixedPointBifurcationCache{P, C} <: AbstractProblemCache{P}
Expand Down
5 changes: 5 additions & 0 deletions src/traits.jl
@@ -0,0 +1,5 @@
abstract type TimeKind end
struct Discrete <: TimeKind end
struct Continuous <: TimeKind end

timekind(::T) where T = TimeKind(T)

0 comments on commit 0eefebc

Please sign in to comment.