In [None]:
using Pkg
const NOTEBOOKS_ROOT = @__DIR__
Pkg.activate(NOTEBOOKS_ROOT)
include(joinpath(NOTEBOOKS_ROOT, "init_python.jl"))
using Revise

In [None]:
using Revise
using StaticArrays
using Distributions
using StaticDistributions
using Particles
using PlotlyJS

In [None]:
const pyparticles = pyimport("particles")
const np = pyimport("numpy")
const Moments = pyparticles.collectors.Moments

In [None]:
struct LotkaVolterra <: StateSpaceModel{SVector{2, Float64}, SVector{2, Float64}}
    alpha::Float64
    beta::Float64
    gamma::Float64
    delta::Float64
    dt::Float64
    nsteps::Int
    sigma0::Float64
    sigmaX_rel::Float64
    sigmaY_rel::Float64
    sigmaX_abs::Float64
    sigmaY_abs::Float64
end
LotkaVolterra(; alpha=2/3, beta=4/3, gamma=1, delta=1, sigma0=0.1, sigmaX_rel=0.1, sigmaY_rel=2.0, sigmaX_abs=0.01, sigmaY_abs=0.1, dt=0.05, dt_obs=0.5) = LotkaVolterra(alpha, beta, gamma, delta, dt, cld(dt_obs, dt), sigma0, sigmaX_rel, sigmaY_rel, sigmaX_abs, sigmaY_abs)
Particles.ssm_PX0(ssm::LotkaVolterra, ::Nothing) = SimpleContinousSMultivariateSampleable{2, Float64}(rng -> begin
    prey0 = ssm.alpha / ssm.beta
    predator0 = ssm.gamma / ssm.delta
    return SVector{2, Float64}(
        rand(rng, TruncatedNormal(prey0, ssm.sigma0 * prey0, 0, Inf64)),
        rand(rng, TruncatedNormal(predator0, ssm.sigma0 * predator0, 0, Inf64)),
    )
end)
function rand_transition(rng, ssm::LotkaVolterra, xp::SVector{2})
    prey, predator = xp
    new_prey = max(0.0, prey + ssm.dt * (ssm.alpha * prey - ssm.beta * prey * predator))
    new_predator = max(0.0, predator + ssm.dt * (ssm.delta * prey * predator - ssm.gamma * predator))
    return SVector{2, Float64}(
        rand(rng, TruncatedNormal(new_prey, sqrt(ssm.dt) * sqrt(ssm.sigmaX_abs + ssm.sigmaX_rel * new_prey), 0, Inf64)),
        rand(rng, TruncatedNormal(new_predator, sqrt(ssm.dt) * sqrt(ssm.sigmaX_abs + ssm.sigmaX_rel * new_predator), 0, Inf64)),
    )
end
Particles.ssm_PX(ssm::LotkaVolterra, ::Nothing, t::Integer, xp::SVector{2}) = SimpleContinousSMultivariateSampleable{2, Float64}(rng -> begin
    for i in 1:ssm.nsteps
        xp = rand_transition(rng, ssm, xp)
    end
    return xp
end)
function Particles.ssm_PY(ssm::LotkaVolterra, ::Nothing, t::Integer, x::SVector{2})
    prey, predator = x
    return SIndependent(
        TruncatedNormal(prey, sqrt(ssm.sigmaY_abs + ssm.sigmaY_rel * prey), 0, Inf64),
        TruncatedNormal(predator, sqrt(ssm.sigmaY_abs + ssm.sigmaY_rel * predator), 0, Inf64),
    )
end

In [None]:
py"""
import particles
from particles import distributions as dists

class MultipleSteps(dists.ProbDist):
    def __init__(self, onestep, xp, nsteps):
        assert nsteps >= 1
        super().__init__()
        self.onestep = onestep
        self.xp = xp
        self.nsteps = nsteps
        self.dtype = onestep(xp).dtype
    def rvs(self, size=None):
        d = self.onestep(self.xp)
        names = list(d.laws.keys())
        x = d.rvs(size=size)
        assert len(x.flat[0]) == len(names)
        for _ in range(self.nsteps-1):
            for i in np.ndindex(x.shape):
                xp = {name: value for (name, value) in zip(names, x[i])}
                x[i] = self.onestep(xp).rvs(size=None).flat[0]
        return x
"""

