In [1]:
abstract type AbstractJumpAggregator end
abstract type AbstractJump end
abstract type AbstractSSAJumpAggregator <: AbstractJumpAggregator end
mutable struct DirectJumpAggregation{T,S,F1,F2,RNG} <: AbstractSSAJumpAggregator
  next_jump::Int
  prev_jump::Int
  next_jump_time::T
  end_time::T
  cur_rates::Vector{T}
  sum_rate::T
  ma_jumps::S
  rates::F1
  affects!::F2
  save_positions::Tuple{Bool,Bool}
  rng::RNG
end
DirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool,Bool}, rng::RNG; kwargs...) where {T,S,F1,F2,RNG} =
  DirectJumpAggregation{T,S,F1,F2,RNG}(nj, nj, njt, et, crs, sr, maj, rs, affs!, sps, rng)

DirectJumpAggregation

In [2]:
abstract type AbstractAggregatorAlgorithm end
struct Direct <: AbstractAggregatorAlgorithm end

In [3]:
# creating the JumpAggregation structure (tuple-based constant jumps)
function aggregate(aggregator::Direct, u, p, t, end_time, constant_jumps,
    ma_jumps, save_positions, rng; kwargs...)

# handle constant jumps using tuples
rates, affects! = get_jump_info_tuples(constant_jumps)

build_jump_aggregation(DirectJumpAggregation, u, p, t, end_time, ma_jumps,
          rates, affects!, save_positions, rng; kwargs...)
end

aggregate (generic function with 1 method)

In [4]:
function get_jump_info_tuples(constant_jumps)
    if (constant_jumps !== nothing) && !isempty(constant_jumps)
      rates    = ((c.rate for c in constant_jumps)...,)
      affects! = ((c.affect! for c in constant_jumps)...,)
    else
      rates    = ()
      affects! = ()
    end
  
    rates, affects!
  end

get_jump_info_tuples (generic function with 1 method)

In [5]:
function build_jump_aggregation(jump_agg_type, u, p, t, end_time, ma_jumps, rates,
    affects!, save_positions, rng; kwargs...)

# mass action jumps
majumps = ma_jumps
if majumps === nothing
majumps = MassActionJump(Vector{typeof(t)}(),
 Vector{Vector{Pair{Int,eltype(u)}}}(),
 Vector{Vector{Pair{Int,eltype(u)}}}())
end

# current jump rates, allows mass action rates and constant jumps
cur_rates = Vector{typeof(t)}(undef, get_num_majumps(majumps) + length(rates))

sum_rate = zero(typeof(t))
next_jump = 0
next_jump_time = typemax(typeof(t))
jump_agg_type(next_jump, next_jump_time, end_time, cur_rates, sum_rate,
majumps, rates, affects!, save_positions, rng; kwargs...)
end

build_jump_aggregation (generic function with 1 method)

## where `get_num_majumps` is defined by:

Here invoque type `MassActionJump`  so have to define it as follows

# Define ConstantRateJump and MassActionJump

