In [None]:
using Pkg
const NOTEBOOKS_ROOT = @__DIR__
Pkg.activate(NOTEBOOKS_ROOT)
# include(joinpath(NOTEBOOKS_ROOT, "init_python.jl"))
using Revise

In [None]:
using Random
using StaticArrays
using Distributions
using StaticDistributions
using Particles
using PlotlyJS

In [None]:
struct LotkaVolterra <: StateSpaceModel{SVector{2, Float64}, SVector{2, Float64}}
    dt::Float64
    sigma0::Float64
    sigmaX_rel::Float64
    sigmaY_rel::Float64
    sigmaX_abs::Float64
    sigmaY_abs::Float64
end

LotkaVolterra(; sigma0=0.1, sigmaX_rel=0.1, sigmaY_rel=2.0, sigmaX_abs=0.01, sigmaY_abs=0.1, dt=0.05) = LotkaVolterra(dt, sigma0, sigmaX_rel, sigmaY_rel, sigmaX_abs, sigmaY_abs)

Particles.parameter_type(::LotkaVolterra) = Vector{Float64} # TODO allow different parameter type: add a function that only checks type isparametertypecorrect or something
Particles.parameter_template(::LotkaVolterra) = Float64[log(2/3), log(4/3), log(1), log(1)]
Particles.isparameter(::LotkaVolterra, θ) = isa(θ, Vector{Float64}) && length(θ) == 4

function Particles.ssm_PX0(ssm::LotkaVolterra, θ::AbstractVector{<:Real})
    alpha, beta, gamma, delta = exp(θ[1]), exp(θ[2]), exp(θ[3]), exp(θ[4])
    prey0 = alpha / beta
    predator0 = gamma / delta
    return SIndependent(
        truncated(Normal(prey0, ssm.sigma0 * prey0; check_args=false), 0, nothing), # NB Inf64 break AD
        truncated(Normal(predator0, ssm.sigma0 * predator0; check_args=false), 0, nothing),
    )
end

function Particles.ssm_PX(ssm::LotkaVolterra, θ::AbstractVector{<:Real}, t::Integer, xp::SVector{2})
    alpha, beta, gamma, delta = exp(θ[1]), exp(θ[2]), exp(θ[3]), exp(θ[4])
    prey, predator = xp
    new_prey = max(0.0, prey + ssm.dt * (alpha * prey - beta * prey * predator))
    new_predator = max(0.0, predator + ssm.dt * (delta * prey * predator - gamma * predator))
    return SIndependent(
        truncated(Normal(new_prey, sqrt(ssm.dt) * sqrt(ssm.sigmaX_abs + ssm.sigmaX_rel * new_prey); check_args=false), 0.0, nothing),
        truncated(Normal(new_predator, sqrt(ssm.dt) * sqrt(ssm.sigmaX_abs + ssm.sigmaX_rel * new_predator); check_args=false), 0.0, nothing),
    )
end

function Particles.ssm_PY(ssm::LotkaVolterra, θ::AbstractVector{<:Real}, t::Integer, x::SVector{2})
    prey, predator = x
    return SIndependent(
        truncated(Normal(prey, sqrt(ssm.sigmaY_abs + ssm.sigmaY_rel * prey); check_args=false), 0.0, nothing),
        truncated(Normal(predator, sqrt(ssm.sigmaY_abs + ssm.sigmaY_rel * predator); check_args=false), 0.0, nothing),
    )
end

In [None]:
ssm = LotkaVolterra(dt=0.5, sigma0=1.0, sigmaX_rel=0.15, sigmaX_abs=0.001, sigmaY_rel=0.15, sigmaY_abs=0.05)
theta0 = Particles.parameter_template(ssm)

