Skip to content

Commit

Permalink
refactoring chvdiffeq
Browse files Browse the repository at this point in the history
  • Loading branch information
rveltz committed Oct 1, 2019
1 parent 29a8716 commit a6069e9
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 25 deletions.
1 change: 1 addition & 0 deletions examples/pdmpStiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ Random.seed!(8)
alloc1 = @allocated PDMP.solve(problem, CHV(Tsit5()); n_jumps = nj, save_positions = (false, false))
Random.seed!(8)
alloc2 = @allocated PDMP.solve(problem, CHV(Tsit5()); n_jumps = 2nj, save_positions = (false, false))
println("--> allocations = ", (alloc1, alloc2))

# test for many calls to solve, the trajectories should be the same
problem = PDMP.PDMPProblem(F!, R!, nu, xc0, xd0, parms, (ti, tf))
Expand Down
54 changes: 29 additions & 25 deletions src/chvdiffeq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,32 +28,34 @@ function chvjump(integrator, prob::PDMPProblem, save_pre_jump, save_rate, verbos
caract = prob.caract
ratecache = caract.ratecache
simjptimes = prob.simjptimes

# final simulation time
tf = prob.tspan[2]

# find the next jump time
t = integrator.u[end]
prob.simjptimes.lastjumptime = t
simjptimes.lastjumptime = t

verbose && printstyled(color=:green, "--> Jump detected at t = $t !!\n")
verbose && printstyled(color=:green, "--> jump not yet performed, xd = ", caract.xd,"\n")

if (save_pre_jump) && (t <= tf)
verbose && printstyled(color=:green, "----> saving pre-jump\n")
push!(prob.Xc, (integrator.u[1:end-1]))
push!(prob.Xd, copy(caract.xd))
push!(prob.time,t)
pushXc!(prob, (integrator.u[1:end-1]))
pushXd!(prob, copy(caract.xd))
pushTime!(prob, t)
#save rates for debugging
save_rate && push!(prob.rate_hist, sum(ratecache.rate))
end

# execute the jump
caract.R(get_rate(caract.ratecache, integrator.u), integrator.u, caract.xd, caract.parms, t, false)
caract.R(get_rate(ratecache, integrator.u), integrator.u, caract.xd, caract.parms, t, false)
if (t < tf)
#save rates for debugging
save_rate && push!(prob.rate_hist, sum(caract.ratecache.rate))
save_rate && push!(prob.rate_hist, sum(ratecache.rate))

# Update event
ev = pfsample(caract.ratecache.rate, sum(caract.ratecache.rate), length(caract.ratecache.rate))
ev = pfsample(ratecache.rate, sum(ratecache.rate), length(ratecache.rate))

# we perform the jump
affect!(caract.pdmpjump, ev, integrator.u, caract.xd, caract.parms, t)
Expand All @@ -66,9 +68,9 @@ function chvjump(integrator, prob::PDMPProblem, save_pre_jump, save_rate, verbos
end
verbose && printstyled(color=:green,"--> jump computed, xd = ",caract.xd,"\n")
# we register the next time interval to solve the extended ode
prob.simjptimes.njumps += 1
prob.simjptimes.tstop_extended += -log(rand())
add_tstop!(integrator, prob.simjptimes.tstop_extended)
simjptimes.njumps += 1
simjptimes.tstop_extended += -log(rand())
add_tstop!(integrator, simjptimes.tstop_extended)
verbose && printstyled(color=:green,"--> End jump\n\n")
end

Expand All @@ -86,6 +88,8 @@ function chv_diffeq!(problem::PDMPProblem,

# we declare the characteristics for convenience
caract = problem.caract
ratecache = caract.ratecache
simjptimes = problem.simjptimes

#ISSUE HERE, IF USING A PROBLEM p MAKE SURE THE TIMES in p.sim ARE WELL SET
# set up the current time as the initial time
Expand All @@ -107,25 +111,25 @@ function chv_diffeq!(problem::PDMPProblem,

# define the ODE flow, this leads to big memory saving
# prob_CHV = ODEProblem((xdot,x,data,tt) -> problem(xdot, x, data, tt), X_extended, (0.0, 1e9))
prob_CHV = ODEProblem((xdot, x, data, tt) -> algopdmp(xdot, x, problem.caract, tt), X_extended, (0.0, 1e9))
integrator = init(prob_CHV, ode, tstops = problem.simjptimes.tstop_extended, callback = cb, save_everystep = false, reltol = reltol, abstol = abstol, advance_to_tstop = true)
prob_CHV = ODEProblem((xdot, x, data, tt) -> algopdmp(xdot, x, caract, tt), X_extended, (0.0, 1e9))
integrator = init(prob_CHV, ode, tstops = simjptimes.tstop_extended, callback = cb, save_everystep = false, reltol = reltol, abstol = abstol, advance_to_tstop = true)

# current jump number
njumps = 0

while (t < tf) && problem.simjptimes.njumps < n_jumps-1
verbose && println("--> n = $(problem.simjptimes.njumps), t = $t, δt = ",problem.simjptimes.tstop_extended)
while (t < tf) && simjptimes.njumps < n_jumps-1
verbose && println("--> n = $(problem.simjptimes.njumps), t = $t, δt = ", simjptimes.tstop_extended)
step!(integrator)

@assert( t < problem.simjptimes.lastjumptime, "Could not compute next jump time $(problem.simjptimes.njumps).\nReturn code = $(integrator.sol.retcode)\n $t < $(problem.simjptimes.lastjumptime),\n solver = $ode. dt = $(t - problem.simjptimes.lastjumptime)")
t, tprev = problem.simjptimes.lastjumptime, t
@assert( t < simjptimes.lastjumptime, "Could not compute next jump time $(simjptimes.njumps).\nReturn code = $(integrator.sol.retcode)\n $t < $(simjptimes.lastjumptime),\n solver = $ode. dt = $(t - simjptimes.lastjumptime)")
t, tprev = simjptimes.lastjumptime, t

# the previous step was a jump! should we save it?
if njumps < problem.simjptimes.njumps && save_positions[2] && (t <= tf)
if njumps < simjptimes.njumps && save_positions[2] && (t <= tf)
verbose && println("----> save post-jump, xd = ",problem.Xd)
push!(problem.Xc, copy(caract.xc))
push!(problem.Xd, copy(caract.xd))
push!(problem.time, t)
pushXc!(problem, copy(caract.xc))
pushXd!(problem, copy(caract.xd))
pushTime!(problem, t)
njumps +=1
verbose && println("----> end save post-jump, ")
end
Expand All @@ -136,11 +140,11 @@ function chv_diffeq!(problem::PDMPProblem,
prob_last_bit = ODEProblem((xdot,x,data,tt) -> caract.F(xdot, x, caract.xd, caract.parms, tt), copy(caract.xc), (tprev, tf))
sol = DiffEqBase.solve(prob_last_bit, ode)
verbose && println("-------> xc[end] = ",sol.u[end])
push!(problem.Xc, sol.u[end])
push!(problem.Xd, copy(caract.xd))
push!(problem.time, sol.t[end])
pushXc!(problem, sol.u[end])
pushXd!(problem, copy(caract.xd))
pushTime!(problem, sol.t[end])
end
return PDMPResult(problem.time, problem.Xc, problem.Xd, problem.rate_hist, save_positions)
return PDMPResult(problem, save_positions)
end

function solve(problem::PDMPProblem{Tc, Td, vectype_xc, vectype_xd, Tcar}, algo::CHV{Tode}, X_extended; verbose = false, n_jumps = Inf64, save_positions = (false, true), reltol = 1e-7, abstol = 1e-9, save_rate = false) where {Tc, Td, vectype_xc, vectype_xd, vectype_rate, Tnu, Tp, TF, TR, Tcar, Tode <: DiffEqBase.DEAlgorithm}
Expand Down

0 comments on commit a6069e9

Please sign in to comment.