In [6]:
#TODO Simplify the code pmapper maybe
struct MassActionJump{T,S,U,V} <: AbstractJump
    scaled_rates::T
    reactant_stoch::S
    net_stoch::U
    param_mapper::V
  
    function MassActionJump{T,S,U,V}(rates::T, rs_in::S, ns::U, pmapper::V, scale_rates::Bool, useiszero::Bool, nocopy::Bool) where {T <: AbstractVector, S, U, V}
      sr  = nocopy ? rates : copy(rates)
      rs = nocopy ? rs_in : copy(rs_in)
      for i in eachindex(rs)
        if useiszero && (length(rs[i]) == 1) && iszero(rs[i][1][1])
          rs[i] = typeof(rs[i])()
        end
      end
  
      if scale_rates && !isempty(sr)
        scalerates!(sr, rs)
      end
      new(sr, rs, ns, pmapper)
    end
    function MassActionJump{Nothing,Vector{S},Vector{U},V}(::Nothing, rs_in::Vector{S}, ns::Vector{U}, pmapper::V, scale_rates::Bool, useiszero::Bool, nocopy::Bool) where {S<:AbstractVector, U<:AbstractVector, V}
      rs = nocopy ? rs_in : copy(rs_in)
      for i in eachindex(rs)
        if useiszero && (length(rs[i]) == 1) && iszero(rs[i][1][1])
          rs[i] = typeof(rs[i])()
        end
      end
      new(nothing, rs, ns, pmapper)
    end
    function MassActionJump{T,S,U,V}(rate::T, rs_in::S, ns::U, pmapper::V, scale_rates::Bool, useiszero::Bool, nocopy::Bool) where {T <: Number, S, U, V}
      rs = rs_in
      if useiszero && (length(rs) == 1) && iszero(rs[1][1])
        rs = typeof(rs)()
      end
      sr = scale_rates ? scalerate(rate, rs) : rate
      new(sr, rs, ns, pmapper)
    end
    function MassActionJump{Nothing,S,U,V}(::Nothing, rs_in::S, ns::U, pmapper::V, scale_rates::Bool, useiszero::Bool, nocopy::Bool) where {S, U, V}
      rs = rs_in
      if useiszero && (length(rs) == 1) && iszero(rs[1][1])
        rs = typeof(rs)()
      end
      new(nothing, rs, ns, pmapper)
    end
  
end
MassActionJump(usr::T, rs::S, ns::U, pmapper::V; scale_rates = true, useiszero = true, nocopy=false) where {T,S,U,V} = MassActionJump{T,S,U,V}(usr, rs, ns, pmapper, scale_rates, useiszero, nocopy)

MassActionJump(usr::T, rs, ns; scale_rates = true, useiszero = true, nocopy=false) where {T <: AbstractVector,S,U} = MassActionJump(usr, rs, ns, nothing; scale_rates=scale_rates, useiszero=useiszero, nocopy=nocopy)

MassActionJump(usr::T, rs, ns; scale_rates = true, useiszero = true, nocopy=false) where {T <: Number,S,U} = MassActionJump(usr, rs, ns, nothing; scale_rates=scale_rates, useiszero=useiszero, nocopy=nocopy)

# with parameter indices or mapping, multiple jump case
function MassActionJump(rs, ns; param_idxs=nothing, param_mapper=nothing, scale_rates = true, useiszero = true, nocopy=false)
  if param_mapper === nothing 
    (param_idxs === nothing) && error("If no parameter indices are given via param_idxs, an explicit parameter mapping must be passed in via param_mapper.")
    pmapper = MassActionJumpParamMapper(param_idxs)
  else
    (param_idxs !== nothing) && error("Only one of param_idxs and param_mapper should be passed.")
    pmapper = param_mapper
  end
                      
  MassActionJump(nothing, nocopy ? rs : copy(rs), ns, pmapper; scale_rates=scale_rates, 
                 useiszero=useiszero, nocopy=true)
end

MassActionJump

In [7]:
@inline get_num_majumps(maj::MassActionJump) = length(maj.scaled_rates)
@inline get_num_majumps(maj::Nothing) = 0

get_num_majumps (generic function with 2 methods)

In [8]:
struct ConstantRateJump{F1,F2} <: AbstractJump
    rate::F1
    affect!::F2
end  

In [9]:
function scalerates!(unscaled_rates::AbstractVector{U}, stochmat::AbstractVector{V}) where {U,S,T,W <: Pair{S,T}, V <: AbstractVector{W}}
    @inbounds for i in eachindex(unscaled_rates)
        coef = one(T)
        @inbounds for specstoch in stochmat[i]
            coef *= factorial(specstoch[2])
        end
        unscaled_rates[i] /= coef
    end
    nothing
end

function scalerate(unscaled_rate::U, stochmat::AbstractVector{Pair{S,T}}) where {U <: Number, S, T}
    coef = one(T)
    @inbounds for specstoch in stochmat
        coef *= factorial(specstoch[2])
    end
    unscaled_rate /= coef
end

scalerate (generic function with 1 method)

## An example of MassActionJump