In [None]:
py"""
import numpy as np
import particles
from particles import distributions as dists
from particles import state_space_models as ssm

class LotkaVolterra(ssm.StateSpaceModel):
    default_params = dict(
        alpha=2/3,
        beta=4/3,
        gamma=1.0,
        delta=1.0,
        sigma0=0.1,
        sigmaX_rel=0.1,
        sigmaY_rel=2.0,
        sigmaX_abs=0.01,
        sigmaY_abs=0.1,
        dt=0.05,
        dt_obs=0.5,
    )

    def PX0(self):
        prey0 = self.alpha / self.beta
        predator0 = self.gamma / self.delta
        return dists.StructDist(dict(
            prey=dists.TruncNormal(mu=prey0, sigma=self.sigma0 * prey0, a=0.0, b=np.inf),
            predator=dists.TruncNormal(mu=predator0, sigma=self.sigma0 * predator0, a=0.0, b=np.inf),
        ))

    def PX_onestep(self, xp):
        new_prey = np.maximum(0.0, xp['prey'] + self.dt * (self.alpha * xp['prey'] - self.beta * xp['prey'] * xp['predator']))
        new_predator = np.maximum(0.0, xp['predator'] + self.dt * (self.delta * xp['prey'] * xp['predator'] - self.gamma * xp['predator']))
        return dists.StructDist(dict(
            prey=dists.TruncNormal(mu=new_prey, sigma=np.sqrt(self.dt) * np.sqrt(self.sigmaX_abs + self.sigmaX_rel * new_prey), a=0.0, b=np.inf),
            predator=dists.TruncNormal(mu=new_predator, sigma=np.sqrt(self.dt) * np.sqrt(self.sigmaX_abs + self.sigmaX_rel * new_predator), a=0.0, b=np.inf),
        ))

    def PX(self, t, xp):
        nsteps = int(np.ceil(self.dt_obs / self.dt))
        return MultipleSteps(self.PX_onestep, xp, nsteps)

    def PY(self, t, xp, x):
        return dists.StructDist(dict(
            obs_prey=dists.TruncNormal(mu=x['prey'], sigma=np.sqrt(self.sigmaY_abs + self.sigmaY_rel * x['prey']), a=0.0, b=np.inf),
            obs_predator=dists.TruncNormal(mu=x['predator'], sigma=np.sqrt(self.sigmaY_abs + self.sigmaY_rel * x['predator']), a=0.0, b=np.inf),
        ))
"""

In [None]:
params = (sigma0=1.0, sigmaX_rel=0.15, sigmaX_abs=0.001, sigmaY_rel=0.15, sigmaY_abs=0.05)
ssm_py = py"LotkaVolterra"(; params...)
ssm_jl = LotkaVolterra(; params...)

In [None]:
T = 100

xtrue_py, data_py = ssm_py.simulate(T)
names = xtrue_py[1].dtype.names
i_predator_xtrue, i_prey_xtrue = xtrue_py[1].dtype.names[1] == "predator" ? (1, 2) : (2, 1)
i_predator_data, i_prey_data = data_py[1].dtype.names[1] == "obs_predator" ? (1, 2) : (2, 1)

xtrue, data = first.(xtrue_py), first.(data_py);
xtrue = map(collect, xtrue)
data = map(collect, data)
if eltype(data) <: Vector
    N = length(first(data))
    data = collect(SVector{N, Float64}, data)
end
if eltype(xtrue) <: Vector
    N = length(first(xtrue))
    xtrue = collect(SVector{N, Float64}, xtrue)
end

if i_predator_xtrue == 1
    xtrue = map(reverse, xtrue)
end
if i_predator_data == 1
    data = map(reverse, data)
end

nothing

In [None]:
[
    plot([
        scatter(x=1:T, y=getindex.(xtrue, 1), mode="markers", name="state (prey)"),
        scatter(x=1:T, y=getindex.(data, 1), mode="markers", name="observation (prey)"),
    ]);
    plot([
        scatter(x=1:T, y=getindex.(xtrue, 2), mode="markers", name="state (predator)"),
        scatter(x=1:T, y=getindex.(data, 2), mode="markers", name="observation (predator)"),
    ])
]

In [None]:
fk_bf_py = pyparticles.state_space_models.Bootstrap(ssm=ssm_py, data=data_py)