In [None]:
T = 250
xtrue, data_full = rand(ssm, theta0, T);
data = similar(data_full, Union{Missing, eltype(data_full)})
fill!(data, missing)
mask = eachindex(data)[begin:4:end]
data[mask] .= data_full[mask]
[
    plot([
        scatter(x=1:T, y=getindex.(xtrue, 1), mode="markers", name="state (prey)"),
        scatter(x=(1:T)[mask], y=getindex.(data[mask], 1), mode="markers", name="observation (prey)"),
    ]);
    plot([
        scatter(x=1:T, y=getindex.(xtrue, 2), mode="markers", name="state (predator)"),
        scatter(x=(1:T)[mask], y=getindex.(data[mask], 2), mode="markers", name="observation (predator)"),
    ])
]

---

In [None]:
struct LogPosterior{T_SMC <: SMC, T_CACHE}
    pf::T_SMC
    cache::T_CACHE
    function LogPosterior(ssm::StateSpaceModel, data, nparticles::Integer)
        bf = BootstrapFilter(ssm, data)
        pf = SMC(
            bf, Particles.parameter_template(ssm), nparticles,
            ParticleHistoryLength(; logCnorm=StaticFiniteHistory{1}()),
            NamedTuple(),
            AdaptiveResampling(SystematicResampling(), 0.5),
        )
        cache = Particles.SMCCache(pf)
        return new{typeof(pf), typeof(cache)}(pf, cache)
    end
end
function (logp::LogPosterior)(theta)::Float64
    reset!(logp.pf, theta)
    offlinefilter!(logp.pf, logp.cache)
    return logp.pf.history_pf.logCnorm[end]
end

struct LogPosteriorWithGradient{T_SMC1 <: SMC, T_SMC2 <: SMC, T_CACHE1, T_CACHE2}
    pf::T_SMC1
    pf_grad::T_SMC2
    cache::T_CACHE1
    cache_grad::T_CACHE2
    function LogPosteriorWithGradient(ssm::StateSpaceModel, data, nparticles::Integer)
        bf = BootstrapFilter(ssm, data)
        pf = SMC(
            bf, Particles.parameter_template(ssm), nparticles,
            ParticleHistoryLength(; logCnorm=StaticFiniteHistory{1}()),
            NamedTuple(),
            AdaptiveResampling(SystematicResampling(), 0.5),
        )
        cache = Particles.SMCCache(pf)
        pf_grad = SMC(
            bf, Particles.parameter_template(ssm), nparticles,
            ParticleHistoryLength(; logCnorm=StaticFiniteHistory{1}()),
            (score=Score(), ),
            AdaptiveResampling(SystematicResampling(), 0.5),
        )
        cache_grad = Particles.SMCCache(pf_grad)
        return new{typeof(pf), typeof(pf_grad), typeof(cache), typeof(cache_grad)}(pf, pf_grad, cache, cache_grad)
    end
end
function (logp::LogPosteriorWithGradient)(gradient, theta)
    reset!(logp.pf_grad, theta)
    offlinefilter!(logp.pf_grad, logp.cache_grad)
    compute_summary!(gradient, logp.pf_grad, :score)
    y = logp.pf_grad.history_pf.logCnorm[end]::Float64
    # if !isfinite(y) || !all(isfinite, gradient)
    #     @info "logp is not finite" theta logp=y ∇logp=gradient
    #     flush(stdout)
    # end
    return y
end
function (logp::LogPosteriorWithGradient)(::Val{:return}, theta)
    gradient = similar(theta, Float64)
    y = logp(gradient, theta)
    return y::Float64, gradient
end
function (logp::LogPosteriorWithGradient)(theta)
    reset!(logp.pf, theta)
    offlinefilter!(logp.pf, logp.cache)
    y = logp.pf.history_pf.logCnorm[end]::Float64
    # if !isfinite(y)
    #     @info "logp is not finite" theta logp=y
    #     flush(stdout)
    # end
    return y
end

In [None]:
logp = LogPosterior(ssm, data, 250);
@time logp(theta0)

In [None]:
logp = LogPosteriorWithGradient(ssm, data, 5);
@time logp(Val(:return), theta0)

---
### Compare variance of LogPosterior as the number of particles is increased