In [66]:
rates1 = [0.1/1000.0]
reactant_stoich1 = [[1=>1,2=>1]]
net_stoich1 = [[1=>-1,2=>1]]
jump1 = MassActionJump(rates1, reactant_stoich1, net_stoich1; scale_rates = false)

MassActionJump{Vector{Float64}, Vector{Vector{Pair{Int64, Int64}}}, Vector{Vector{Pair{Int64, Int64}}}, Nothing}([0.0001], [[1 => 1, 2 => 1]], [[1 => -1, 2 => 1]], nothing)

## An example of ConstantRateJump

In [67]:
rate2 = (u,p,t) -> 0.01u[2]
affect! = function (integrator)
  integrator.u[2] -= 1
  integrator.u[3] += 1
end
jump2 = ConstantRateJump(rate2,affect!)

ConstantRateJump{var"#18#19", var"#20#21"}(var"#18#19"(), var"#20#21"())

# Next we want to build aggregator

In [97]:
using SciMLBase, Random
u0 = [999.,1.,0.]
p = SciMLBase.NullParameters()
t0 = 0.0 # initial time
end_time  = 200.
constant_jumps = jump2 
ma_jumps = jump1
save_positions = (false,false)
rng = Random.seed!(1234)

MersenneTwister(1234)

In [98]:
struct JumpSet{T1,T2,T3,T4} <: AbstractJump
    variable_jumps::T1
    constant_jumps::T2
    regular_jump::T3
    massaction_jump::T4
end

In [99]:
jumpset=JumpSet((),(constant_jumps,),nothing,ma_jumps)

JumpSet{Tuple{}, Tuple{ConstantRateJump{var"#18#19", var"#20#21"}}, Nothing, MassActionJump{Vector{Float64}, Vector{Vector{Pair{Int64, Int64}}}, Vector{Vector{Pair{Int64, Int64}}}, Nothing}}((), (ConstantRateJump{var"#18#19", var"#20#21"}(var"#18#19"(), var"#20#21"()),), nothing, MassActionJump{Vector{Float64}, Vector{Vector{Pair{Int64, Int64}}}, Vector{Vector{Pair{Int64, Int64}}}, Nothing}([0.0001], [[1 => 1, 2 => 1]], [[1 => -1, 2 => 1]], nothing))

In [100]:
jumpset.constant_jumps

(ConstantRateJump{var"#18#19", var"#20#21"}(var"#18#19"(), var"#20#21"()),)

In [101]:
dja = aggregate(Direct(),u0,p,t0,end_time,jumpset.constant_jumps, jumpset.massaction_jump, save_positions, rng)

DirectJumpAggregation{Float64, MassActionJump{Vector{Float64}, Vector{Vector{Pair{Int64, Int64}}}, Vector{Vector{Pair{Int64, Int64}}}, Nothing}, Tuple{var"#18#19"}, Tuple{var"#20#21"}, MersenneTwister}(0, 0, Inf, 200.0, [2.2860795294e-314, 2.2857227113e-314], 0.0, MassActionJump{Vector{Float64}, Vector{Vector{Pair{Int64, Int64}}}, Vector{Vector{Pair{Int64, Int64}}}, Nothing}([0.0001], [[1 => 1, 2 => 1]], [[1 => -1, 2 => 1]], nothing), (var"#18#19"(),), (var"#20#21"(),), (false, false), MersenneTwister(1234))

## So if we look at the aggregator, the issue is that the next_jump_time = Inf, this is normall because we haven't initialize the prob, in order to do it, one needs to define the `JumpProblem`. Maybe we can skip the defining the DEProblem, because it will have a lot of dependencies. So one idea is to just update `aggregator` for the first time.

In [102]:

