Skip to content

Commit

Permalink
refactoring chvdiffeq
Browse files Browse the repository at this point in the history
  • Loading branch information
rveltz committed Oct 2, 2019
1 parent a6069e9 commit 216a5e4
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 22 deletions.
33 changes: 14 additions & 19 deletions src/chv.jl
Expand Up @@ -13,6 +13,7 @@ function solve(problem::PDMPProblem, algo::CHV{Tode}; verbose::Bool = false, ind

# we declare the characteristics for convenience
caract = problem.caract
ratecache = caract.ratecache

ti, tf = problem.tspan
n_jumps += 1 # to hold initial vector
Expand All @@ -30,8 +31,6 @@ function solve(problem::PDMPProblem, algo::CHV{Tode}; verbose::Bool = false, ind
end
X_extended[end] = ti

t_hist = [ti]

#useful to use the same array, as it can be used in CHV(ode)
Xd = caract.xd
if ind_save_c[1] == -1
Expand All @@ -43,22 +42,18 @@ function solve(problem::PDMPProblem, algo::CHV{Tode}; verbose::Bool = false, ind
end
xc_hist = VectorOfArray([copy(xc0)[ind_save_c]])
xd_hist = VectorOfArray([copy(xd0)[ind_save_d]])
rate_hist = eltype(xc0)[]

res_ode = zeros(2, length(X_extended))

nsteps += 1

numpf = size(caract.pdmpjump.nu, 1) # number of reactions
rate = zeros(numpf) # vector of rates

# define the ODE flow, this leads to big memory saving
if ode == :cvode || ode == :bdf
Flow = (X0_,Xd_,Δt,r_) -> Sundials.cvode((tt,x,xdot) -> algo(xdot, x, problem.caract, tt), X0_, [0., Δt], abstol = abstol, reltol = reltol, integrator = :BDF)
Flow = (X0_,Xd_,Δt,r_) -> Sundials.cvode((tt,x,xdot) -> algo(xdot, x, caract, tt), X0_, [0., Δt], abstol = abstol, reltol = reltol, integrator = :BDF)
elseif ode==:adams
Flow = (X0_,Xd_,Δt,r_) -> Sundials.cvode((tt,x,xdot) -> algo(xdot, x, problem.caract, tt), X0_, [0., Δt], abstol = abstol, reltol = reltol, integrator = :Adams)
Flow = (X0_,Xd_,Δt,r_) -> Sundials.cvode((tt,x,xdot) -> algo(xdot, x, caract, tt), X0_, [0., Δt], abstol = abstol, reltol = reltol, integrator = :Adams)
elseif ode==:lsoda
Flow = (X0_,Xd_,Δt,r_) -> LSODA.lsoda((tt,x,xdot,data) -> algo(xdot, x, problem.caract, tt), X0_, [0., Δt], abstol = abstol, reltol = reltol)
Flow = (X0_,Xd_,Δt,r_) -> LSODA.lsoda((tt,x,xdot,data) -> algo(xdot, x, caract, tt), X0_, [0., Δt], abstol = abstol, reltol = reltol)
end

# we use the first time interval from the one generated by the constructor PDMPProblem
Expand All @@ -69,7 +64,7 @@ function solve(problem::PDMPProblem, algo::CHV{Tode}; verbose::Bool = false, ind

verbose && println("--> t = ", t," - δt = ", δt, ",nstep = ", nsteps)

res_ode .= Flow(X_extended, Xd, δt, rate)
res_ode .= Flow(X_extended, Xd, δt, ratecache.rate)

verbose && println("--> ode solve is done!")

Expand All @@ -81,44 +76,44 @@ function solve(problem::PDMPProblem, algo::CHV{Tode}; verbose::Bool = false, ind
# this is the next jump time
t = res_ode[end, end]

problem.caract.R(rate, X_extended, Xd, problem.caract.parms, t, false)
caract.R(ratecache.rate, X_extended, Xd, caract.parms, t, false)

# jump time:
if (t < tf) && nsteps < n_jumps
# Update event
ev = pfsample(rate, sum(rate), numpf)
ev = pfsample(ratecache.rate)

# we perform the jump
affect!(problem.caract.pdmpjump, ev, X_extended, Xd, problem.caract.parms, t)
affect!(caract.pdmpjump, ev, X_extended, Xd, caract.parms, t)

verbose && println("--> Which reaction? => ", ev)
verbose && println("--> xd = ", Xd)

# save state, post-jump
push!(t_hist, t)
pushTime!(problem, t)
push!(xc_hist, X_extended[ind_save_c])
push!(xd_hist, Xd[ind_save_d])

save_rate && push!(rate_hist, problem.caract.R(rate, X_extended, Xd, problem.caract.parms, t, true)[1])
save_rate && push!(problem.rate_hist, caract.R(ratecache.rate, X_extended, Xd, caract.parms, t, true)[1])

δt = - log(rand())

else
if ode in [:cvode, :bdf, :adams]
res_ode_last = Sundials.cvode((tt, x, xdot)->problem.caract.F(xdot,x,Xd,problem.caract.parms,tt), xc_hist[end], [t_hist[end], tf], abstol = 1e-9, reltol = 1e-7)
res_ode_last = Sundials.cvode((tt, x, xdot)->caract.F(xdot,x,Xd,caract.parms,tt), xc_hist[end], [problem.time[end], tf], abstol = 1e-9, reltol = 1e-7)
else#if ode==:lsoda
res_ode_last = LSODA.lsoda((tt, x, xdot, data)->problem.caract.F(xdot,x,Xd,problem.caract.parms,tt), xc_hist[end], [t_hist[end], tf], abstol = 1e-9, reltol = 1e-7)
res_ode_last = LSODA.lsoda((tt, x, xdot, data)->caract.F(xdot,x,Xd,caract.parms,tt), xc_hist[end], [problem.time[end], tf], abstol = 1e-9, reltol = 1e-7)
end
t = tf

# save state
push!(t_hist, tf)
pushTime!(problem, tf)
push!(xc_hist, res_ode_last[end,ind_save_c])
push!(xd_hist, Xd[ind_save_d])
end
nsteps += 1
end
verbose && println("--> Done")
verbose && println("--> xc = ", xd_hist[:,1:nsteps-1])
return PDMPResult(t_hist, xc_hist, xd_hist, rate_hist, save_positions)
return PDMPResult(problem.time, xc_hist, xd_hist, problem.rate_hist, save_positions, length(problem.time), 0)
end
2 changes: 1 addition & 1 deletion src/chvdiffeq.jl
Expand Up @@ -55,7 +55,7 @@ function chvjump(integrator, prob::PDMPProblem, save_pre_jump, save_rate, verbos
save_rate && push!(prob.rate_hist, sum(ratecache.rate))

# Update event
ev = pfsample(ratecache.rate, sum(ratecache.rate), length(ratecache.rate))
ev = pfsample(ratecache.rate)

# we perform the jump
affect!(caract.pdmpjump, ev, integrator.u, caract.xd, caract.parms, t)
Expand Down
2 changes: 1 addition & 1 deletion src/rejection.jl
Expand Up @@ -68,7 +68,7 @@ function solve(problem::PDMPProblem, Flow::Function; verbose::Bool = false, save
if (t < tf)
verbose && println("----> Jump!, ratio = ", ppf[1] / ppf[2], ", xd = ", Xd)
# make a jump
ev = pfsample(ratecache.rate, sum(ratecache.rate), length(ratecache.rate))
ev = pfsample(ratecache.rate)

# we perform the jump
affect!(caract.pdmpjump, ev, X0, Xd, caract.parms, t)
Expand Down
2 changes: 1 addition & 1 deletion src/rejectiondiffeq.jl
Expand Up @@ -56,7 +56,7 @@ function rejectionjump(integrator, prob::PDMPProblem, save_pre_jump, save_rate,
end

# Update event
ev = pfsample(ratecache.rate, sum(ratecache.rate), length(ratecache.rate))
ev = pfsample(ratecache.rate)

# we perform the jump
affect!(caract.pdmpjump, ev, integrator.u, caract.xd, caract.parms, t)
Expand Down
2 changes: 2 additions & 0 deletions src/utils.jl
Expand Up @@ -49,6 +49,8 @@ function pfsample(w::vec, s::Tc, n::Int64) where {Tc, vec <: AbstractVector{Tc}}
return i
end

pfsample(rate) = pfsample(rate, sum(rate), length(rate))

"""
This type stores the output composed of:
- **time** : a `Vector` of `Float64`, containing the times of simulated events.
Expand Down

0 comments on commit 216a5e4

Please sign in to comment.