In [4]:
using Pkg
Pkg.activate("..")

using Pun
using Pun: interpret_program, uninterpret_program, EvalState
using Plots

include("../examples.jl")

[32m[1m  Activating[22m[39m project at `~/Documents/code/Pun.jl`


gamma (generic function with 1 method)

In [11]:
struct LineModel
    slope::Real
    intercept::Real
    noise::Real
end

function evaluate(m::LineModel)
    return (x) -> @prob begin
        y <<= normal(x * m.slope + m.intercept, m.noise)
        return y
    end
end

get_noisy_point(x, prob_outlier, line_model) = @prob begin
    is_outlier <<= flip(prob_outlier)
    y <<= is_outlier ? normal(0, 10)  : evaluate(line_model)(x)
    is_outlier >>= flip(prob_outlier)
    return y
end

regression_with_outliers(xs::Vector{<:Real}) = @prob begin
    slope <<= normal(0, 2)
    intercept <<= normal(0, 2)
    noise <<= gamma(1, 1)
    line_model .<<= LineModel(slope, intercept, noise)
    slope .>>= line_model.slope
    intercept .>>= line_model.intercept
    noise .>>= line_model.noise
    prob_outlier <<= uniform(0, 1)

    # Next, we generate the actual y coordinates.
    n = length(xs)

    ys <<= mapM(x -> get_noisy_point(x, prob_outlier, line_model), xs)
    prob_outlier >>= beta(1, 1)
    return ys, line_model
end;

In [12]:
xs = [-5., -4., -3., -2., -1., 0., 1., 2., 3., 4., 5.];
(ys, line_model), score, basis = simulate(regression_with_outliers(xs));

In [44]:
state = EvalState()
p = beta(1, 1)
val, deps = interpret_program(p, state)
uninterpret_program(p, state, val, deps)

In [None]:
function mh(p, q, y)

    # y -> y'
    state = EvalState()
    state.cfg.n_inputs[] = dim
    val, deps = interpret_program(q(y), state)
    uninterpret_program(p, state, (val, y), union(deps, Set(1:dim)))
    tape_partials = DynamicForwardDiff.Partials[]
    accumulate_partials!(state.tape, tape_partials)
    Qblk = dictrows_to_sparse(Dict{Int,Float64}[p.values for p in tape_partials], state.cfg.n_inputs[])
    correction = logpdet_from_tall(Qblk; tol=1e-14)
    weight = state.logweight + correction

    # y' -> y
    state = EvalState()
    state.cfg.n_inputs[] = dim
    uninterpret_program(q((val, deps)), state)

    accept = minimum(exp(weight), 1)
end


function block_resimulation_update(tr)
    # Block 1: Update the line's parameters
    line_params = select(:noise, :slope, :intercept)
    (tr, _) = mh(tr, line_params)
    
    # Blocks 2-N+1: Update the outlier classifications
    (xs,) = get_args(tr)
    n = length(xs)
    for i=1:n
        (tr, _) = mh(tr, select(:data => i => :is_outlier))
    end
    
    # Block N+2: Update the prob_outlier parameter
    (tr, _) = mh(tr, select(:prob_outlier))
    
    # Return the updated trace
    tr
end;

function block_resimulation_inference(xs, ys, observations)
    observations = make_constraints(ys)
    (tr, _) = generate(regression_with_outliers, (xs,), observations)
    for iter=1:500
        tr = block_resimulation_update(tr)
    end
    tr
end;