"""
fill_rates_and_sum!(p::AbstractSSAJumpAggregator, u, params, t)

Reevaluate all rates and their sum.
"""
function fill_rates_and_sum!(p::AbstractSSAJumpAggregator, u, params, t)
    sum_rate = zero(typeof(p.sum_rate))

    # mass action jumps
    majumps   = p.ma_jumps
    cur_rates = p.cur_rates
    @inbounds for i in 1:get_num_majumps(majumps)
        cur_rates[i] = evalrxrate(u, i, majumps)
        sum_rate    += cur_rates[i]
    end

    # constant rates
    rates = p.rates
    idx   = get_num_majumps(majumps) + 1
    @inbounds for rate in rates
        cur_rates[idx] = rate(u, params, t)
        sum_rate += cur_rates[idx]
        idx += 1
    end

    p.sum_rate = sum_rate
    nothing
end
@inline @fastmath function evalrxrate(speciesvec::AbstractVector{T}, rxidx::S, majump::MassActionJump{U,V,W,X})::R where {T,S,R,U <: AbstractVector{R},V,W,X}
    val = one(T)
    @inbounds for specstoch in majump.reactant_stoch[rxidx]
        specpop = speciesvec[specstoch[1]]
        val    *= specpop
        @inbounds for k = 2:specstoch[2]
            specpop -= one(specpop)
            val     *= specpop
        end
    end

    @inbounds return val * majump.scaled_rates[rxidx]
end

evalrxrate (generic function with 1 method)

## update the rates and sum of rates

In [103]:
fill_rates_and_sum!(dja, u0, p, t0 )

In [104]:
dja

DirectJumpAggregation{Float64, MassActionJump{Vector{Float64}, Vector{Vector{Pair{Int64, Int64}}}, Vector{Vector{Pair{Int64, Int64}}}, Nothing}, Tuple{var"#18#19"}, Tuple{var"#20#21"}, MersenneTwister}(0, 0, Inf, 200.0, [0.0999, 0.01], 0.1099, MassActionJump{Vector{Float64}, Vector{Vector{Pair{Int64, Int64}}}, Vector{Vector{Pair{Int64, Int64}}}, Nothing}([0.0001], [[1 => 1, 2 => 1]], [[1 => -1, 2 => 1]], nothing), (var"#18#19"(),), (var"#20#21"(),), (false, false), MersenneTwister(1234))

In [105]:
function initialize!(p::DirectJumpAggregation, integrator, u, params, t)
  generate_jumps!(p, integrator, u, params, t)
  nothing
end
# calculate the next jump / jump time
function generate_jumps!(p::DirectJumpAggregation, integrator, u, params, t)
  p.sum_rate, ttnj = time_to_next_jump(p, u, params, t)
  @fastmath p.next_jump_time = t + ttnj
  @inbounds p.next_jump = searchsortedfirst(p.cur_rates, rand(p.rng) * p.sum_rate) # 这里用到了 searchsortedfirst 是一个加速的function
  nothing
end

generate_jumps! (generic function with 1 method)

## 下面这个是核心函数，这个决定如何更新 
- `next_jump`
- `next_jump_time`
- `sum_rate`

In [106]:
@fastmath function time_to_next_jump(p::DirectJumpAggregation{T,S,F1,F2,RNG}, u, params, t) where {T,S,F1 <: Tuple, F2 <: Tuple, RNG}
  prev_rate = zero(t)
  new_rate  = zero(t)
  cur_rates = p.cur_rates

  # mass action rates
  majumps   = p.ma_jumps
  idx       = get_num_majumps(majumps)
  @inbounds for i in 1:idx
    new_rate     = evalrxrate(u, i, majumps)
    cur_rates[i] = new_rate + prev_rate
    prev_rate    = cur_rates[i]
  end

  # constant jump rates
  rates = p.rates
  if !isempty(rates)
    idx  += 1
    fill_cur_rates(u, params, t, cur_rates, idx, rates...)
    @inbounds for i in idx:length(cur_rates)
      cur_rates[i] = cur_rates[i] + prev_rate
      prev_rate    = cur_rates[i]
    end
  end

  @inbounds sum_rate = cur_rates[end]
  sum_rate, randexp(p.rng) / sum_rate
end
@inline function fill_cur_rates(u, p, t, cur_rates, idx, rate, rates...)
  @inbounds cur_rates[idx] = rate(u, p, t)
  idx += 1
  fill_cur_rates(u, p, t, cur_rates, idx, rates...)
