Skip to content

Commit

Permalink
Merge pull request #30 from ScottPJones/spj/fix166
Browse files Browse the repository at this point in the history
Add checking of keywords to make solve more robust, fix #166
  • Loading branch information
rveltz committed May 12, 2017
2 parents 3f0465b + f5a4138 commit 9b09fcb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
10 changes: 10 additions & 0 deletions src/LSODA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@ module LSODA
using Compat, DiffEqBase
import DiffEqBase: solve

const warnkeywords =
(:save_idxs, :d_discontinuities, :isoutofdomain, :unstable_check,
:calck, :progress, :timeseries_steps, :dense,
:dtmin, :dtmax,
:internalnorm, :gamma, :beta1, :beta2, :qmax, :qmin, :qoldinit)

function __init__()
const global warnlist = Set(warnkeywords)
end

@compat abstract type LSODAAlgorithm <: AbstractODEAlgorithm end
immutable lsoda <: LSODAAlgorithm end

Expand Down
28 changes: 16 additions & 12 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ function solve{uType,tType,isinplace}(
prob::AbstractODEProblem{uType,tType,isinplace},
alg::LSODAAlgorithm,
timeseries=[],ts=[],ks=[];

verbose=true,
abstol=1/10^6,reltol=1/10^3,
tstops=Float64[],
saveat=Float64[],maxiter=Int(1e5),
Expand All @@ -13,6 +15,8 @@ function solve{uType,tType,isinplace}(
save_timeseries = nothing,
userdata=nothing,kwargs...)

verbose && !isempty(kwargs) && check_keywords(alg, kwargs, warnlist)

if save_timeseries != nothing
warn("save_timeseries is deprecated. Use save_everystep instead")
_save_everystep = save_timeseries
Expand All @@ -31,20 +35,20 @@ function solve{uType,tType,isinplace}(
T = tspan[end]

if typeof(saveat) <: Number
saveat_vec = convert(Vector{tType},saveat+tspan[1]:saveat:(tspan[end]-saveat))
# Exclude the endpoint because of floating point issues
saveat_vec = convert(Vector{tType},saveat+tspan[1]:saveat:(tspan[end]-saveat))
# Exclude the endpoint because of floating point issues
else
saveat_vec = convert(Vector{tType},collect(saveat))
saveat_vec = convert(Vector{tType},collect(saveat))
end

if !isempty(saveat_vec) && saveat_vec[end] == tspan[2]
pop!(saveat_vec)
pop!(saveat_vec)
end

if !isempty(saveat_vec) && saveat_vec[1] == tspan[1]
save_ts = sort(unique([saveat_vec;T]))
save_ts = sort(unique([saveat_vec;T]))
else
save_ts = sort(unique([t0;saveat_vec;T]))
save_ts = sort(unique([t0;saveat_vec;T]))
end

if T < save_ts[end]
Expand Down Expand Up @@ -94,25 +98,25 @@ function solve{uType,tType,isinplace}(
rtol = ones(Float64,neq)

if typeof(abstol) == Float64
atol *= abstol
atol *= abstol
else
atol = copy(abstol)
atol = copy(abstol)
end

if typeof(reltol) == Float64
rtol *= reltol
rtol *= reltol
else
rtol = copy(reltol)
rtol = copy(reltol)
end

opt = lsoda_opt_t()
opt.ixpr = 0
opt.rtol = pointer(rtol)
opt.atol = pointer(atol)
if save_everystep
itask_tmp = 2
itask_tmp = 2
else
itask_tmp = 1
itask_tmp = 1
end
opt.itask = itask_tmp

Expand Down

0 comments on commit 9b09fcb

Please sign in to comment.