In [None]:
using Pkg
const NOTEBOOKS_ROOT = @__DIR__
Pkg.activate(NOTEBOOKS_ROOT)
include(joinpath(NOTEBOOKS_ROOT, "init_python.jl"))
using Revise

In [None]:
#Pkg.resolve()
#Pkg.update()

In [None]:
using Revise
using Distributions
using Particles
using PlotlyJS

In [None]:
const pyparticles = pyimport("particles")
const kalman = pyimport("particles.kalman")
const np = pyimport("numpy")
const Moments = pyparticles.collectors.Moments

In [None]:
struct LinearGauss <: StateSpaceModel{Float64, Float64}
    rho::Float64
    sigmaX::Float64
    sigmaY::Float64
    sigma0::Float64
end
LinearGauss(; rho=0.9, sigmaX=1.0, sigmaY=0.2, sigma0=sigmaX/sqrt(1-rho^2)) = LinearGauss(rho, sigmaX, sigmaY, sigma0)
Particles.ssm_PX0(ssm::LinearGauss, ::Nothing) = Normal(0.0, ssm.sigma0)
Particles.ssm_PX(ssm::LinearGauss, ::Nothing, t::Integer, xp::Real) = Normal(ssm.rho*xp, ssm.sigmaX)
Particles.ssm_PY(ssm::LinearGauss, ::Nothing, t::Integer, x::Real) = Normal(x, ssm.sigmaY)
#     def proposal0(self, data):
#         sig2post = 1. / (1. / self.sigma0**2 + 1. / self.sigmaY**2)
#         mupost = sig2post * (data[0] / self.sigmaY**2)
#         return dists.Normal(loc=mupost, scale=np.sqrt(sig2post))

#     def proposal(self, t, xp, data):
#         sig2post = 1. / (1. / self.sigmaX**2 + 1. / self.sigmaY**2)
#         mupost = sig2post * (self.rho * xp / self.sigmaX**2
#                              + data[t] / self.sigmaY**2)
#         return dists.Normal(loc=mupost, scale=np.sqrt(sig2post))

#     def logeta(self, t, x, data):
#         law = dists.Normal(loc=self.rho * x,
#                            scale=np.sqrt(self.sigmaX**2 + self.sigmaY**2))
#         return law.logpdf(data[t + 1])

In [None]:
ssm_py = kalman.LinearGauss(sigmaX=1., sigmaY=2.0, rho=.9)
ssm_jl = LinearGauss(sigmaX=1., sigmaY=2.0, rho=.9)

In [None]:
T = 100
xtrue_py, data_py = ssm_py.simulate(T)
xtrue, data = first.(xtrue_py), first.(data_py)
plot([
    scatter(x=1:T, y=xtrue, mode="markers", name="state"),
    scatter(x=1:T, y=data, mode="markers", name="observation"),
])

In [None]:
# xtrue, data = rand(ssm_jl, T)
# plot([
#     scatter(x=1:T, y=xtrue, mode="markers", name="state"),
#     scatter(x=1:T, y=data, mode="markers", name="observation"),
# ])

In [None]:
# Compute exact solution using Kalman filter
kf = kalman.Kalman(ssm=ssm_py, data=data_py)
kf.filter()
true_loglik = np.cumsum(kf.logpyt)
true_filt_means = [first(f[1]) for f in kf.filt]
true_filt_vars = [first(f[2]) for f in kf.filt]
nothing

In [None]:
plot([
    scatter(x=1:T, y=true_filt_means, mode="lines", line_color="black", name="filter mean"),
    scatter(x=1:T, y=true_filt_means.+sqrt.(true_filt_vars), mode="lines", line_color="gray", name="filter mean + 1sd"),
    scatter(x=1:T, y=true_filt_means.-sqrt.(true_filt_vars), mode="lines", line_color="gray", name="filter mean - 1sd"),
    scatter(x=1:T, y=xtrue, mode="markers", name="state"),
    scatter(x=1:T, y=data, mode="markers", name="observation"),
])

In [None]:
fk_bf_py = pyparticles.state_space_models.Bootstrap(ssm=ssm_py, data=data_py)
# fk_gf_py = pyparticles.state_space_models.GuidedPF(ssm=ssm_py, data=data_py)

In [None]:
function run_smc_py(fk, N)
    f = @pycall pyparticles.SMC(fk=fk, N=N, collect=[Moments], verbose=false, qmc=false, resampling="systematic", ESSrmin=0.5)::PyObject
    f.run()
    return f
end

In [None]:
N = 5000
GC.gc()
bf_py = @time run_smc_py(fk_bf_py, N)
# gf_py = @time run_smc_py(fk_gf_py, N)

In [None]:
bf_py_means = [x["mean"] for x in bf_py."summaries".moments];
bf_py_vars = [x["var"] for x in bf_py."summaries".moments];
# gf_py_means = [x["mean"] for x in gf_py."summaries".moments];
# gf_py_vars = [x["var"] for x in gf_py."summaries".moments];