end

@inline function fill_cur_rates(u, p, t, cur_rates, idx, rate)
  @inbounds cur_rates[idx] = rate(u, p, t)
  nothing
end


fill_cur_rates (generic function with 2 methods)

## Now the Integrator is needed 
## First, let us construct the `Integrator`, which encodes the time and state information

In [79]:
mutable struct Integrator{T}
    u::Vector{T}
    t::Float64
    p::Union{Vector{T},SciMLBase.NullParameters}
end 

In [107]:
integrator = Integrator(copy(u0),copy(t0),p)

Integrator{Float64}([999.0, 1.0, 0.0], 0.0, SciMLBase.NullParameters())

In [108]:
integrator.u

3-element Vector{Float64}:
 999.0
   1.0
   0.0

In [109]:
initialize!(dja, integrator, copy(u0), p, copy(t0))

In [110]:
dja

DirectJumpAggregation{Float64, MassActionJump{Vector{Float64}, Vector{Vector{Pair{Int64, Int64}}}, Vector{Vector{Pair{Int64, Int64}}}, Nothing}, Tuple{var"#18#19"}, Tuple{var"#20#21"}, MersenneTwister}(1, 0, 22.597865080896174, 200.0, [0.0999, 0.1099], 0.1099, MassActionJump{Vector{Float64}, Vector{Vector{Pair{Int64, Int64}}}, Vector{Vector{Pair{Int64, Int64}}}, Nothing}([0.0001], [[1 => 1, 2 => 1]], [[1 => -1, 2 => 1]], nothing), (var"#18#19"(),), (var"#20#21"(),), (false, false), MersenneTwister(1234, (0, 1002, 0, 2)))

In [111]:
integrator

Integrator{Float64}([999.0, 1.0, 0.0], 0.0, SciMLBase.NullParameters())

## Next step update integrator, the key function is `update_state!`

In [112]:
using UnPack, StaticArrays
function (p::AbstractSSAJumpAggregator)(integrator)
    execute_jumps!(p, integrator, integrator.u, integrator.p, integrator.t)
    generate_jumps!(p, integrator, integrator.u, integrator.p, integrator.t)
    nothing
end
@inline function execute_jumps!(p::DirectJumpAggregation, integrator, u, params, t)
    update_state!(p, integrator, u)
    nothing
end
@inline function update_state!(p::AbstractSSAJumpAggregator, integrator, u)
    @unpack ma_jumps, next_jump = p
    num_ma_rates = get_num_majumps(ma_jumps)
    if next_jump <= num_ma_rates # is next jump a mass action jump
        if u isa SVector
            integrator.u = executerx(u, next_jump, ma_jumps)
        else
            @inbounds executerx!(u, next_jump, ma_jumps)
        end
    else
        idx = next_jump - num_ma_rates
        @inbounds p.affects![idx](integrator)
    end

    # save jump that was just executed
    p.prev_jump = next_jump
    return integrator.u
end
@inline @fastmath function executerx!(speciesvec::AbstractVector{T}, rxidx::S,
    majump::MassActionJump{U,V,W,X}) where {T,S,U,V,W,X}
@inbounds net_stoch = majump.net_stoch[rxidx]
@inbounds for specstoch in net_stoch
speciesvec[specstoch[1]] += specstoch[2]
end
nothing
end


executerx! (generic function with 1 method)

In [113]:
dja(integrator)

In [114]:
integrator

Integrator{Float64}([998.0, 2.0, 0.0], 0.0, SciMLBase.NullParameters())

In [115]:
dja(integrator)

In [116]:
dja

