In [37]:
using Revise
using Plots
using Hungarian
using Gen
using Statistics

includet("./distribution_utils.jl")
includet("./model.jl")
includet("./visualizations.jl")
includet("./inference.jl")
using .FireflyInference
revise(FireflyInference)



true

In [38]:
scene_size = 64
max_fireflies = 4
steps = 40

constraints = choicemap()
constraints[:init => :n_fireflies] = 4
gt_trace, _ = generate(model,(scene_size, max_fireflies, steps,),  constraints);

fig = animate_trace(gt_trace)
mp4(fig, fps=10)

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mSaved animation to /Users/yonifriedman/Research/ProbComp/Fireflies/tmp.mp4


In [40]:
num_particles = 100
num_samples = num_particles
intermediate_traces = smc(gt_trace, model, num_particles, num_samples; record_json=false, return_intermediate_traces=true, experiment_tag="");

Running SMC
.......................................| END
.

LoadError: KeyError: key 3 not found

In [30]:
println("Visualizing SMC")
anim = visualize_particles_over_time(intermediate_traces, gt_trace;)
gif(anim, "fireflies_hi-blink_tight_proposal.mp4", fps = 10)

Visualizing SMC


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mSaved animation to /Users/yonifriedman/Research/ProbComp/Fireflies/fireflies_hi-blink_tight_proposal.mp4


In [None]:
# check scores
scores = [get_score(intermediate_traces[10][i]) for i in 1:num_particles]

scores = scores ./ sum(scores)
println(maximum(scores))

In [11]:
# DOESN'T WORK WHEN THERE ARE MISMATCHED OBJECTS / PREDICTIONS
function position_accuracy(gt_trace, inferred_traces)
    gt_states, obs = get_retval(gt_trace)
    gt_xs = gt_states[:xs]
    gt_ys = gt_states[:ys]
    steps = length(gt_xs)
    scores = []
    for (i, inferred_trace) in enumerate(inferred_traces)
        states, _ = get_retval(inferred_trace)
        inferred_xs = states[:xs]
        inferred_ys = states[:ys]
        l2_dist = (sum((gt_xs .- inferred_xs).^2) + sum((gt_ys .- inferred_ys).^2)) / steps
        push!(scores, l2_dist)
    end
    scores
end

if max_fireflies == 1
    scores = position_accuracy(gt_trace, inferred_traces)
    println(mean(scores))
end

In [None]:
function optimal_position_error(gt_trace, inferred_traces; max_cost=100.)
    gt_states, obs = get_retval(gt_trace)
    gt_xs = gt_states[:xs]
    gt_ys = gt_states[:ys]
    num_gt_objects, steps = size(gt_xs)
    
    all_scores = Float64[]
    all_misses = Int[]
    assignments = []
    
    for (i, inferred_trace) in enumerate(inferred_traces)
        states, _ = get_retval(inferred_trace)
        inferred_xs = states[:xs]
        inferred_ys = states[:ys]
        num_inferred_objects = size(inferred_xs, 1)
        
        # Determine the size of the cost matrix
        max_objects = max(num_gt_objects, num_inferred_objects)
        
        # Compute cost matrix for this particle
        cost_matrix = fill(max_cost, (max_objects, max_objects))
        for inf_obj in 1:num_inferred_objects
            for gt_obj in 1:num_gt_objects
                squared_distances = (gt_xs[gt_obj, :] .- inferred_xs[inf_obj, :]).^2 .+ 
                                    (gt_ys[gt_obj, :] .- inferred_ys[inf_obj, :]).^2
                average_distance = sqrt(sum(squared_distances) / steps)
                cost_matrix[inf_obj, gt_obj] = min(average_distance, max_cost)
            end
        end
        
        # Apply Hungarian algorithm
        assignment, cost = hungarian(cost_matrix)
        
        # Compute scores based on optimal assignment
        particle_scores = zeros(num_inferred_objects)
        for inf_obj in 1:num_inferred_objects
            assigned_gt = assignment[inf_obj]
            if assigned_gt <= num_gt_objects  # Real assignment
                particle_scores[inf_obj] = cost_matrix[inf_obj, assigned_gt]
            else  # Assigned to a dummy object (missed)
                particle_scores[inf_obj] = max_cost
            end
        end
        
        # Count missed ground truth objects
        missed_gt = count(i -> i > num_inferred_objects, assignment[1:num_gt_objects])
        
        append!(all_scores, sum(particle_scores))
        push!(all_misses, missed_gt)
        push!(assignments, assignment)
    end
    
    return all_scores, all_misses, assignments
end

inferred_traces = intermediate_traces[steps]
scores, misses, assignments = optimal_position_error(gt_trace, inferred_traces; max_cost=64.)

println(mean(scores), " ", std(scores), " ", mean(misses), " ", std(misses))

fig = plot(layout=(1, 2), figsize=(10, 5))
scatter!(fig[1], scores, xlabel="Particle", ylabel="Error", title="Tracking Error per particle")
bar!(fig[2], misses, xlabel="Particle", ylabel="Missed objects", title="Missed objects per particle")

In [None]:
function pixel_reconstruction_error(gt_trace, particles)
    scene_size, max_fireflies, steps = get_args(gt_trace)
    gt_choices = get_choices(gt_trace)
    gt_obs = [gt_choices[:observations => t] for t in 1:steps]

    l2_dists = []
    errormaps = []
    inferred_renders = []
    for (i, particle) in enumerate(particles)
        particle_errormap = []
        states, _  = get_retval(particle)
        particle_obs = []
        for t in 1:steps
            obs = render!(states, t, scene_size)
            p_errormap = logpdfmap(image_likelihood, gt_obs[t], obs, 0.01)
            push!(particle_errormap, p_errormap)
            push!(particle_obs, obs)
        end
        push!(errormaps, particle_errormap)
        l2_dist = sum([sum((gt_obs[t] .- particle_obs[t]).^2) for t in 1:steps]) / (scene_size^2 * steps)
        push!(l2_dists, l2_dist)
        push!(inferred_renders, particle_obs)
    end
    l2_dists, errormaps, inferred_renders
end

l2_dists, errormaps, obses = pixel_reconstruction_error(gt_trace, inferred_traces)

In [None]:
anim = @animate for t in 1:steps
    # Create a 4x5 grid of subplots
    fig = plot(size=(400, 400), 
        suptitle="Reconstruction Error [Frame $t]")
    # Add figure title
    errors = [errormaps[p][t] for p in 1:num_samples]
    error = sum(errors) / num_samples
    error = ifelse.(error .< 0., -error, 0)
    mean_error = mean(error)
    heatmap!(error, color=:grays, 
        axis=nothing, legend=false, 
        aspect_ratio=:equal, 
        colorbar=false, 
        xticks=false, yticks=false, 
        title=mean_error, yflip=true)
end

gif(anim, "reconstruction_error.mp4", fps = 10)

In [None]:
anim = @animate for t in 1:steps
    # Create a 4x5 grid of subplots
    fig = plot(size=(600, 600), 
        suptitle="Reconstruction Error [Frame $t]")
    # Add figure title
    img = [mat_to_img(obses[p][t]) for p in 1:num_samples]
    img = sum(img) / num_samples
    
    heatmap!(img, axis=nothing, legend=false, 
        aspect_ratio=:equal,
        colorbar=false, 
        xticks=false, yticks=false, 
        yflip=true,
        title="Particle Reconstruction")
end

gif(anim, "reconstruction_error.mp4", fps = 10)