From 98ea4f0e4cf40e69aab6cab5eaa6e99914d01b67 Mon Sep 17 00:00:00 2001 From: Romain Veltz Date: Mon, 7 Oct 2019 22:38:21 +0200 Subject: [PATCH] add finalizer --- examples/tcp2d.jl | 2 +- src/chv.jl | 4 +++- src/chvdiffeq.jl | 11 ++++++----- src/rejection.jl | 3 ++- src/rejectiondiffeq.jl | 7 ++++--- test/pdmpStiff.jl | 1 + test/runtests.jl | 6 +++++- 7 files changed, 22 insertions(+), 12 deletions(-) diff --git a/examples/tcp2d.jl b/examples/tcp2d.jl index 4190b9d..8ba7f51 100644 --- a/examples/tcp2d.jl +++ b/examples/tcp2d.jl @@ -48,7 +48,7 @@ result3 = PDMP.solve(problem, CHV(:lsoda); n_jumps = nj, save_positions=(false, #################################################################################################### # DEBUG DEBUG -# +# # algo = CHV(Tsit5()) # xd1 = zeros(Float64, length(xc0)+1) # xdd1 = similar(xd1) diff --git a/src/chv.jl b/src/chv.jl index 88b048a..2573fb2 100644 --- a/src/chv.jl +++ b/src/chv.jl @@ -2,7 +2,7 @@ include("chvdiffeq.jl") -function solve(problem::PDMPProblem, algo::CHV{Tode}; verbose::Bool = false, ind_save_d = -1:1, ind_save_c = -1:1, n_jumps = Inf64, reltol = 1e-7, abstol = 1e-9, save_positions = (false, true), save_rate = false) where {Tode <: Symbol} +function solve(problem::PDMPProblem, algo::CHV{Tode}; verbose::Bool = false, ind_save_d = -1:1, ind_save_c = -1:1, n_jumps = Inf64, reltol = 1e-7, abstol = 1e-9, save_positions = (false, true), save_rate = false, finalizer = finalize_dummy) where {Tode <: Symbol} verbose && println("#"^30) ode = algo.ode @assert ode in [:cvode, :lsoda, :adams, :bdf] @@ -96,6 +96,8 @@ function solve(problem::PDMPProblem, algo::CHV{Tode}; verbose::Bool = false, ind save_rate && push!(problem.rate_hist, caract.R(ratecache.rate, X_extended, Xd, caract.parms, t, true)[1]) + finalizer(ratecache.rate, caract.xc, caract.xd, caract.parms, t) + δt = - log(rand()) else diff --git a/src/chvdiffeq.jl b/src/chvdiffeq.jl index 9bda6f4..c33d0ad 100644 --- a/src/chvdiffeq.jl +++ b/src/chvdiffeq.jl @@ -76,7 +76,7 @@ end function chv_diffeq!(problem::PDMPProblem, ti::Tc, tf::Tc, X_extended::vece, - verbose = false; ode = Tsit5(), save_positions = (false, true), n_jumps::Td = Inf64, reltol=1e-7, abstol=1e-9, save_rate = false) where {Tc, Td, vece} + verbose = false; ode = Tsit5(), save_positions = (false, true), n_jumps::Td = Inf64, reltol=1e-7, abstol=1e-9, save_rate = false, finalizer = finalizer) where {Tc, Td, vece} verbose && println("#"^30) verbose && printstyled(color=:red,"Entry in chv_diffeq\n") @@ -134,6 +134,7 @@ function chv_diffeq!(problem::PDMPProblem, njumps +=1 verbose && println("----> end save post-jump, ") end + finalizer(ratecache.rate, caract.xc, caract.xd, caract.parms, t) end # we check that the last bit [t_last_jump, tf] is not missing if t>tf @@ -148,15 +149,15 @@ function chv_diffeq!(problem::PDMPProblem, return PDMPResult(problem, save_positions) end -function solve(problem::PDMPProblem{Tc, Td, vectype_xc, vectype_xd, Tcar}, algo::CHV{Tode}, X_extended; verbose = false, n_jumps = Inf64, save_positions = (false, true), reltol = 1e-7, abstol = 1e-9, save_rate = false) where {Tc, Td, vectype_xc, vectype_xd, vectype_rate, Tnu, Tp, TF, TR, Tcar, Tode <: DiffEqBase.DEAlgorithm} +function solve(problem::PDMPProblem{Tc, Td, vectype_xc, vectype_xd, Tcar}, algo::CHV{Tode}, X_extended; verbose = false, n_jumps = Inf64, save_positions = (false, true), reltol = 1e-7, abstol = 1e-9, save_rate = false, finalizer = finalize_dummy) where {Tc, Td, vectype_xc, vectype_xd, vectype_rate, Tnu, Tp, TF, TR, Tcar, Tode <: DiffEqBase.DEAlgorithm} - return chv_diffeq!(problem, problem.tspan[1], problem.tspan[2], X_extended, verbose; ode = algo.ode, save_positions = save_positions, n_jumps = n_jumps, reltol = reltol, abstol = abstol, save_rate = save_rate) + return chv_diffeq!(problem, problem.tspan[1], problem.tspan[2], X_extended, verbose; ode = algo.ode, save_positions = save_positions, n_jumps = n_jumps, reltol = reltol, abstol = abstol, save_rate = save_rate, finalizer = finalizer) end -function solve(problem::PDMPProblem{Tc, Td, vectype_xc, vectype_xd, Tcar}, algo::CHV{Tode}; verbose = false, n_jumps = Inf64, save_positions = (false, true), reltol = 1e-7, abstol = 1e-9, save_rate = false) where {Tc, Td, vectype_xc, vectype_xd, vectype_rate, Tnu, Tp, TF, TR, Tcar, Tode <: DiffEqBase.DEAlgorithm} +function solve(problem::PDMPProblem{Tc, Td, vectype_xc, vectype_xd, Tcar}, algo::CHV{Tode}; verbose = false, n_jumps = Inf64, save_positions = (false, true), reltol = 1e-7, abstol = 1e-9, save_rate = false, finalizer = finalize_dummy) where {Tc, Td, vectype_xc, vectype_xd, vectype_rate, Tnu, Tp, TF, TR, Tcar, Tode <: DiffEqBase.DEAlgorithm} # resize the extended vector to the proper dimension X_extended = zeros(Tc, length(problem.caract.xc) + 1) - return chv_diffeq!(problem, problem.tspan[1], problem.tspan[2], X_extended, verbose; ode = algo.ode, save_positions = save_positions, n_jumps = n_jumps, reltol = reltol, abstol = abstol, save_rate = save_rate) + return chv_diffeq!(problem, problem.tspan[1], problem.tspan[2], X_extended, verbose; ode = algo.ode, save_positions = save_positions, n_jumps = n_jumps, reltol = reltol, abstol = abstol, save_rate = save_rate, finalizer = finalizer ) end diff --git a/src/rejection.jl b/src/rejection.jl index fa23491..347d963 100644 --- a/src/rejection.jl +++ b/src/rejection.jl @@ -1,6 +1,6 @@ struct RejectionExact <: AbstractRejectionExact end -function solve(problem::PDMPProblem, Flow::Function; verbose::Bool = false, save_rejected = false, ind_save_d = -1:1, ind_save_c = -1:1, n_jumps = Inf64, save_positions = (false, true), save_rate = false) +function solve(problem::PDMPProblem, Flow::Function; verbose::Bool = false, save_rejected = false, ind_save_d = -1:1, ind_save_c = -1:1, n_jumps = Inf64, save_positions = (false, true), save_rate = false, finalizer = finalize_dummy) verbose && println("#"^30) verbose && printstyled(color=:red,"--> Start Rejection method\n") @@ -83,6 +83,7 @@ function solve(problem::PDMPProblem, Flow::Function; verbose::Bool = false, save push!(xc_hist, X0[ind_save_c]) push!(xd_hist, Xd[ind_save_d]) save_rate && push!(problem.rate_hist, sum(ratecache.rate)) + finalizer(ratecache.rate, caract.xc, caract.xd, caract.parms, t) end if verbose println("--> Done") end if verbose println("--> xd = ",xd_hist[:,1:nsteps]) end diff --git a/src/rejectiondiffeq.jl b/src/rejectiondiffeq.jl index c52ed08..0c5e972 100644 --- a/src/rejectiondiffeq.jl +++ b/src/rejectiondiffeq.jl @@ -79,7 +79,7 @@ end function rejection_diffeq!(problem::PDMPProblem, ti::Tc, tf::Tc, verbose = false; ode = Tsit5(), - save_positions = (false, true), n_jumps::Td = Inf64, reltol=1e-7, abstol=1e-9, save_rate = false) where {Tc, Td} + save_positions = (false, true), n_jumps::Td = Inf64, reltol=1e-7, abstol=1e-9, save_rate = false, finalizer = finalize_dummy) where {Tc, Td} verbose && println("#"^30) verbose && printstyled(color=:red,"Entry in rejection_diffeq\n") ti, tf = problem.tspan @@ -140,6 +140,7 @@ function rejection_diffeq!(problem::PDMPProblem, #put the flag for rejection simjptimes.reject = true end + finalizer(ratecache.rate, caract.xc, caract.xd, caract.parms, t) end # we check whether the last bit [t_last_jump, tf] is missing if t>tf @@ -155,7 +156,7 @@ function rejection_diffeq!(problem::PDMPProblem, end -function solve(problem::PDMPProblem{Tc, Td, vectype_xc, vectype_xd, Tcar}, algo::Rejection{Tode}; verbose = false, n_jumps = Inf64, save_positions = (false, true), reltol = 1e-7, abstol = 1e-9, save_rate = true) where {Tc, Td, vectype_xc, vectype_xd, vectype_rate, Tnu, Tp, TF, TR, Tcar, Tode <: DiffEqBase.DEAlgorithm} +function solve(problem::PDMPProblem{Tc, Td, vectype_xc, vectype_xd, Tcar}, algo::Rejection{Tode}; verbose = false, n_jumps = Inf64, save_positions = (false, true), reltol = 1e-7, abstol = 1e-9, save_rate = true, finalizer = finalize_dummy) where {Tc, Td, vectype_xc, vectype_xd, vectype_rate, Tnu, Tp, TF, TR, Tcar, Tode <: DiffEqBase.DEAlgorithm} - return rejection_diffeq!(problem, problem.tspan[1], problem.tspan[2], verbose; ode = algo.ode, save_positions = save_positions, n_jumps = n_jumps, reltol = reltol, abstol = abstol, save_rate = save_rate ) + return rejection_diffeq!(problem, problem.tspan[1], problem.tspan[2], verbose; ode = algo.ode, save_positions = save_positions, n_jumps = n_jumps, reltol = reltol, abstol = abstol, save_rate = save_rate, finalizer = finalizer ) end diff --git a/test/pdmpStiff.jl b/test/pdmpStiff.jl index 8061418..b4ffb07 100644 --- a/test/pdmpStiff.jl +++ b/test/pdmpStiff.jl @@ -119,6 +119,7 @@ println("\n\nComparison of solvers - rejection") problem = PDMP.PDMPProblem(F!, R!, nu, xc0, xd0, parms, (ti, tf)) res = PDMP.solve(problem, Rejection(ode[1]); n_jumps = 4, verbose = false) println("--> norm difference = ", norm(res.time - res_a_rej[1][1:4], Inf64), " - solver = ",ode[2]) + @test norm(res.time - res_a_rej[1][1:4], Inf64) < 0.0043 end Random.seed!(8) diff --git a/test/runtests.jl b/test/runtests.jl index e4d5391..80d51dd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,9 @@ using PiecewiseDeterministicMarkovProcesses, Test, LinearAlgebra, Random, DifferentialEquations +macro testS(label, args...) + :(@testset $label begin @test $(args...); end) +end + @testset "Example TCP" begin include("../examples/tcp.jl") @test norm(errors[6:end], Inf64) < 1e-4 @@ -13,7 +17,7 @@ end @testset "Example with stiff ODE part" begin include("pdmpStiff.jl") @test norm(errors, Inf64) < 1e-3 - @test restime1 == res12.time + @testS "Call many times the same problem" restime1 == res12.time end @testset "Controlling allocations" begin