Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add Common Interface Bindings #5

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
julia 0.5
BinDeps 0.4.3
Compat 0.9.0
Parameters 0.5.0
Parameters 0.5.0
DiffEqBase
6 changes: 5 additions & 1 deletion src/LSODA.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module LSODA

using Compat
using Compat, DiffEqBase
import DiffEqBase: solve

const depsfile = joinpath(dirname(dirname(@__FILE__)),"deps","deps.jl")
if isfile(depsfile)
Expand All @@ -11,7 +12,10 @@ end

export lsoda, lsoda_0, lsoda_opt_t, lsoda_context_t, lsoda_prepare, lsoda_opt_t, lsoda_free, lsoda_evolve!

export LSODAAlgorithm, LSODAAlg, solve

include("types_and_consts.jl")
include("solver.jl")
include("common.jl")

end # module
138 changes: 138 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
abstract LSODAAlgorithm <: AbstractODEAlgorithm
immutable LSODAAlg <: LSODAAlgorithm end

## Common Interface Solve Functions

function solve{uType,tType,isinplace,F}(
prob::AbstractODEProblem{uType,tType,isinplace,F},
alg::LSODAAlgorithm,
timeseries=[],ts=[],ks=[];
abstol=1/10^6,reltol=1/10^3,
saveat=Float64[],maxiter=Int(1e5),
timeseries_errors=true,save_timeseries=true,
userdata=nothing,kwargs...)

tspan = prob.tspan
t0 = tspan[1]
T = tspan[end]

save_ts = sort(unique([t0;saveat;T]))

if T < save_ts[end]
error("Final saving timepoint is past the solving timespan")
end
if t0 > save_ts[1]
error("First saving timepoint is before the solving timespan")
end

if typeof(prob.u0) <: Number
u0 = [prob.u0]
else
u0 = vec(deepcopy(prob.u0))
end

sizeu = size(prob.u0)

### Fix the more general function to Sundials allowed style
if !isinplace && (typeof(prob.u0)<:Vector{Float64} || typeof(prob.u0)<:Number)
f! = (t,u,du,userdata) -> (du[:] = prob.f(t,u); nothing)
elseif !isinplace && typeof(prob.u0)<:AbstractArray
f! = (t,u,du,userdata) -> (du[:] = vec(prob.f(t,reshape(u,sizeu))); nothing)
elseif typeof(prob.u0)<:Vector{Float64}
f! = (t,u,du,userdata) -> prob.f(t,u,du)
else # Then it's an in-place function on an abstract array
f! = (t,u,du,userdata) -> (prob.f(t,reshape(u,sizeu),reshape(du,sizeu));
u = vec(u); du=vec(du); nothing)
end

ures = Vector{Vector{Float64}}()
push!(ures,u0)
utmp = copy(u0)
ttmp = [t0]
t = [t0]
ts = [t0]

neq = Int32(length(u0))
userfun = UserFunctionAndData(f!, userdata,neq)

atol = ones(Float64,neq)
rtol = ones(Float64,neq)

if typeof(abstol) == Float64
atol *= abstol
else
atol = copy(abstol)
end

if typeof(reltol) == Float64
rtol *= reltol
else
rtol = copy(reltol)
end

opt = lsoda_opt_t()
opt.ixpr = 0
opt.rtol = pointer(rtol)
opt.atol = pointer(atol)
opt.itask = 1

const fex_c = cfunction(lsodafun,Cint,(Cdouble,Ptr{Cdouble},Ptr{Cdouble},Ref{typeof(userfun)}))

ctx = lsoda_context_t()
ctx.function_ = fex_c
ctx.neq = neq
ctx.state = 1
ctx.data = pointer_from_objref(userfun)

lsoda_prepare(ctx,opt)

# The Inner Loops : Style depends on save_timeseries
if save_timeseries
#=
for k in 2:length(save_ts)
looped = false
while tout[end] < save_ts[k]
looped = true
flag = @checkflag CVode(mem,
save_ts[k], utmp, tout, CV_ONE_STEP)
push!(ures,copy(utmp))
push!(ts, tout...)
end
if looped
# Fix the end
flag = @checkflag CVodeGetDky(
mem, save_ts[k], Cint(0), ures[end])
ts[end] = save_ts[k]
else # Just push another value
flag = @checkflag CVodeGetDky(
mem, save_ts[k], Cint(0), utmp)
push!(ures,copy(utmp))
push!(ts, save_ts[k]...)
end
end
=#
else # save_timeseries == false, so use saveat style
for k in 2:length(save_ts)
ttmp[1] = save_ts[k]
lsoda(ctx,utmp,t,ttmp[1])
push!(ures,copy(utmp))
end
ts = save_ts
end

### Finishing Routine

timeseries = Vector{uType}(0)
if typeof(prob.u0)<:Number
for i=1:length(ures)
push!(timeseries,ures[i][1])
end
else
for i=1:length(ures)
push!(timeseries,reshape(ures[i],sizeu))
end
end

build_solution(prob,alg,ts,timeseries,
timeseries_errors = timeseries_errors)
end
1 change: 1 addition & 0 deletions test/REQUIRE
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DiffEqProblemLibrary
5 changes: 5 additions & 0 deletions test/test_common.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
using LSODA, DiffEqProblemLibrary
prob = prob_ode_linear
sol = solve(prob,LSODAAlg(),save_timeseries=false,saveat=[1/2])
prob = prob_ode_2Dlinear
sol = solve(prob,LSODAAlg(),save_timeseries=false,saveat=[1/2])