In [3]:
using Pun

In [4]:
struct LinearModel
    slope
    intercept
end

# We define a function to compute y for a given x
function evaluate_model(model::LinearModel, x::Real)
    return model.slope * x + model.intercept
end

linear_model_prior = @prob begin
    # We begin by sampling a slope and intercept for the line.
    # Before we have seen the data, we don't know the values of
    # these parameters, so we treat them as random choices. The
    # distributions they are drawn from represent our prior beliefs
    # about the parameters: in this case, that neither the slope nor the
    # intercept will be more than a couple points away from 0.
    slope <<= normal(0, 1)
    intercept <<= normal(0, 2)
    return LinearModel(slope, intercept)
end

line_model(xs; prior=linear_model_prior) = @prob begin
    linear_model <<= prior

    # Given the slope and intercept, we can sample y coordinates
    # for each of the x coordinates in our input vector.
    ys <<= mapM(x -> normal(evaluate_model(linear_model, x), 0.1), xs)

    return linear_model, ys
end;

In [168]:
using LogExpFunctions

function effective_sample_size(log_normalized_weights::Vector{Float64})
    log_ess = -logsumexp(2. * log_normalized_weights)
    return exp(log_ess)
end

function normalize_weights(log_weights)
    log_total_weight = logsumexp(log_weights)
    log_normalized_weights = log_weights .- log_total_weight
    return (log_total_weight, log_normalized_weights)
end

function initialize_particle_filter(p, q, y, basis, n)
    return Pun.importance_sampling(p, q, y, basis, n)
end

function maybe_resample(particles; ess_threshold=length(particles) / 2)
    n_particles = length(particles)
    weights = getindex.(particles, 2)
    log_total_weight, log_normalized_weights = normalize_weights(weights)
    ess = effective_sample_size(log_normalized_weights)
    do_resample = ess < ess_threshold
    if do_resample
        weights = exp.(log_normalized_weights)
        resampler = @prob begin
            index <<= iid(categorical(weights), n_particles)
            return index
        end
        index, _, _ = simulate(resampler)
        particles = particles[index]
    end
    return particles
end

function particle_filter_step(p, q, y, particles)
    _, _, basis = simulate(p)
    new_particles = []
    for (particle, weight) in particles
        smc_proposal = (@prob y -> begin
            new_particle <<= q(particle, y)
            (particle, y) >>= p
            return new_particle
        end)
        new_particle, new_weight, _ = simulate(smc_proposal; args=(; y, particle), arg_basis=basis)
        push!(new_particles, (new_particle, weight + new_weight))
    end
    return new_particles
end

function smc(model, xs, ys, n_particles)
    _, _, basis = simulate(model(xs[1:1]))
    particles = initialize_particle_filter(
        model(xs[1:1]), y -> linear_model_prior, ys[1:1], basis, n_particles
    )

    # TODO: add rejuventation step

    # steps
    for t = 1:length(xs)-1
        particles = maybe_resample(particles, ess_threshold=n_particles / 2)
        proposal(particle, y) = @prob begin
            new_particle .<<= particle
            return new_particle
        end
        particles = particle_filter_step(model(xs[1:t]), proposal, ys[1:t], particles)
    end
    return particles
end

smc (generic function with 1 method)

In [170]:
xs = [-5., -4., -3., -2., -1., 0., 1., 2., 3., 4., 5.];
ys = [6.75003, 6.1568, 4.26414, 1.84894, 3.09686, 1.94026, 1.36411, -0.83959, -0.976, -1.93363, -2.91303];
smc(line_model, xs, ys, 1000)

1000-element Vector{Any}:
 (LinearModel(-1.0872035135763787, 1.3967585262964684), -1676.7267935469406)
 (LinearModel(-0.9513322142638918, 2.0845264656616216), -1956.5966388878155)
 (LinearModel(-1.0872035135763787, 1.3967585262964684), -1676.7267935469406)
 (LinearModel(-1.0872035135763787, 1.3967585262964684), -1676.7267935469406)
 (LinearModel(-1.0872035135763787, 1.3967585262964684), -1676.7267935469406)
 (LinearModel(-1.0872035135763787, 1.3967585262964684), -1676.7267935469406)
 (LinearModel(-1.0872035135763787, 1.3967585262964684), -1676.7267935469406)
 (LinearModel(-1.0872035135763787, 1.3967585262964684), -1676.7267935469406)
 (LinearModel(-1.0872035135763787, 1.3967585262964684), -1676.7267935469406)
 (LinearModel(-1.0872035135763787, 1.3967585262964684), -1676.7267935469406)
 â‹®
 (LinearModel(-1.0872035135763787, 1.3967585262964684), -1676.7267935469406)
 (LinearModel(-1.0872035135763787, 1.3967585262964684), -1676.7267935469406)
 (LinearModel(-1.0872035135763787, 1.39675852