In [None]:
py"""
def meanvar(W, x):
    m = np.average(x, weights=W, axis=0)
    m2 = np.average(x**2, weights=W, axis=0)
    v = m2 - m**2
    return m, v
def mom_func(W, x):
    #prey = np.asarray([z['prey'] for z in x])
    #predator = np.asarray([z['predator'] for z in x])
    prey = x['prey']
    predator = x['predator']
    m_prey, v_prey = meanvar(W, prey)
    m_predator, v_predator = meanvar(W, predator)
    return {'mean_prey': m_prey, 'var_prey': v_prey, 'mean_predator': m_predator, 'var_predator': v_predator}
"""

In [None]:
function run_smc_py(fk, N)
    f = @pycall pyparticles.SMC(fk=fk, N=N, collect=[Moments(mom_func=py"mom_func")], verbose=false, qmc=false, resampling="systematic", ESSrmin=0.5)::PyObject
    f.run()
    return f
end

In [None]:
N = 100
GC.gc()
bf_py = @time run_smc_py(fk_bf_py, N)

In [None]:
py_means_prey = [x["mean_prey"] for x in bf_py."summaries".moments];
py_vars_prey = [x["var_prey"] for x in bf_py."summaries".moments];
py_means_predator = [x["mean_predator"] for x in bf_py."summaries".moments];
py_vars_predator = [x["var_predator"] for x in bf_py."summaries".moments];

In [None]:
fk_bf_jl = BootstrapFilter(ssm_jl, data);

In [None]:
function run_smc_jl(fk::FeynmanKacModel{T}, N::Integer) where {T}
    f = SMC(
        fk, nothing, N,
        ParticleHistoryLength(FullHistory()),
        (mean_and_var=RunningSummary(MeanAndVariance(), FullHistory()), ),
        AdaptiveResampling(SystematicResampling(), 0.5),
    )
    offlinefilter!(f)
    return f
end

In [None]:
GC.gc()
bf_jl = @time run_smc_jl(fk_bf_jl, N);

In [None]:
jl_means = getproperty.(bf_jl.history_run.mean_and_var, :mean)
jl_vars = getproperty.(bf_jl.history_run.mean_and_var, :var);
jl_means_prey = getindex.(jl_means, 1)
jl_means_predator = getindex.(jl_means, 2)
jl_vars_prey = getindex.(jl_vars, 1)
jl_vars_predator = getindex.(jl_vars, 2)
nothing

In [None]:
plot([
    scatter(x=1:T, y=py_means_prey, mode="lines", line_color="black", name="particle (py) filter mean"),
    scatter(x=1:T, y=py_means_prey.+sqrt.(py_vars_prey), mode="lines", line_color="gray", name="particle (py) filter mean + 1sd"),
    scatter(x=1:T, y=py_means_prey.-sqrt.(py_vars_prey), mode="lines", line_color="gray", name="particle (py) filter mean - 1sd"),
    scatter(x=1:T, y=jl_means_prey, mode="lines", line_color="red", name="particle (jl) filter mean"),
    scatter(x=1:T, y=jl_means_prey.+sqrt.(jl_vars_prey), mode="lines", line_color="orange", name="particle (jl) filter mean + 1sd"),
    scatter(x=1:T, y=jl_means_prey.-sqrt.(jl_vars_prey), mode="lines", line_color="orange", name="particle (jl) filter mean - 1sd"),
    scatter(x=1:T, y=getindex.(xtrue, 1), mode="markers", name="state"),
    scatter(x=1:T, y=getindex.(data, 1), mode="markers", name="observation"),
])

In [None]:
plot([
    scatter(x=1:T, y=py_means_predator, mode="lines", line_color="black", name="particle (py) filter mean"),
    scatter(x=1:T, y=py_means_predator.+sqrt.(py_vars_predator), mode="lines", line_color="gray", name="particle (py) filter mean + 1sd"),
    scatter(x=1:T, y=py_means_predator.-sqrt.(py_vars_predator), mode="lines", line_color="gray", name="particle (py) filter mean - 1sd"),
    scatter(x=1:T, y=jl_means_predator, mode="lines", line_color="red", name="particle (jl) filter mean"),
    scatter(x=1:T, y=jl_means_predator.+sqrt.(jl_vars_predator), mode="lines", line_color="orange", name="particle (jl) filter mean + 1sd"),
    scatter(x=1:T, y=jl_means_predator.-sqrt.(jl_vars_predator), mode="lines", line_color="orange", name="particle (jl) filter mean - 1sd"),
    scatter(x=1:T, y=getindex.(xtrue, 2), mode="markers", name="state"),
    scatter(x=1:T, y=getindex.(data, 2), mode="markers", name="observation"),
])