In [None]:
using Pkg
const NOTEBOOKS_ROOT = @__DIR__
isfile(joinpath(NOTEBOOKS_ROOT, "Project.toml")) || include("init_project.jl")
Pkg.activate(NOTEBOOKS_ROOT)
using Revise

In [None]:
Pkg.resolve()
Pkg.update()

### Discrete Cox model

In [None]:
using Distributions
using Particles
using PlotlyJS

struct DiscreteCox <: StateSpaceModel{Float64, Int}
    mu::Float64
    sigma::Float64
    phi::Float64
end
DiscreteCox(; mu=0.0, sigma=1.0, phi=0.95) = DiscreteCox(mu, sigma, phi)
Particles.ssm_PX0(ssm::DiscreteCox, ::Nothing) = Normal(ssm.mu, ssm.sigma / sqrt((1 - ssm.phi) * (1 + ssm.phi)))
Particles.ssm_PX(ssm::DiscreteCox, ::Nothing, t::Integer, xp::Real) = Normal(muladd(ssm.phi, xp - ssm.mu, ssm.mu), ssm.sigma)
Particles.ssm_PY(ssm::DiscreteCox, ::Nothing, t::Integer, x::Real) = Poisson(exp(x))
Particles.ssm_sup_logpdf_PX(ssm::DiscreteCox, ::Nothing, t::Integer, x::Real) = -log(ssm.sigma) - log(2*pi)/2 # if phi ≤ 1

"""
    Score
score of the model (gradient of log-likelihood at theta=theta_0)
"""
struct Score <: Particles.ImmutableAdditiveFunction end
function (::Score)(bf::BootstrapFilter{Float64, Int, DiscreteCox}, ::Nothing, x, ::Nothing)::Float64
    ssm = bf.ssm
    return ((1 - ssm.phi) * (1 + ssm.phi) / (2 * ssm.sigma^4)) * (x - ssm.mu)^2 - ssm.sigma^2 / 2
end
function (::Score)(bf::BootstrapFilter{Float64, Int, DiscreteCox}, ::Nothing, t::Integer, xp, x, ::Nothing)::Float64
    ssm = bf.ssm
    return muladd(ssm.phi, xp - ssm.mu, ssm.mu - x)^2 / (2 * ssm.sigma^4) - ssm.sigma^2 / 2
end
Particles.return_type(::Score, ::BootstrapFilter{Float64, Int, DiscreteCox}) = Float64

In [None]:
ssm = DiscreteCox()

### Simulate data

