In [None]:
using Revise
using Plots
using Gen

includet("./distribution_utils.jl")
includet("./model.jl")
includet("./visualizations.jl")
includet("./inference.jl")

In [None]:
scene_size = 32
max_fireflies = 1
steps = 1

gt_trace = simulate(model, (scene_size, max_fireflies, steps,))
choices = get_choices(gt_trace)
observations = [mat_to_img(choices[:observations => t]) for t in 1:steps]
fig = plot(layout = grid(1,1), legend=false, size=(400, 400))
anim = Plots.@animate for t in 1:steps
    frame = observations[t]
    fig = plot(layout = grid(1,1), legend=false, size=(400, 400))
    Plots.heatmap!(fig, frame, title="Frame $t")
end

gif(anim, "firefly.mp4", fps = 10)

In [None]:
test_chm = choicemap()
test_chm[:init => :init_x => 1] = gt_trace[:init => :init_x => 1] 
test_chm[:init => :init_y => 1] = gt_trace[:init => :init_y => 1]
test_chm[:init => :color => 1] = gt_trace[:init => :color => 1]
test_chm[:init => :init_vx => 1] = gt_trace[:init => :init_vx => 1]
test_chm[:init => :init_vy => 1] = gt_trace[:init => :init_vy => 1]
test_chm[:states => 1 => :x => 1] = gt_trace[:states => 1 => :x => 1]
test_chm[:states => 1 => :y => 1] = gt_trace[:states => 1 => :y => 1]
out_tr, w = generate(model, (scene_size, max_fireflies, steps,), test_chm)
println(w)

In [4]:
choicemaps = []
colors = [1, 2, 3]

for n in range(1, max_fireflies)
    for x in range(1, scene_size)
        for y in range(1, scene_size)
            for c in colors
                chm = choicemap()
                chm[:init => :init_x => n] = x
                chm[:init => :init_y => n] = y
                chm[:init => :color => n] = c
                push!(choicemaps, chm)
            end
        end
    end
end

In [5]:
likelihoods = Dict()
for chm in choicemaps
    tr, w = generate(model, (scene_size, max_fireflies, steps,), chm)
    x = chm[:init => :init_x => 1]
    y = chm[:init => :init_y => 1]
    if haskey(likelihoods, x=>y)
        if w != -Inf
            likelihoods[x => y] += w
        end
    else
        if w != -Inf
            likelihoods[x => y] = w
        end
    end

end

In [6]:
likelihood_map = zeros((scene_size, scene_size))
for x in range(1, scene_size)
    for y in range(1, scene_size)
        if haskey(likelihoods, x=>y)
            likelihood_map[x, y] = likelihoods[x => y]
        end
    end
end

In [None]:


@gen function sample_location(scene_max)
    x = {:x} ~ uniform_discrete(1, scene_max)
    y = {:y} ~ uniform_discrete(1, scene_max)
    observed_x = {:observed_x} ~ normal(x, 1.5)
    observed_y = {:observed_y} ~ normal(y, 1.5)
    return observed_x, observed_y
end

scene_max = 32
gt_trace = simulate(sample_location, (scene_max,))
observations = get_choices(gt_trace)

ws = []
for x in range(1, scene_max)
    w_ys = []
    for y in range(1, scene_max)
        chm = choicemap()

        chm[:observed_x] = observations[:observed_x]
        chm[:observed_y] = observations[:observed_y]
        chm[:x] = x
        chm[:y] = y
        tr, w = generate(sample_location, (scene_max,), chm)
        push!(w_ys, exp(w))
    end
    push!(ws, w_ys)
end

# Convert 2D array to matrix
ws = hcat(ws...)
println(size(ws)) 

fig = plot()
heatmap!(fig, ws)
scatter!([observations[:observed_x]], [observations[:observed_y]], color="red", label="Observations", markershape=:x, markersize=10) 
        

display(fig)

In [None]:
function draw_square(x, y, scene_max, size=2)
    # Draw a square of size S at x, y
    scene = zeros((3, scene_max, scene_max))
    xmin = max(1, x-size)
    xmax = min(scene_max, x+size)
    ymin = max(1, y-size)
    ymax = min(scene_max, y+size)

    for y in range(ymin, ymax)  
        for x in range(xmin, xmax)
            scene[1, y, x] = 1
            scene[2, y, x] = 1
            scene[3, y, x] = 1
        end
    end
    
    return scene
end

@gen function sample_location(scene_max)
    rendered = zeros((3, scene_max, scene_max))
    for i in range(1, 2)
        x = {:x => i} ~ uniform_discrete(1, scene_max)
        y = {:y => i} ~ uniform_discrete(1, scene_max)
        rendered = rendered .+ draw_square(x, y, scene_max)
    end
    observation = {:observation} ~ image_likelihood(rendered, 0.1)
    return observation
end


function enumerate_likelihoods(gt_trace; fix_one=true)
    scene_max = get_args(gt_trace)[1]
    choices = get_choices(gt_trace)
    ws = []
    for x in range(1, scene_max)
        w_ys = []
        for y in range(1, scene_max)
            chm = choicemap()

            chm[:observation] = choices[:observation]
            if fix_one
                chm[:x => 1] = choices[:x => 1]
                chm[:y => 1] = choices[:y => 1]
            end

            chm[:x => 2] = x
            chm[:y => 2] = y
            tr, w = generate(sample_location, (scene_max,), chm)
            push!(w_ys, w)
        end
        push!(ws, w_ys)
    end

    # Convert 2D array to matrix
    ws = hcat(ws...)
    return ws
end

scene_max = 64
gt_trace = simulate(sample_location, (scene_max,))
observations = get_choices(gt_trace)

fixed_ws = enumerate_likelihoods(gt_trace; fix_one=true)
unfixed_ws = enumerate_likelihoods(gt_trace; fix_one=false)

using Plots.PlotMeasures
h1 = heatmap(mat_to_img(observations[:observation]), xlims=(1, scene_max), ylims=(1, scene_max), yflip=true, title="Observation")
h2 = heatmap(fixed_ws, ylims=(1, scene_max), xlims=(1, scene_max), yflip=true, colorbar=true, title="Fix one, Propose one")
h3 = heatmap(unfixed_ws, ylims=(1, scene_max), xlims=(1, scene_max), yflip=true, colorbar=true, title="Sample one, Propose one")

fig = plot(h1, h2, h3, plot_title="Likelihood Comparisons For Two 'Fireflies'", layout = @layout([a; b c]), 
    figsize=(800, 200), aspect_ratio=1, margin=1mm)

display(fig)