In [None]:
plot([
    scatter(x=1:T, y=true_filt_means, mode="lines", line_color="black", name="Kalman filter mean"),
    scatter(x=1:T, y=true_filt_means.+sqrt.(true_filt_vars), mode="lines", line_color="gray", name="Kalman filter mean + 1sd"),
    scatter(x=1:T, y=true_filt_means.-sqrt.(true_filt_vars), mode="lines", line_color="gray", name="Kalman filter mean - 1sd"),
    scatter(x=1:T, y=bf_py_means, mode="lines", line_color="red", name="particle (py) filter mean"),
    scatter(x=1:T, y=bf_py_means.+sqrt.(bf_py_vars), mode="lines", line_color="orange", name="particle (py) filter mean + 1sd"),
    scatter(x=1:T, y=bf_py_means.-sqrt.(bf_py_vars), mode="lines", line_color="orange", name="particle (py) filter mean - 1sd"),
    scatter(x=1:T, y=xtrue, mode="markers", name="state"),
    scatter(x=1:T, y=data, mode="markers", name="observation"),
])

In [None]:
fk_bf_jl = BootstrapFilter(ssm_jl, data)

In [None]:
function run_smc_jl(fk::FeynmanKacModel{T}, N::Integer) where {T}
    f = SMC(
        fk, nothing, N,
        
        ParticleHistoryLength(FullHistory()),
        
        (mean_and_var=RunningSummary(MeanAndVariance(), FullHistory()), ),
        #(mean_and_var=RunningSummary(MeanAndVariance(), StaticFiniteHistory{3}()), ),
        #(mean_and_var=MeanAndVariance(), ),
        #NamedTuple(),
        
        AdaptiveResampling(SystematicResampling(), 0.5),
    )
    offlinefilter!(f)
    return f
end

In [None]:
GC.gc()
bf_jl = @time run_smc_jl(fk_bf_jl, N);

In [None]:
using ProfileSVG
@profview run_smc_jl(fk_bf_jl, N)

In [None]:
# Offline summary computation
let step = END-2
    offline_value = Particles.compute_summary(bf_jl, OfflineSummary(MeanAndVariance()), step)
    if hasproperty(bf_jl.history_run, :mean_and_var)
        @assert offline_value == bf_jl.history_run.mean_and_var[step]
    end
    offline_value
end

In [None]:
if hasproperty(bf_jl.history_run, :mean_and_var)
    bf_jl_means = getproperty.(bf_jl.history_run.mean_and_var, :mean)
    bf_jl_vars = getproperty.(bf_jl.history_run.mean_and_var, :var)
end;

In [None]:
if hasproperty(bf_jl.history_run, :mean_and_var)
    plot([
        scatter(x=1:T, y=true_filt_means, mode="lines", line_color="black", name="Kalman filter mean"),
        scatter(x=1:T, y=true_filt_means.+sqrt.(true_filt_vars), mode="lines", line_color="gray", name="Kalman filter mean + 1sd"),
        scatter(x=1:T, y=true_filt_means.-sqrt.(true_filt_vars), mode="lines", line_color="gray", name="Kalman filter mean - 1sd"),
        scatter(x=(1:T)[end-length(bf_jl_means)+1:end], y=bf_jl_means, mode="lines", line_color="red", name="particle (jl) filter mean"),
        scatter(x=(1:T)[end-length(bf_jl_means)+1:end], y=bf_jl_means.+sqrt.(bf_jl_vars), mode="lines", line_color="orange", name="particle (jl) filter mean + 1sd"),
        scatter(x=(1:T)[end-length(bf_jl_means)+1:end], y=bf_jl_means.-sqrt.(bf_jl_vars), mode="lines", line_color="orange", name="particle (jl) filter mean - 1sd"),
        scatter(x=1:T, y=xtrue, mode="markers", name="state"),
        scatter(x=1:T, y=data, mode="markers", name="observation"),
    ])
end

In [None]:
if hasproperty(bf_jl.history_run, :mean_and_var)
    plot([
        scatter(x=1:T, y=bf_py_means, mode="lines", line_color="black", name="particle (py) filter mean"),
        scatter(x=1:T, y=bf_py_means.+sqrt.(bf_py_vars), mode="lines", line_color="gray", name="particle (py) filter mean + 1sd"),
        scatter(x=1:T, y=bf_py_means.-sqrt.(bf_py_vars), mode="lines", line_color="gray", name="particle (py) filter mean - 1sd"),
        scatter(x=(1:T)[end-length(bf_jl_means)+1:end], y=bf_jl_means, mode="lines", line_color="red", name="particle (jl) filter mean"),
        scatter(x=(1:T)[end-length(bf_jl_means)+1:end], y=bf_jl_means.+sqrt.(bf_jl_vars), mode="lines", line_color="orange", name="particle (jl) filter mean + 1sd"),
        scatter(x=(1:T)[end-length(bf_jl_means)+1:end], y=bf_jl_means.-sqrt.(bf_jl_vars), mode="lines", line_color="orange", name="particle (jl) filter mean - 1sd"),
        scatter(x=1:T, y=xtrue, mode="markers", name="state"),
        scatter(x=1:T, y=data, mode="markers", name="observation"),
    ])
end