Skip to content

Commit

Permalink
Merge pull request #19 from rveltz/update_
Browse files Browse the repository at this point in the history
update common interface
  • Loading branch information
rveltz committed Apr 14, 2017
2 parents 50586e1 + a85fa99 commit ec00d7c
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 33 deletions.
2 changes: 1 addition & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ julia 0.5
BinDeps 0.4.3
Compat 0.9.0
Parameters 0.5.0
DiffEqBase
DiffEqBase 0.15.0
90 changes: 66 additions & 24 deletions src/common.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,42 @@
## Common Interface Solve Functions

function solve{uType,tType,isinplace,F}(
prob::AbstractODEProblem{uType,tType,isinplace,F},
function solve{uType,tType,isinplace}(
prob::AbstractODEProblem{uType,tType,isinplace},
alg::LSODAAlgorithm,
timeseries=[],ts=[],ks=[];
abstol=1/10^6,reltol=1/10^3,
tstops=Float64[],
saveat=Float64[],maxiter=Int(1e5),
timeseries_errors=true,save_timeseries=true,
save_start=true,
timeseries_errors=true,save_everystep= isempty(saveat),
save_timeseries = nothing,
userdata=nothing,kwargs...)

if save_timeseries != nothing
warn("save_timeseries is deprecated. Use save_everystep instead")
_save_everystep = save_timeseries
end

tspan = prob.tspan
t0 = tspan[1]
T = tspan[end]

save_ts = sort(unique([t0;saveat;T]))
if typeof(saveat) <: Number
saveat_vec = convert(Vector{tType},saveat:saveat:(tspan[end]-saveat))
# Exclude the endpoint because of floating point issues
else
saveat_vec = convert(Vector{tType},collect(saveat))
end

if !isempty(saveat_vec) && saveat_vec[end] == tspan[2]
pop!(saveat_vec)
end

if !isempty(saveat_vec) && saveat_vec[1] == tspan[1]
save_ts = sort(unique([saveat_vec;T]))
else
save_ts = sort(unique([t0;saveat_vec;T]))
end

if T < save_ts[end]
error("Final saving timepoint is past the solving timespan")
Expand Down Expand Up @@ -54,7 +76,7 @@ function solve{uType,tType,isinplace,F}(
ttmp = [t0]
t = [t0]
t2 = [t0]
ts = [t0]
save_start ? ts = [t0] : ts = Vector{typeof(t0)}(0)

neq = Int32(length(u0))
userfun = UserFunctionAndData(f!, userdata,neq)
Expand All @@ -78,7 +100,7 @@ function solve{uType,tType,isinplace,F}(
opt.ixpr = 0
opt.rtol = pointer(rtol)
opt.atol = pointer(atol)
if save_timeseries
if save_everystep
itask_tmp = 2
else
itask_tmp = 1
Expand All @@ -97,40 +119,60 @@ function solve{uType,tType,isinplace,F}(

for k in 2:length(save_ts)
ttmp[1] = save_ts[k]
while t[1]<ttmp[1]
lsoda(ctx,utmp,t,ttmp[1])
if t[1]>ttmp[1] # overstepd, interpolate back
t2[1] = t[1] # save step values
copy!(utmp2,utmp) # save step values
opt.itask = 1 # change to interpolating
if t[1]<ttmp[1]
while t[1]<ttmp[1]
lsoda(ctx,utmp,t,ttmp[1])
opt.itask = itask_tmp
push!(ures,copy(utmp))
push!(ts,t[1])
if k != length(save_ts) # don't overstep the last timestep
push!(ures,copy(utmp2))
push!(ts,t2[1])
if t[1]>ttmp[1] # overstepd, interpolate back
t2[1] = t[1] # save step values
copy!(utmp2,utmp) # save step values
opt.itask = 1 # change to interpolating
lsoda(ctx,utmp,t,ttmp[1])
opt.itask = itask_tmp
push!(ures,copy(utmp))
push!(ts,t[1])
if k != length(save_ts) && save_ts[k+1] > t2[1] # don't overstep the last timestep
push!(ures,copy(utmp2))
push!(ts,t2[1])
end
copy!(utmp,utmp2)
t[1] = t2[1]
else
push!(ures,copy(utmp))
push!(ts,t[1])
end
else
push!(ures,copy(utmp))
push!(ts,t[1])
end
else
t2[1] = t[1] # save step values
copy!(utmp2,utmp) # save step values
opt.itask = 1 # change to interpolating
lsoda(ctx,utmp,t,ttmp[1])
opt.itask = itask_tmp
push!(ures,copy(utmp))
push!(ts,t[1])
if k != length(save_ts) && save_ts[k+1] > t2[1] # don't overstep the last timestep
push!(ures,copy(utmp2))
push!(ts,t2[1])
end
copy!(utmp,utmp2)
t[1] = t2[1]
end
end

### Finishing Routine

timeseries = Vector{uType}(0)
save_start ? start_idx = 1 : start_idx = 2
if typeof(prob.u0)<:Number
for i=1:length(ures)
for i=start_idx:length(ures)
push!(timeseries,ures[i][1])
end
else
for i=1:length(ures)
for i=start_idx:length(ures)
push!(timeseries,reshape(ures[i],sizeu))
end
end

build_solution(prob,alg,ts,timeseries,
timeseries_errors = timeseries_errors)
timeseries_errors = timeseries_errors,
retcode = :Success)
end
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
push!(LOAD_PATH, "/Users/rveltz/work/prog_gd/julia")
using LSODA
using Base.Test

Expand Down
3 changes: 1 addition & 2 deletions test/test2.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
include("/Users/rveltz/work/prog_gd/julia/repLSODA/LSODA.jl/src/LSODA.jl")
using LSODA

function rhs!(t, x, ydot, data)
Expand Down Expand Up @@ -30,4 +29,4 @@ println("\n####################################\n--> Use of a lsoda_evolve!")
# LSODA.lsoda_evolve!(ctx,y0,tspan[k-1:k])
# @printf("at t = %12.4e y= %14.6e %14.6e %14.6e\n",tspan[k],y0[1], y0[2], y0[3])
# end
# lsoda_free(ctx)
# lsoda_free(ctx)
34 changes: 29 additions & 5 deletions test/test_common.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,34 @@
using LSODA, DiffEqProblemLibrary
using LSODA, DiffEqProblemLibrary, Base.Test
prob = prob_ode_linear
sol = solve(prob,lsoda(),save_timeseries=false,saveat=[1/2])
sol = solve(prob,lsoda(),saveat=[1/2])
@test sol.t == [0,1/2,1]
prob = prob_ode_2Dlinear
sol = solve(prob,lsoda(),save_timeseries=false,saveat=[1/2])
sol = solve(prob,lsoda(),saveat=[1/2])
@test sol.t == [0,1/2,1]
sol = solve(prob,lsoda(),saveat=1/10)
@test sol.t == collect(0:1/10:1)

prob = prob_ode_linear
sol = solve(prob,lsoda(),save_timeseries=true,saveat=[1/2])
sol = solve(prob,lsoda())
sol = solve(prob,lsoda(),save_everystep=true,saveat=[1/2])
@test 1/2 sol.t
prob = prob_ode_2Dlinear
sol = solve(prob,lsoda(),save_timeseries=true,saveat=[1/2])
sol = solve(prob,lsoda(),save_everystep=true,saveat=[1/2])
@test 1/2 sol.t
sol = solve(prob,lsoda(),save_everystep=true,saveat=1/2)
@test 1/2 sol.t
sol = solve(prob,lsoda(),save_everystep=true,saveat=[1/10,1/5,3/10])#,2/5,1/2,3/5,7/10])
@test 1/10 sol.t
@test 1/5 sol.t
@test 3/10 sol.t
sol = solve(prob,lsoda(),save_everystep=true,saveat=1/10)
for i in 2:length(sol.t)
@test sol.t[i] > sol.t[i-1]
end
for k in 0:1/10:1
@test k sol.t
end

sol = solve(prob,lsoda(),save_start=false,saveat=1/10)
sol.t[1] == 0.1
sol.u[1] != prob.u0

0 comments on commit ec00d7c

Please sign in to comment.