Skip to content

Commit

Permalink
Merge pull request #18 from ChrisRackauckas/speed
Browse files Browse the repository at this point in the history
Specialize on the function and improve performance via StaticArrays
  • Loading branch information
sdwfrost committed May 11, 2020
2 parents 51fd982 + 05fa9df commit 95a9861
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 61 deletions.
1 change: 1 addition & 0 deletions Project.toml
Expand Up @@ -7,6 +7,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
Expand Down
35 changes: 14 additions & 21 deletions README.md
Expand Up @@ -22,14 +22,14 @@ This is an implementation of [Gillespie's direct method](http://en.wikipedia.org
The stable release of ```Gillespie.jl``` can be installed from the Julia REPL using the following command.

```julia
using Pkg
Pkg.add("Gillespie")
```

The development version from this repository can be installed as follows.

```julia
Pkg.clone("https://github.com/sdwfrost/Gillespie.jl")
Pkg.build("Gillespie")
Pkg.add("https://github.com/sdwfrost/Gillespie.jl")
```

## Example usage
Expand Down Expand Up @@ -86,33 +86,26 @@ The development version of ```Gillespie.jl``` includes code to simulate via unif

The development version of ```Gillespie.jl``` also includes code to simulate assuming time-varying rates via the true jump method; the API is the same as for the SSA, with the exception that the rate function must have three arguments, as described above.

## Performance considerations

Passing functions as arguments in Julia v0.4 incurs a performance penalty. One can circumvent this by passing an immutable object, with ```call``` overloaded, as follows.

```julia
immutable G; end
call(::Type{G},x,parms) = F(x,parms)
```

An example of this approach is given [here](https://github.com/sdwfrost/Gillespie.jl/blob/master/examples/sir2.jl). This is the default behaviour in v0.5 and above.

## Benchmarks

The speed of an SIR model in `Gillespie.jl` was compared to:

- A version using the R package `GillespieSSA`
- Handcoded versions of the SIR model in Julia, R, and Rcpp
- [DifferentialEquations.jl's](https://docs.sciml.ai/latest/) jump interface.

1000 simulations were performed, and the time per simulation computed (lower is better). Benchmarks were run on a Mac Pro (Late 2013), with 3 Ghz 8-core Intel Xeon E3, 64GB 1866 Mhz RAM, running OSX v 10.11.3 (El Capitan), using Julia v0.4.5 and R v.3.3. Jupyter notebooks for [Julia](https://gist.github.com/sdwfrost/8a0e926a5e16d7d104bd2bc1a5f9ed0b) and [R](https://gist.github.com/sdwfrost/afed3b881ef5742623b905a539197c7a) with the code and benchmarks are available as gists. A plain Julia file is also provided [in the benchmarks subdirectory](https://github.com/sdwfrost/Gillespie.jl/blob/master/benchmarks/sir-jl-benchmark.jl) for ease of benchmarking locally.

| Implementation | Time per simulation (ms) |
| -------------------------------------- | ------------------------ |
| R (GillespieSSA) | 894.25 |
| R (handcoded) | 1087.94 |
| Rcpp (handcoded) | 1.31 |
| Julia (Gillespie.jl) | 3.99 |
| Julia (Gillespie.jl, passing object) | 1.78 |
| Julia (handcoded) | 1.20 |
| Implementation | Time per simulation (ms) |
| -------------------------------------------| ------------------------ |
| R (GillespieSSA) | 463 |
| R (handcoded) | 785 |
| Rcpp (handcoded) | 1.40 |
| Julia (Gillespie.jl) | 1.69 |
| Julia (Gillespie.jl, Static) | 0.89 |
| Julia (DifferentialEquations.jl) | 1.14 |
| Julia (DifferentialEquations.jl, Static) | 0.72 |
| Julia (handcoded) | 0.49 |

(smaller is better)

Expand Down
60 changes: 43 additions & 17 deletions benchmarks/sir-jl-benchmark.jl
@@ -1,6 +1,4 @@

using DataFrames
using DataArrays
using Distributions
using Gillespie
using BenchmarkTools
Expand All @@ -10,10 +8,10 @@ function sir(beta,gamma,N,S0,I0,R0,tf)
S = S0
I = I0
R = R0
ta=DataArray(Float64,0)
Sa=DataArray(Float64,0)
Ia=DataArray(Float64,0)
Ra=DataArray(Float64,0)
ta=Vector{Float64}(undef,0)
Sa=Vector{Float64}(undef,0)
Ia=Vector{Float64}(undef,0)
Ra=Vector{Float64}(undef,0)
while t < tf
push!(ta,t)
push!(Sa,S)
Expand All @@ -37,10 +35,10 @@ function sir(beta,gamma,N,S0,I0,R0,tf)
end
end
results = DataFrame()
results[:t] = ta
results[:S] = Sa
results[:I] = Ia
results[:R] = Ra
results[!,:t] = ta
results[!,:S] = Sa
results[!,:I] = Ia
results[!,:R] = Ra
return(results)
end

Expand All @@ -58,15 +56,43 @@ parms = [0.1/10000.0,0.05]
tf = 1000.0

ssa(x0,F,nu,parms,tf) # compile
srand(1234)
Random.seed!(1234)
@benchmark ssa($x0,$F,$nu,$parms,$tf) samples=1000 seconds=100

immutable G; end
call(::Type{G},x,parms) = F(x,parms)
ssa(x0,G,nu,parms,tf) # compile
srand(1234)
@benchmark ssa($x0,$G,$nu,$parms,$tf) samples=1000 seconds=100
function F2(x,parms)
(S,I,R) = x
(beta,gamma) = parms
infection = beta*S*I
recovery = gamma*I
[infection,recovery]
end

x0 = SA[9999,1,0]
nu = SA[-1 1 0
0 -1 1]
parms = SA[0.1/10000.0,0.05]
tf = 1000.0

ssa(x0,F2,nu,parms,tf) # compile
Random.seed!(1234)
@benchmark ssa($x0,$F2,$nu,$parms,$tf) samples=1000 seconds=100

sir(0.1/10000,0.05,10000,9999,1,0,1000) # compile
srand(1234)
Random.seed!(1234)
@benchmark sir(0.1/10000,0.05,10000,9999,1,0,1000) samples=1000 seconds=100

using DiffEqBiological

sir_model = @reaction_network rn begin
0.1/10000.0, s + i --> 2i
0.05, i --> r
end
sir_prob = DiscreteProblem([9999,1,0],(0.0,tf))
sir_jump_prob = JumpProblem(sir_prob,Direct(),sir_model)
sir_sol = solve(sir_jump_prob,SSAStepper()) # compile
@benchmark solve(sir_jump_prob,SSAStepper()) samples=1000 seconds=100

sir_prob = DiscreteProblem(SA[9999,1,0],(0.0,tf))
sir_jump_prob = JumpProblem(sir_prob,Direct(),sir_model)
sir_sol = solve(sir_jump_prob,SSAStepper()) # compile
@benchmark solve(sir_jump_prob,SSAStepper()) samples=1000 seconds=100
3 changes: 2 additions & 1 deletion src/Gillespie.jl
Expand Up @@ -4,6 +4,7 @@ using Distributions
using DataFrames
using QuadGK
using Roots
using StaticArrays

export
ssa,
Expand Down Expand Up @@ -34,7 +35,7 @@ There are several named arguments:
- **thin**: (`Bool`) whether to thin jumps for Jensens method (default: `true`).
"
function ssa(x0::Vector{Int64},F::Base.Callable,nu::Matrix{Int64},parms::Vector{Float64},tf::Float64; algo=:gillespie, max_rate::Float64=0.0, thin::Bool=true)
function ssa(x0::AbstractVector{Int64},F::Base.Callable,nu::AbstractMatrix{Int64},parms::AbstractVector{Float64},tf::Float64; algo=:gillespie, max_rate::Float64=0.0, thin::Bool=true)
@assert algo in [:gillespie,:jensen,:tjm] "Available algorithms are :gillespie, :jensen, and :tjm"
if algo == :gillespie
return gillespie(x0,F,nu,parms,tf)
Expand Down
56 changes: 34 additions & 22 deletions src/SSA.jl
Expand Up @@ -19,11 +19,11 @@ end
- **alg** : the algorithm used (`Symbol`, either `:gillespie`, `jensen`, or `tjc`).
- **tvc** : whether rates are time varying.
"
struct SSAArgs
x0::Vector{Int64}
F::Any
nu::Matrix{Int64}
parms::Vector{Float64}
struct SSAArgs{X,Ftype,N,P}
x0::X
F::Ftype
nu::N
parms::P
tf::Float64
alg::Symbol
tvc::Bool
Expand Down Expand Up @@ -54,7 +54,7 @@ This function is a substitute for `StatsBase.sample(wv::WeightVec)`, which avoid
- **n** : the length of `w`.
"
function pfsample(w::Array{Float64,1},s::Float64,n::Int64)
function pfsample(w::AbstractArray{Float64,1},s::Float64,n::Int64)
t = rand() * s
i = 1
cw = w[1]
Expand All @@ -74,7 +74,7 @@ This function performs Gillespie's stochastic simulation algorithm. It takes the
- **parms** : a `Vector` of `Float64` representing the parameters of the system.
- **tf** : the final simulation time (`Float64`).
"
function gillespie(x0::Vector{Int64},F::Base.Callable,nu::Matrix{Int64},parms::Vector{Float64},tf::Float64)
function gillespie(x0::AbstractVector{Int64},F::Base.Callable,nu::AbstractMatrix{Int64},parms::AbstractVector{Float64},tf::Float64)
# Args
args = SSAArgs(x0,F,nu,parms,tf,:gillespie,false)
# Set up time array
Expand All @@ -84,7 +84,7 @@ function gillespie(x0::Vector{Int64},F::Base.Callable,nu::Matrix{Int64},parms::V
# Set up initial x
nstates = length(x0)
x = copy(x0')
xa = copy(x0)
xa = copy(Array(x0))
# Number of propensity functions
numpf = size(nu,1)
# Main loop
Expand All @@ -103,9 +103,13 @@ function gillespie(x0::Vector{Int64},F::Base.Callable,nu::Matrix{Int64},parms::V
push!(ta,t)
# Update event
ev = pfsample(pf,sumpf,numpf)
deltax = view(nu,ev,:)
for i in 1:nstates
@inbounds x[1,i] += deltax[i]
if x isa SVector
@inbounds x[1] += nu[ev,:]
else
deltax = view(nu,ev,:)
for i in 1:nstates
@inbounds x[1,i] += deltax[i]
end
end
for xx in x
push!(xa,xx)
Expand All @@ -127,7 +131,7 @@ This function performs the true jump method for piecewise deterministic Markov p
- **parms** : a `Vector` of `Float64` representing the parameters of the system.
- **tf** : the final simulation time (`Float64`).
"
function truejump(x0::Vector{Int64},F::Base.Callable,nu::Matrix{Int64},parms::Vector{Float64},tf::Float64)
function truejump(x0::AbstractVector{Int64},F::Base.Callable,nu::AbstractMatrix{Int64},parms::AbstractVector{Float64},tf::Float64)
# Args
args = SSAArgs(x0,F,nu,parms,tf,:tjm,true)
# Set up time array
Expand Down Expand Up @@ -161,9 +165,13 @@ function truejump(x0::Vector{Int64},F::Base.Callable,nu::Matrix{Int64},parms::Ve
push!(ta,t)
# Update event
ev = pfsample(pf,sumpf,numpf)
deltax = view(nu,ev,:)
for i in 1:nstates
@inbounds x[1,i] += deltax[i]
if x isa SVector
@inbounds x[1] += nu[ev,:]
else
deltax = view(nu,ev,:)
for i in 1:nstates
@inbounds x[1,i] += deltax[i]
end
end
for xx in x
push!(xa,xx)
Expand All @@ -186,9 +194,9 @@ This function performs stochastic simulation using thinning/uniformization/Jense
- **tf** : the final simulation time (`Float64`).
- **max_rate**: the maximum rate (`Float64`).
"
function jensen(x0::Vector{Int64},F::Base.Callable,nu::Matrix{Int64},parms::Vector{Float64},tf::Float64,max_rate::Float64,thin::Bool=true)
function jensen(x0::AbstractVector{Int64},F::Base.Callable,nu::AbstractMatrix{Int64},parms::AbstractVector{Float64},tf::Float64,max_rate::Float64,thin::Bool=true)
if thin==false
return jensen_alljumps(x0::Vector{Int64},F::Base.Callable,nu::Matrix{Int64},parms::Vector{Float64},tf::Float64,max_rate::Float64)
return jensen_alljumps(x0::AbstractVector{Int64},F::Base.Callable,nu::Matrix{Int64},parms::AbstractVector{Float64},tf::Float64,max_rate::Float64)
end
tvc=true
try
Expand Down Expand Up @@ -232,10 +240,14 @@ function jensen(x0::Vector{Int64},F::Base.Callable,nu::Matrix{Int64},parms::Vect
# Update event
ev = pfsample([pf; max_rate-sumpf],max_rate,numpf+1)
if ev < numpf
deltax = view(nu,ev,:)
for i in 1:nstates
@inbounds x[1,i] += deltax[i]
end
if x isa SVector
@inbounds x[1] += nu[ev,:]
else
deltax = view(nu,ev,:)
for i in 1:nstates
@inbounds x[1,i] += deltax[i]
end
end
for xx in x
push!(xa,xx)
end
Expand All @@ -259,7 +271,7 @@ This function performs stochastic simulation using thinning/uniformization/Jense
- **tf** : the final simulation time (`Float64`).
- **max_rate**: the maximum rate (`Float64`).
"
function jensen_alljumps(x0::Vector{Int64},F::Base.Callable,nu::Matrix{Int64},parms::Vector{Float64},tf::Float64,max_rate::Float64)
function jensen_alljumps(x0::AbstractVector{Int64},F::Base.Callable,nu::AbstractMatrix{Int64},parms::AbstractVector{Float64},tf::Float64,max_rate::Float64)
# Args
tvc=true
try
Expand Down

0 comments on commit 95a9861

Please sign in to comment.