In [5]:
using Pun

In [6]:
struct LinearModel
    slope
    intercept
end

# We define a function to compute y for a given x
function evaluate_model(model::LinearModel, x::Real)
    return model.slope * x + model.intercept
end

linear_model_prior = @prob begin
    # We begin by sampling a slope and intercept for the line.
    # Before we have seen the data, we don't know the values of
    # these parameters, so we treat them as random choices. The
    # distributions they are drawn from represent our prior beliefs
    # about the parameters: in this case, that neither the slope nor the
    # intercept will be more than a couple points away from 0.
    slope <<= normal(0, 1)
    intercept <<= normal(0, 2)
    return LinearModel(slope, intercept)
end

line_model(xs; prior=linear_model_prior) = @prob begin
    linear_model <<= prior

    # Given the slope and intercept, we can sample y coordinates
    # for each of the x coordinates in our input vector.
    ys <<= mapM(x -> normal(evaluate_model(linear_model, x), 0.1), xs)

    return linear_model, ys
end;

In [None]:
using LogExpFunctions

function effective_sample_size(log_normalized_weights::Vector{Float64})
    log_ess = -logsumexp(2. * log_normalized_weights)
    return exp(log_ess)
end

function normalize_weights(log_weights)
    log_total_weight = logsumexp(log_weights)
    log_normalized_weights = log_weights .- log_total_weight
    return (log_total_weight, log_normalized_weights)
end

function initialize_particle_filter(p, q, y, basis, n)
    return importance_sampling(p, q, y, basis, n)
end

function maybe_resample(particles; ess_threshold=length(particles)/2)
    num_particles = length(particles)
    weights = getindex.(particles, 2)
    log_total_weight, log_normalized_weights = normalize_weights(weights)
    ess = effective_sample_size(log_normalized_weights)
    do_resample = ess < ess_threshold
    if do_resample
        weights = exp.(log_normalized_weights)
        resampler = @prob begin
            index <<= iid(categorical(weights), num_particles)
            return index
        end
        index, _, _ = simulate(resampler)
        particles = particles[index]
    end
    return particles
end

function particle_filter_step(p, q, y, basis, particles)
    new_particles = []
    for (particle, weight) in particles
        new_particle, new_weight, new_basis = proposal(particle, y)
        push!(new_particles, (new_particle, new_weight))
    end
    # Implicitly assuming everything has the same basis
    return new_particles, new_basis
end

maybe_resample (generic function with 2 methods)

In [73]:
xs = [-5., -4., -3., -2., -1., 0., 1., 2., 3., 4., 5.];
ys = [6.75003, 6.1568, 4.26414, 1.84894, 3.09686, 1.94026, 1.36411, -0.83959, -0.976, -1.93363, -2.91303];
_, _, basis = simulate(line_model(xs))
weighted_particles = Pun.importance_sampling(line_model(xs), y -> linear_model_prior, ys, basis, 100);
resampled = maybe_resample(weighted_particles)

100-element Vector{Tuple{Any, Any}}:
 (LinearModel(-0.8890334989441273, 2.248194656280551), -400.7098206106388)
 (LinearModel(-0.8890334989441273, 2.248194656280551), -400.7098206106388)
 (LinearModel(-0.8890334989441273, 2.248194656280551), -400.7098206106388)
 (LinearModel(-0.8890334989441273, 2.248194656280551), -400.7098206106388)
 (LinearModel(-0.8890334989441273, 2.248194656280551), -400.7098206106388)
 (LinearModel(-0.8890334989441273, 2.248194656280551), -400.7098206106388)
 (LinearModel(-0.8890334989441273, 2.248194656280551), -400.7098206106388)
 (LinearModel(-0.8890334989441273, 2.248194656280551), -400.7098206106388)
 (LinearModel(-0.8890334989441273, 2.248194656280551), -400.7098206106388)
 (LinearModel(-0.8890334989441273, 2.248194656280551), -400.7098206106388)
 â‹®
 (LinearModel(-0.8890334989441273, 2.248194656280551), -400.7098206106388)
 (LinearModel(-0.8890334989441273, 2.248194656280551), -400.7098206106388)
 (LinearModel(-0.8890334989441273, 2.248194656280551), -40

In [87]:
function smc(model, xs, ys, n_particles)
    _, _, basis = simulate(model(xs[1:1]))
    particles = initialize_particle_filter(
        model(xs[1:1]), y -> linear_model_prior, basis, n_particles)

    # steps
    for t = 1:length(zs)-1
        particles = maybe_resample(particles, ess_threshold=num_particles / 2)
        proposal(particle, y) = @prob begin
            (model, ys), _ = particle
            new_y <<= normal(evaluate_model(model, xs[t+1]), 0.1)
            new_ys .<<= [ys; y]
            new_y .>>= new_ys[end]
            new_model .<<= model
            return new_model, new_ys
        end
        particle_filter_step(model[1:t], proposal,)
    end
end

smc (generic function with 1 method)