DirectJumpAggregation{Float64, MassActionJump{Vector{Float64}, Vector{Vector{Pair{Int64, Int64}}}, Vector{Vector{Pair{Int64, Int64}}}, Nothing}, Tuple{var"#18#19"}, Tuple{var"#20#21"}, MersenneTwister}(1, 1, 3.96997761029981, 200.0, [0.29910000000000003, 0.32910000000000006], 0.32910000000000006, MassActionJump{Vector{Float64}, Vector{Vector{Pair{Int64, Int64}}}, Vector{Vector{Pair{Int64, Int64}}}, Nothing}([0.0001], [[1 => 1, 2 => 1]], [[1 => -1, 2 => 1]], nothing), (var"#18#19"(),), (var"#20#21"(),), (false, false), MersenneTwister(1234, (0, 1002, 0, 6)))

In [118]:
integrator

Integrator{Float64}([997.0, 3.0, 0.0], 0.0, SciMLBase.NullParameters())

In [120]:
u0, t0

([999.0, 1.0, 0.0], 0.0)

# So up to now we are able to update the integrator by each iteration.
## Next question is how to store to integrator in a list

In [33]:
mutable struct SSAIntegrator{uType,tType,P,S,SA} 
    u::uType
    t::tType
    tprev::tType
    p::P
    sol::S
    i::Int
    saveat::SA
    cur_saveat::Int
    end_time::tType
    save_end::Bool
end

In [34]:
mutable struct SSASolution{tType,uType}
    t::Vector{tType}
    u::Vector{uType}
end

In [128]:
ssa_sol = SSASolution([],[])

SSASolution{Any, Any}(Any[], Any[])

In [129]:
integrator

Integrator{Float64}([997.0, 3.0, 0.0], 0.0, SciMLBase.NullParameters())

In [130]:
dja

DirectJumpAggregation{Float64, MassActionJump{Vector{Float64}, Vector{Vector{Pair{Int64, Int64}}}, Vector{Vector{Pair{Int64, Int64}}}, Nothing}, Tuple{var"#18#19"}, Tuple{var"#20#21"}, MersenneTwister}(1, 1, 3.96997761029981, 200.0, [0.29910000000000003, 0.32910000000000006], 0.32910000000000006, MassActionJump{Vector{Float64}, Vector{Vector{Pair{Int64, Int64}}}, Vector{Vector{Pair{Int64, Int64}}}, Nothing}([0.0001], [[1 => 1, 2 => 1]], [[1 => -1, 2 => 1]], nothing), (var"#18#19"(),), (var"#20#21"(),), (false, false), MersenneTwister(1234, (0, 1002, 0, 6)))

In [131]:
current_time = 0.
next_end_time = dja.next_jump_time + current_time
ssaintegrator = SSAIntegrator(copy(u0),copy(t0),0.,p,ssa_sol,0,0:1:200,1,next_end_time,true)

SSAIntegrator{Vector{Float64}, Float64, SciMLBase.NullParameters, SSASolution{Any, Any}, StepRange{Int64, Int64}}([999.0, 1.0, 0.0], 0.0, 0.0, SciMLBase.NullParameters(), SSASolution{Any, Any}(Any[], Any[]), 0, 0:1:200, 1, 3.96997761029981, true)

In [132]:
using DiffEqBase
function DiffEqBase.solve!(integrator)
    end_time = integrator.end_time
    integrator.t = end_time
    if integrator.saveat !== nothing && !isempty(integrator.saveat)
        # Split to help prediction
        while integrator.cur_saveat <= length(integrator.saveat) &&
           integrator.saveat[integrator.cur_saveat] < integrator.t

            push!(integrator.sol.t,integrator.saveat[integrator.cur_saveat])
            push!(integrator.sol.u,copy(integrator.u))
            integrator.cur_saveat += 1

        end
    end

    if integrator.save_end && integrator.sol.t[end] != end_time
        push!(integrator.sol.t,end_time)
        push!(integrator.sol.u,copy(integrator.u))
    end
end

In [133]:
solve!(ssaintegrator)

5-element Vector{Any}:
 [999.0, 1.0, 0.0]
 [999.0, 1.0, 0.0]
 [999.0, 1.0, 0.0]
 [999.0, 1.0, 0.0]
 [999.0, 1.0, 0.0]

In [136]:
ssaintegrator.sol.t

5-element Vector{Any}:
 0
 1
 2
 3
 3.96997761029981