In [None]:
T = 100
# xtrue, data = rand(ssm, T)
xtrue, data = [2.4021265350056944, 1.681624508654635, 2.3916947359876395, 4.77788647209322, 2.817148421953582, 3.4693282990887497, 3.2924510794624573, 5.011681413526453, 7.190454976294111, 6.967928865313658, 6.102722306108012, 4.487636893855688, 3.703846197361809, 3.7468282413507787, 2.8988502482381753, 2.5257424292420287, 1.3329589939737236, 0.8693102565291805, 2.711898711746743, 3.4232838880367797, 3.2453131166199407, 2.331887449217647, 3.3700451288125493, 4.252488078154874, 2.4038348942332135, 2.261371117945066, 2.9478443945807964, 2.682914704873883, 1.3585135164014874, 1.3926437420156972, -0.20822765483888794, -0.2658680311238287, 0.07618351356071174, 2.1709769773058425, 2.742357879584931, 2.628257920122673, 3.183596799522256, 2.0409879872348107, -0.03713106151521539, -0.5805168512529257, -0.6194246449677187, 1.3123945516084254, 0.970109940654305, 0.8950591383004288, 0.10543607139233557, 0.35279437922984114, -0.3491891121737035, -2.016274098243435, -1.1083793648677271, -0.4117322280012812, -0.5118909120660742, 1.1099251920213644, 1.2851962828014611, 1.5377462404247215, 0.057517125953026715, 0.5130136355717464, 0.9159240740327812, 1.6531357064720393, 1.748726382153127, 2.355084014692249, 1.9546883298793616, 2.271199044211417, 2.5570104439531662, 1.9408311991694989, 0.4747386323002716, -0.3847378120076566, 0.09586345989864509, -1.788164232258839, 0.19845446961195812, 0.7032762459314381, 0.581864986690299, 1.6669135895956826, 1.8116436386236507, 0.6530231531931909, 0.6795396754452527, 2.3592872304719084, 3.2559674377841263, 3.403227474923672, 4.130354333847429, 5.465074809642444, 5.477761488716303, 5.6612352411497096, 6.309660800274545, 7.688024615871875, 8.59400161729781, 8.25763747662676, 8.35246885224497, 7.692127177332258, 8.601585504922028, 8.408117815341656, 7.008772701254552, 7.340153378498685, 8.115159339115328, 7.958472562194978, 7.819763909587895, 8.003961033420545, 6.748063926944804, 6.804687618088711, 7.218016132369026, 7.441399002392016], [11, 7, 11, 119, 31, 28, 29, 133, 1311, 1017, 465, 91, 35, 34, 19, 12, 4, 3, 11, 25, 20, 13, 29, 64, 9, 6, 22, 19, 1, 3, 1, 0, 3, 14, 13, 13, 27, 7, 0, 0, 1, 5, 3, 4, 0, 2, 1, 0, 0, 3, 0, 3, 3, 3, 1, 4, 0, 2, 2, 11, 4, 13, 15, 5, 2, 1, 0, 0, 2, 1, 3, 5, 5, 2, 2, 8, 26, 39, 73, 221, 270, 300, 540, 2247, 5357, 3947, 4223, 2157, 5441, 4425, 1121, 1525, 3343, 2922, 2446, 2998, 868, 952, 1340, 1662];

In [None]:
plot([
    scatter(x=1:T, y=xtrue, mode="markers", name="state"),
    scatter(x=1:T, y=log.(Float64.(data)), mode="markers", name="observation"),
])

### Compare online naive smoother with amortized naive smoother

In [None]:
nparticles = 1000
bf = BootstrapFilter(ssm, data)
score = AdditiveFunctionSmoother(Score(), NaiveAfsMethod())
score_offline = OfflineSummary(AdditiveFunctionSmoother(Score(), AmortizedNaiveAfsMethod()))
pf = SMC(
    bf, Particles.parameter_template(ssm), nparticles,
    ParticleHistoryLength(score_offline),
    (; score, ),
    Particles.required_amortized_computations(score_offline),
);
offlinefilter!(pf)
score_value = compute_summary(pf, :score)
score_value_offline = compute_summary(pf, score_offline)
score_value == score_value_offline || error("scores should have been the same!")
score_value

### Benchmark online naive smoother vs amortized naive smoother

In [None]:
using BenchmarkTools
function benchmark_online(ssm, data, nparticles)
    bf = BootstrapFilter(ssm, data)
    score = AdditiveFunctionSmoother(Score(), NaiveAfsMethod())
    pf = SMC(
        bf, Particles.parameter_template(ssm), nparticles,
        (; score, ),
    );
    offlinefilter!(pf)
    return compute_summary(pf, :score)
end
function benchmark_offline(ssm, data, nparticles)
    bf = BootstrapFilter(ssm, data)
    score_offline = OfflineSummary(AdditiveFunctionSmoother(Score(), AmortizedNaiveAfsMethod()))
    pf = SMC(
        bf, Particles.parameter_template(ssm), nparticles,
        ParticleHistoryLength(score_offline),
        Particles.required_amortized_computations(score_offline),
    );
    offlinefilter!(pf)
    return compute_summary(pf, score_offline)
end

In [None]:
nparticles = 1000;

In [None]:
@benchmark benchmark_online($ssm, $data, $nparticles)

In [None]:
@benchmark benchmark_offline($ssm, $data, $nparticles)