Skip to content

Commit

Permalink
Add a GPU demo
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Nov 14, 2019
1 parent adc4a84 commit 1eb0387
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 20 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Takafumi Arakaki <aka.tkf@gmail.com>"]
version = "0.0.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down
80 changes: 80 additions & 0 deletions examples/gpu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
using Bifurcations
using LinearAlgebra
using OrdinaryDiffEq
using Setfield
using Parameters: @unpack

HAS_CUDA = try
using CuArrays
import CUDAnative
true
catch err
@error "CuArrays cannot be imported" exception = (err, catch_backtrace())
global const CuArray = Nothing
false
end

function phase_dynamics!(du, u, p, t)
@unpack J, h, g = p
tanh = u isa CuArray ? CUDAnative.tanh : Base.tanh
du .= .- u .+ g .* (J * tanh.(u)) .+ h
return du
end

function jacobian!(M, u, p, t)
@unpack J, h, g = p
cosh = u isa CuArray ? CUDAnative.cosh : Base.cosh
M .= -I + g .* J .* Diagonal(inv.(cosh.(u)) .^ 2)
return M
end

N = 10000
ode = ODEProblem(
ODEFunction(phase_dynamics!; jac = jacobian!),
randn(N),
(0.0, 50.0),
(
J = randn(N, N) ./ .√(N),
h = randn(N),
g = 0.5,
),
)

ts = solve(ode, Tsit5())
ode0 = remake(ode, u0 = ts.u[end])

param_axis = @lens _.g
prob = BifurcationProblem(ode0, param_axis, (0.5, 1.5))
# prob = BifurcationProblem(ode, param_axis, (0.5, 1.5))

using Bifurcations.Continuations: ContinuationSolver, ContinuationOptions
opts = ContinuationOptions(
atol = 0.05,
rtol = 0.05,
max_branches = 0,
max_samples = 1,
# bidirectional_first_sweep = false,
# start_from_nearest_root = true,
verbose = true,
)
solver = ContinuationSolver(prob, opts)
# solve(prob; start_from_nearest_root = true)
# sol = @time solve!(solver)
# sol = @time solve!(solver)

if HAS_CUDA
ode_gpu = remake(
ode0;
u0 = cu(ode0.u0),
p = (
J = cu(ode0.p.J),
h = cu(ode0.p.h),
g = ode0.p.g,
),
)
prob_gpu = BifurcationProblem(ode_gpu, param_axis, (0.5, 1.5))
solver_gpu = ContinuationSolver(prob_gpu, opts)
# solve(prob_gpu; start_from_nearest_root = true)
# sol_gpu = @time solve!(solver_gpu)
# sol_gpu = @time solve!(solver_gpu)
end
6 changes: 3 additions & 3 deletions src/continuations/branching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ function find_more_nullspaces(Q, L, rtol, atol, max_steps)
end
=#

L2 = @view L[1:end-1, 1:end-1]
L2 = popbottomright(L)
R2 = L2'
if Q isa StaticArray
L2 = LowerTriangular(SMatrix{size(L2)...}(L2))
R2 = UpperTriangular(SMatrix{size(R2)...}(R2))
end
y, cotJ = _find_more_nullspaces(L2, R2, y, rtol, atol, max_steps)
tJ2 = (@view Q[1:end-1, :])' * y
tJ2 = (Q[1:end-1, :])' * y

