Skip to content

Commit

Permalink
Setup ForwardDiff config for codim2 (partially)
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Jun 22, 2018
1 parent 5dc17dd commit eee7d83
Showing 1 changed file with 15 additions and 20 deletions.
35 changes: 15 additions & 20 deletions src/codim2/diffeq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ using ForwardDiff
using Setfield: Lens, set, get

using ...Bifurcations: maybe_subtract!
using ..Continuations: _similar

"""
Codimension-2 fixed point bifurcation problem wrapper for DifferentialEquations.
Expand Down Expand Up @@ -54,23 +53,21 @@ end
DiffEqCodim2BifurcationCache(prob::DiffEqCodim2BifurcationProblem) =
DiffEqCodim2BifurcationCache(
prob,
setup_fd_config(statekind(prob), prob.de_prob))
setup_fd_config(statekind(prob), prob))

function setup_fd_config(::MutableState, de_prob)
u0 = de_prob.u0
y = similar(u0, length(u0) * 2 + 1)
x = similar(u0, length(u0) * 2 + 2)
function setup_fd_config(::MutableState, prob)
x = vcat(prob.x0, prob.v0, prob.t0) # length = 2N + 1
y = copy(x)
return ForwardDiff.JacobianConfig(
(y, x) -> de_prob.f(y, x, de_prob, 0),
(y, x) -> _residual!(y, x, prob, statekind(cache.prob)),
y,
x)
end

function setup_fd_config(::ImmutableState, de_prob)
u0 = de_prob.u0
x = _similar(u0, length(u0) * 2 + 2)
function setup_fd_config(::ImmutableState, prob)
x = vcat(prob.x0, prob.v0, prob.t0) # length = 2N + 1
return ForwardDiff.JacobianConfig(
(x) -> de_prob.f(x, de_prob, 0),
(x) -> _residual!(x, x, prob, statekind(cache.prob)),
x)
end

Expand Down Expand Up @@ -100,7 +97,7 @@ function isindomain(u, cache::DiffEqCodim2BifurcationCache)
end

residual!(H, u, cache::DiffEqCodim2BifurcationCache) =
_residual!(H, u, cache, statekind(cache.prob))
_residual!(H, u, cache.prob, statekind(cache.prob))

residual_jacobian!(H, J, u, cache::DiffEqCodim2BifurcationCache) =
_residual_jacobian!(H, J, u, cache, statekind(cache.prob))
Expand All @@ -117,9 +114,8 @@ end

# ------------------------------------------------------------------- residual!

function _residual!(H, u, cache::DiffEqCodim2BifurcationCache,
function _residual!(H, u, prob::DiffEqCodim2BifurcationProblem,
::MutableState)
prob = cache.prob
q = modified_param!(prob, u)

H1, H2, H3 = output_vars(H)
Expand All @@ -131,7 +127,7 @@ function _residual!(H, u, cache::DiffEqCodim2BifurcationCache,
(dx, x) -> prob.de_prob.f(dx, x, q, 0),
H1, # dx
x,
# cache.cfg, # TODO: make it work
# TODO: setup cache
)

A_mul_B!(H2, J, v)
Expand All @@ -142,9 +138,8 @@ function _residual!(H, u, cache::DiffEqCodim2BifurcationCache,
return H
end

function _residual!(::Any, u, cache::DiffEqCodim2BifurcationCache,
function _residual!(::Any, u, prob::DiffEqCodim2BifurcationProblem,
::ImmutableState)
prob = cache.prob
q = modified_param!(prob, u)

# TODO: Can I compute H and J in one go? Or is it already
Expand All @@ -154,7 +149,7 @@ function _residual!(::Any, u, cache::DiffEqCodim2BifurcationCache,
J = ForwardDiff.jacobian(
(x) -> prob.de_prob.f(x, q, 0),
x,
# cache.cfg, # TODO: make it work
# TODO: setup cache
)

v = ds_eigvec(u)
Expand All @@ -176,7 +171,7 @@ function _residual_jacobian!(H, J, u, cache::DiffEqCodim2BifurcationCache,
(y, x) -> residual!(y, x, cache),
H, # y
u, # x
# TODO: setup cache
cache.cfg,
)
return (H, J)
end
Expand All @@ -189,7 +184,7 @@ function _residual_jacobian!(_H, _J, u, cache::DiffEqCodim2BifurcationCache,
J = ForwardDiff.jacobian(
(x) -> residual!(_H, x, cache),
u, # x
# TODO: setup cache
cache.cfg,
)
return (H, J)
end

0 comments on commit eee7d83

Please sign in to comment.