Skip to content

Commit

Permalink
Use StaticArrays in codim1lc/problem.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Aug 9, 2018
1 parent 3faf36b commit ca1d29d
Showing 1 changed file with 56 additions and 73 deletions.
129 changes: 56 additions & 73 deletions src/codim1lc/problem.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using ForwardDiff
using Parameters: @with_kw, @unpack
using Setfield: set
using StaticArrays: SVector, SMatrix, Size

using Jacobi: lagrange, zgj, wgj
using ..PolyUtils: dlagrange
Expand Down Expand Up @@ -79,13 +80,13 @@ num_mesh(prob::LimitCycleProblem) = prob.num_mesh

# ----------------------------------------------------------------------- cache

struct LimitCycleCache{P, V, M} <: AbstractProblemCache{P}
struct LimitCycleCache{P, VRef, MV, ML, VV} <: AbstractProblemCache{P}
prob::P
reference::V
dv::M
lagrange_polynomial_vals::M
lagrange_polynomial_driv::M
gauss_quadrature_weight::V
reference::VRef
dv::MV
lagrange_polynomial_vals::ML
lagrange_polynomial_driv::ML
gauss_quadrature_weight::VV
end

# Indices:
Expand Down Expand Up @@ -119,15 +120,17 @@ function LimitCycleCache(prob)
return LimitCycleCache(
prob,
reference,
dv,
lagrange_polynomial_vals,
lagrange_polynomial_driv,
gauss_quadrature_weight,
SMatrix{n, m}(dv),
SMatrix{m, m + 1}(lagrange_polynomial_vals),
SMatrix{m, m + 1}(lagrange_polynomial_driv),
SVector{m}(gauss_quadrature_weight),
)
end
# TODO: dv is only used to hold the dimension of the dynamical system.
# Remove it.

dim_state(cache::LimitCycleCache) = size(cache.dv, 1)
degree(cache::LimitCycleCache) = size(cache.dv, 2)
Base.@pure dim_state(cache::LimitCycleCache) = Size(cache.dv)[1]
Base.@pure degree(cache::LimitCycleCache) = Size(cache.dv)[2]
num_mesh(cache::LimitCycleCache) = cache.prob.num_mesh

function set_reference!(cache::LimitCycleCache, u)
Expand Down Expand Up @@ -197,42 +200,40 @@ end

# ------------------------------------------------------------------- workspace

struct LimitCycleWorkspace{C, M}
cache::C
x::M
dx::M
end

function make_workspace(cache, u)
x = similar(u, size(cache.dv))
dx = similar(x)
return LimitCycleWorkspace(cache, x, dx)
end

function diff_range(cache, j, i)
@inline function diff_range(cache, j, i)
n = dim_state(cache)
m = degree(cache)
offset = n * m * (j - 1) + n * (i - 1)
return offset + 1:offset + n
end

function get_samples(cache, u, j) # TODO: better name than "samples"?
n = dim_state(cache)
m = degree(cache)
is_last = j == num_mesh(cache)

if is_last
dims = (n, m)
else
dims = (n, m + 1)
# TODO: get_samples --- better name than "samples"?
@generated function get_samples(
cache::LimitCycleCache{P, VRef, MV, ML, VV},
u, j,
) where {P, VRef, MV, ML, VV}
n = Size(MV)[1]
m = Size(MV)[2]

idxs_rest = [:($(n * m) * (j - 1) + $i) for i in 1:(n * (m + 1))]
idxs_last = [
idxs_rest[1:end - n]
1:n
]
Mat = SMatrix{n, m + 1, eltype(u)}

quote
if j == num_mesh(cache)
@inbounds return $Mat($([:(u[$i]) for i in idxs_last]...))
else
@inbounds return $Mat($([:(u[$i]) for i in idxs_rest]...))
end
end
s = n * m * (j - 1) + 1
e = s - 1 + prod(dims)
= reshape(view(u, s:e), dims)
xτ0 = view(u, 1:n)