if y isa SVector # TODO: don't
return SVector(tJ2...), cotJ
Expand All @@ -47,7 +47,7 @@ function _find_more_nullspaces(L2, R2,
ker_L2 = nullspace(Array(L2))
ker_R2 = nullspace(Array(R2))
if size(ker_L2, 2) > 0 && size(ker_R2, 2) > 0
return (T(@view ker_L2[:, 1]), T(@view ker_R2[:, 1]))
return (T(ker_L2[:, 1]), T(ker_R2[:, 1]))
end
# Otherwise, let's fallback to the manual method.
end
Expand Down
2 changes: 1 addition & 1 deletion src/continuations/continuations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using StaticArrays: StaticArray, SMatrix, SVector, SArray
import DiffEqBase: solve, solve!, init, step!

using ..ArrayUtils: _similar, _zeros, isalmostzero, zero_if_nan, _lq!, _det,
_normalize!, bottomrow
_normalize!, bottomrow, popbottomright

include("interface.jl")
include("euler_newton.jl")
Expand Down
14 changes: 11 additions & 3 deletions src/continuations/euler_newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,16 @@ rawtangent(Q) = vec(rawtangentmat(Q))
function rawtangentmat(Q)
if Q isa StaticArray
return bottomrow(Q)
elseif Q isa Adjoint # CuArray takes this path
x′ = _similar(Q, size(Q, 1), 1)
fill!(x′, false)
x′[end, 1] = 1
# return (Q' * x′)' # `vec(x::Adjoint{_, <:CuArray})` does not work
return conj(reshape(Q' * x′, 1, :))
else
x = zeros(1, size(Q, 1))
x = _similar(Q, 1, size(Q, 1))
fill!(x, false)
# x .= (x .* false .+ CartesianIndices(x)) .== Ref(size(x)) # InvalidIRError
x[1, end] = 1
rmul!(x, Q)
return x
Expand All @@ -97,7 +105,7 @@ end

function tangent(L, Q)
tJ = rawtangent(Q)
if _det(Q) * det(@view L[1:end-1, 1:end-1]) < 0
if _det(Q) * _det(popbottomright(L)) < 0
tJ *= -1
end
return tJ
Expand Down Expand Up @@ -130,7 +138,7 @@ function corrector_step!(H::HType,
H, J = residual_jacobian!(H, J, v, prob_cache)
A = vcat(J, _zeros(J, 1, size(J, 2))) # TODO: improve
L, Q = _lq!(A)
y = vcat((@view L[1:end-1, 1:end-1]) \ H, false)
y = vcat(popbottomright(L) \ H, _zeros(J, 1))
dv = Q' * y
w = v - dv
return (w :: vType,
Expand Down
4 changes: 2 additions & 2 deletions src/continuations/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ end


function residual(u::AbstractArray, cache::AbstractProblemCache)
H = similar(@view u[1:end-1])
H = similar(u[1:end-1])
return residual!(H, u, cache)
end

Expand All @@ -196,7 +196,7 @@ end


function residual_jacobian(u::AbstractArray, cache::AbstractProblemCache)
H = similar(@view u[1:end-1])
H = similar(u[1:end-1])
J = similar(H, (length(H), length(u)))
return residual_jacobian!(H, J, u, cache)
end
Expand Down
4 changes: 2 additions & 2 deletions src/fixedpoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ end

function _residual!(H, u, prob::FixedPointBifurcationProblem,
::MutableState, ::Any)
x = @view u[1:end-1]
x = u[1:end-1]
t = u[end]
prob.homotopy(H, x, prob.p, t)
return H
Expand All @@ -185,7 +185,7 @@ end
function _residual_jacobian!(H, J, u, cache::FixedPointBifurcationCache,
::MutableState, ::HasJac)
prob = cache.prob
x = @view u[1:end-1]
x = u[1:end-1]
t = u[end]
prob.homotopy_jacobian(H, J, x, prob.p, t)
return (H, J)
Expand Down
46 changes: 37 additions & 9 deletions src/utils/array_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using LinearAlgebra: QRPackedQ, LQPackedQ, BlasFloat

using ForwardDiff: Dual
using StaticArrays: SVector, SMatrix, StaticArray, Size, similar_type
using ArrayInterface: ismutable

container_array_of(::SVector{S}) where {S} = SVector{S}
container_array_of(::SMatrix{S1, S2}) where {S1, S2} = SMatrix{S1, S2}
Expand Down Expand Up @@ -39,7 +40,15 @@ _eig(A) = eigen(A)
_eig(A::SMatrix) = eigen(Array(A))
_eig(A::Adjoint) = eigen(Array(A))

_similar(x::AbstractArray, dims...) = similar(x, dims)
function _similar(x::AbstractArray, dims...)
y = parent(x)
if typeof(x) === typeof(y)
return similar(x, dims)
else
return _similar(y, dims...)
end
end
_similar(x::LinearAlgebra.AbstractQ, dims...) = _similar(x.factors, dims...)
_similar(x::StaticArray, dims...) = _zeros(x, dims...)
_zeros(x::AbstractArray{T}, dims...) where T = fill!(similar(x, T, dims), zero(T))
_zeros(x::StaticArray, dims...) = zeros(similar_type(x, Size(dims)))
Expand Down Expand Up @@ -85,22 +94,27 @@ zero_if_nan(x) = isnan(x) ? zero(x) : x
A
end

const MatrixWithLQ = StridedMatrix{<:BlasFloat}

function _lq!(A::MatrixWithLQ)
L, Q = lq!(A)
return (LowerTriangular(L), Q)
function _lq!(A::SMatrix)
Q, R = qr(A')
return (LowerTriangular(R'), Q')
end

function _lq!(A)
function _lq!(A) # for CuAarray
Q, R = qr(A')
return (LowerTriangular(R'), Q')
return (UpperTriangular(R)', Q')
end

function _lq!(A::Matrix{<:BlasFloat})
L, Q = lq!(A)
return (LowerTriangular(L), Q)
end

_det(Q) = det(Q)
# https://github.com/JuliaLang/julia/pull/32887
_det(Q::Union{QRPackedQ{T}, LQPackedQ{T}}) where {T <: Real} =
isodd(count(!iszero, Q.τ)) ? -1 : 1
_det(Q::StaticArray) = det(Q)
_det(Q::Transpose) = _det(Q.parent)
_det(Q::Adjoint) = adjoint(_det(Q.parent))

@inline foldlargs(op, x) = x
@inline foldlargs(op, x1, x2, xs...) = foldlargs(op, op(x1, x2), xs...)
Expand All @@ -113,6 +127,20 @@ bottomrow(M::SMatrix{S1, S2}) where {S1, S2} =
end
)

popbottomright(A) = A[1:end-1, 1:end-1]
popbottomright(A::Adjoint) = popbottomright(A')'
popbottomright(A::LowerTriangular) = LowerTriangular(popbottomright(parent(A)))
popbottomright(A::UpperTriangular) = UpperTriangular(popbottomright(parent(A)))

function popbottomright(A::SMatrix{S1,S2}) where {S1,S2}
xs = foldlargs((), ntuple(identity, S2 - 1)...) do xs, j
foldlargs(xs, ntuple(identity, S1 - 1)...) do xs, i
(xs..., @inbounds A[i, j])
end
end
return SMatrix{S1 - 1,S2 - 1}(xs)
end

function _normalize!(x)
normalize!(x)
return x
Expand Down
8 changes: 8 additions & 0 deletions test/test_smoke.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,12 @@ using Bifurcations: examples
end
end

@testset "include(gpu.jl)" begin
m = Module()
@test begin
Base.include(m, "../examples/gpu.jl")
m.solver isa Bifurcations.Continuations.ContinuationSolver
end
end

end # module

0 comments on commit 1eb0387

Please sign in to comment.