In [None]:
function logp_vs_nparticles(ssm::StateSpaceModel, data, nparticles::AbstractVector{<:Integer}, theta; nruns::Integer=10, kwargs...)
    x = Vector{String}(undef, length(nparticles) * nruns)
    y = Vector{Float64}(undef, length(nparticles) * nruns)
    k = 1
    for n in nparticles
        logp = LogPosterior(ssm, data, n)
        for _ in 1:nruns
            @inbounds x[k] = string(convert(Int, n))
            @inbounds y[k] = logp(theta)
            k += 1
        end
    end
    return plot(violin(; x, y, kwargs...))
end

In [None]:
logp_vs_nparticles(ssm, data, [50, 75, 100, 150, 200, 250, 300, 500, 1000], theta0; nruns=200)

---
### Check derivative computation

In [None]:
theta1 = Float64[log(0.5), log(1.5), log(0.8), log(1.2)]

In [None]:
# One realization
_, grad_pf = @time logp(Val(:return), theta1)

In [None]:
# Multiple realizations
using Statistics
grads = [convert(SVector{4}, logp(Val(:return), theta1)[2]) for _ in 1:10]
# new_grads = [convert(SVector{4}, logp(Val(:return), theta1)[2]) for _ in 1:30]
# append!(grads, new_grads)
grad_pf = mean(grads)

In [None]:
using Dierckx
cmp = 1
f = if cmp == 1
    x -> logp([theta1[1] + x, theta1[2], theta1[3], theta1[4]])
elseif cmp == 2
    x -> logp([theta1[1], theta1[2] + x, theta1[3], theta1[4]])
elseif cmp == 3
    x -> logp([theta1[1], theta1[2], theta1[3] + x, theta1[4]])
elseif cmp == 4
    x -> logp([theta1[1], theta1[2], theta1[3], theta1[4] + x])
end
dx = 0.000025
x = -0.4:dx:0.4
y = map(f, x);

In [None]:
k = range(-0.399, 0.399, length=10)
spline = Spline1D(x, y, k; k=3)
println(derivative(spline, 0.0))
plot([
    scatter(; x, y, mode="markers"),
    scatter(; x=[-0.3, 0.3], y=[spline(0.0) - grad_pf[cmp] * 0.3, spline(0.0) + grad_pf[cmp] * 0.3], mode="lines"),
    scatter(; x, y=spline.(x), mode="lines"),
    scatter(; x=k, y=spline.(k), mode="markers"),
])

---
### Metropolis-Hastings MCMC

In [None]:
using AdvancedMH
using MCMCChains
import StatsPlots

In [None]:
logp = LogPosterior(ssm, data, 100)
model = DensityModel(logp);

In [None]:
spl = RWMH([Normal(0.0, 0.05), Normal(0.0, 0.05), Normal(0.0, 0.05), Normal(0.0, 0.05)])
chain = sample(model, spl, 100_000; init_params=theta0, param_names=["α", "β", "γ", "δ"], chain_type=Chains)

In [None]:
StatsPlots.plot(chain)

---
### Hamiltonian MCMC

In [None]:
using AdvancedHMC

In [None]:
using AdvancedHMC

# Choose parameter dimensionality and initial parameter value
D = length(theta)
initial_θ = copy(theta)

# Set the number of samples to draw and warmup iterations
n_samples, n_adapts = 200, 100

# Define a Hamiltonian system
metric = DiagEuclideanMetric(D)
hamiltonian = Hamiltonian(metric, logp, Base.Fix1(logp, Val(:return)))

# Define a leapfrog solver, with initial step size chosen heuristically
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
# initial_ϵ = 0.01
integrator = Leapfrog(initial_ϵ)

# Define an HMC sampler, with the following components
#   - multinomial sampling scheme,
#   - generalised No-U-Turn criteria, and
#   - windowed adaption for step-size and diagonal mass matrix
proposal = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator))

In [None]:
# Run the sampler to draw samples from the specified Gaussian, where
#   - `samples` will store the samples
#   - `stats` will store diagnostic statistics for each sample
n_samples, n_adapts = 200, 100
samples, stats = sample(hamiltonian, proposal, initial_θ, n_samples, adaptor, n_adapts; progress=false)