In [None]:
using Pun
using Plots
using LogExpFunctions


rayleigh(sigma) = @prob begin
    u <<= uniform(0, 1)
    x .<<= sigma * (-2 * log(u))^0.5
    u .>>= exp(((x / sigma)^2) * -0.5)
    return x
end

circle(r) = @prob begin
    x <<= normal(0, 1)
    y <<= normal(0, 1)
    s .<<= r .* (x, y) ./ (x^2 + y^2)^0.5
    (x, y) >>= @prob begin
        R <<= rayleigh(1)
        point .<<= s .* (R / r)
        R .>>= (point[1]^2 + point[2]^2)^0.5
        return point
    end
    return s
end

circle_posterior_estimate(noisy_point, r) = @prob begin
    closest_point .<<= r .* noisy_point ./ sum(noisy_point.^2)^0.5
    gaussian_step <<= iid(normal(0, 1), 2)
    point_estimate <<= closest_point .+ gaussian_step
    normalized_point_estimate .<<= r .* point_estimate ./ sum(point_estimate.^2)^0.5
    point_estimate .>>= closest_point .+ gaussian_step
    gaussian_step .>>= noisy_point
    return normalized_point_estimate
end

noisy_circle(r) = @prob begin
    point <<= circle(r)
    noise <<= iid(normal(0, 1), 2)
    noisy_point .<<= point .+ noise
    noise .>>= noisy_point .- point
    point >>= circle_posterior_estimate(noisy_point, r)
    return noisy_point
end

r_prior = uniform(3, 5)

point_cloud = @prob begin
    r <<= r_prior
    points <<= iid(noisy_circle(r), 100)
    return r, points
end

(r, points), weight, basis = simulate(point_cloud)

weighted_particles = Pun.importance_sampling(point_cloud, y -> r_prior, points, Pun.LebesgueBase(200), 100)

particles, log_weights = getindex.(weighted_particles, 1), getindex.(weighted_particles, 2)
weights = exp.(log_weights .- logsumexp(log_weights))

importance_estimate, map_estimate = sum(weights .* particles), particles[argmax(log_weights)]

println("r = ", r)
println("weight = ", weight)
println("IS estimate for r = ", importance_estimate)
println("MAP estimate = ", map_estimate)

scatter(getindex.(points, 1), getindex.(points, 2), aspect_ratio=:equal)

MethodError: MethodError: no method matching length(::Pun.Block)
The function `length` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  length(!Matched::BitSet)
   @ Base bitset.jl:355
  length(!Matched::Profile.HeapSnapshot.Nodes)
   @ Profile ~/.julia/juliaup/julia-1.11.7+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/Profile/src/heapsnapshot_reassemble.jl:52
  length(!Matched::JSON.Parser.MemoryParserState)
   @ JSON ~/.julia/packages/JSON/93Ea8/src/Parser.jl:28
  ...