return xτ, xτ0, is_last
end
# TODO: Stop hard-coding output type here. Assuming that m is always
# small (for StaticArrays to be effective) is probably fine but
# assuming n to be small is not OK. But I need to optimize a lot for
# Bifurcations.jl to be useful for non-small n anyway.

# Indices:
# i ∈ 1:m : collocation/evaluation/Gauss point/node index; ζᵢ
Expand All @@ -241,40 +242,23 @@ end
# Arrays: Index: Size:
# x, dx [p, i] (n, m)
# lpv, lpd [i, k] (m, m + 1)
# xτ [p, k] (n, m + 1) or (n, m)
# xτ0 [p] (n,)

function collocation!(ws, u, j)
@unpack x, dx = ws
lpv = ws.cache.lagrange_polynomial_vals # ℓᵀ; ℓᵢₖ = [ℓᵀ]ₖᵢ = ℓₖ(ζᵢ)
lpd = ws.cache.lagrange_polynomial_driv
xτ, xτ0, is_last = get_samples(ws.cache, u, j) # xₚ(τₖ)
if ! is_last
A_mul_Bt!(x, xτ, lpv) # x(ζ) = x(τ) * ℓ
A_mul_Bt!(dx, xτ, lpd) # x'(ζ) = x'(τ) * ℓ
else
# It is the last mesh. So, the last node has to come from the
# first node of the first mesh.
A_mul_Bt!(x, xτ, @view lpv[:, 1:end-1])
A_mul_Bt!(dx, xτ, @view lpd[:, 1:end-1])
x .+= xτ0 * (@view lpv[:, end])'
dx .+= xτ0 * (@view lpd[:, end])'
# TODO: optimize
end
# xτ [p, k] (n, m + 1)

@inline function collocation!(cache, u, j)
lpv = cache.lagrange_polynomial_vals # ℓᵀ; ℓᵢₖ = [ℓᵀ]ₖᵢ = ℓₖ(ζᵢ)
lpd = cache.lagrange_polynomial_driv
= get_samples(cache, u, j) # xₚ(τₖ)
x =* lpv' # x(ζ) = x(τ) * ℓ
dx =* lpd' # x'(ζ) = x'(τ) * ℓ
return x, dx
end
# TODO: store ℓ (not ℓᵀ) in lpv? (Why not?)

function reference_diff!(cache, j)
@inline function reference_diff!(cache, j)
dv = cache.dv
lpd = cache.lagrange_polynomial_driv
vτ, vτ0, is_last = get_samples(cache, cache.reference, j)
if ! is_last
A_mul_Bt!(dv, vτ, lpd) # x'(ζ) = x'(τ) * ℓ
else
A_mul_Bt!(dv, vτ, @view lpd[:, 1:end-1])
dv .+= vτ0 * (@view lpd[:, end])'
# TODO: optimize
end
= get_samples(cache, cache.reference, j)
dv =* lpd' # x'(ζ) = x'(τ) * ℓ
return dv
end

Expand Down Expand Up @@ -306,16 +290,15 @@ function residual_lc!(H, u, q, cache::LimitCycleCache)
prob = cache.prob

l = u[u_idx_period(cache)]
ws = make_workspace(cache, u)
weight = cache.gauss_quadrature_weight

phase_condition = zero(eltype(H))
for j in 1:num_mesh(cache)
x, dx = collocation!(ws, u, j)
x, dx = collocation!(cache, u, j)
dv = reference_diff!(cache, j)
@views for (i, t) in enumerate(turns(cache, j))
@inbounds for (i, t) in enumerate(turns(cache, j))
r = diff_range(cache, j, i)
f = H[r]
f = @view H[r]
prob_time = t * l + prob.time_offset
prob.de_prob.f(f, x[:, i], q, prob_time)

Expand Down

0 comments on commit ca1d29d

Please sign in to comment.